diff --git a/fastapi_crudrouter/core/_base.py b/fastapi_crudrouter/core/_base.py index 376d54e..e45d33f 100644 --- a/fastapi_crudrouter/core/_base.py +++ b/fastapi_crudrouter/core/_base.py @@ -1,4 +1,4 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Any, Callable, Generic, List, Optional, Type, Union from fastapi import APIRouter, HTTPException @@ -10,7 +10,7 @@ NOT_FOUND = HTTPException(404, "Item not found") -class CRUDGenerator(Generic[T], APIRouter): +class CRUDGenerator(Generic[T], APIRouter, ABC): schema: Type[T] create_schema: Type[T] update_schema: Type[T] diff --git a/fastapi_crudrouter/core/_utils.py b/fastapi_crudrouter/core/_utils.py index 118047b..ef3562e 100644 --- a/fastapi_crudrouter/core/_utils.py +++ b/fastapi_crudrouter/core/_utils.py @@ -33,7 +33,7 @@ def schema_factory( } name = schema_cls.__name__ + name - schema = create_model(__model_name=name, **fields) # type: ignore + schema: Type[T] = create_model(__model_name=name, **fields) # type: ignore return schema diff --git a/tests/test_base.py b/tests/test_base.py index ace1c40..399ec14 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,3 +1,6 @@ +from abc import ABC +from typing import Type + import pytest from fastapi import APIRouter, FastAPI @@ -14,13 +17,30 @@ from tests import Potato -def test_router_type(): - assert issubclass(CRUDGenerator, APIRouter) - assert issubclass(SQLAlchemyCRUDRouter, APIRouter) - assert issubclass(MemoryCRUDRouter, APIRouter) - assert issubclass(OrmarCRUDRouter, APIRouter) - assert issubclass(GinoCRUDRouter, APIRouter) - assert issubclass(DatabasesCRUDRouter, APIRouter) +@pytest.fixture( + params=[ + GinoCRUDRouter, + SQLAlchemyCRUDRouter, + MemoryCRUDRouter, + OrmarCRUDRouter, + GinoCRUDRouter, + DatabasesCRUDRouter, + ] +) +def subclass(request) -> Type[CRUDGenerator]: + return request.param + + +def test_router_is_subclass_of_crud_generator(subclass): + assert issubclass(subclass, CRUDGenerator) + + +def test_router_is_subclass_of_api_router(subclass): + assert issubclass(subclass, APIRouter) + + +def test_base_class_is_abstract(): + assert issubclass(CRUDGenerator, ABC) def test_raise_not_implemented(): @@ -35,9 +55,7 @@ def bar(): methods = CRUDGenerator.get_routes() for m in methods: - with pytest.raises(NotImplementedError): + with pytest.raises(TypeError): app.include_router(CRUDGenerator(schema=Potato)) setattr(CRUDGenerator, f"_{m}", foo) - - app.include_router(CRUDGenerator(schema=Potato))