diff --git a/fastapi_crudrouter/core/_base.py b/fastapi_crudrouter/core/_base.py index f6379e8..e45d33f 100644 --- a/fastapi_crudrouter/core/_base.py +++ b/fastapi_crudrouter/core/_base.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from typing import Any, Callable, Generic, List, Optional, Type, Union from fastapi import APIRouter, HTTPException @@ -9,7 +10,7 @@ NOT_FOUND = HTTPException(404, "Item not found") -class CRUDGenerator(Generic[T], APIRouter): +class CRUDGenerator(Generic[T], APIRouter, ABC): schema: Type[T] create_schema: Type[T] update_schema: Type[T] @@ -176,21 +177,27 @@ def remove_api_route(self, path: str, methods: List[str]) -> None: ): self.routes.remove(route) + @abstractmethod def _get_all(self, *args: Any, **kwargs: Any) -> Callable[..., Any]: raise NotImplementedError + @abstractmethod def _get_one(self, *args: Any, **kwargs: Any) -> Callable[..., Any]: raise NotImplementedError + @abstractmethod def _create(self, *args: Any, **kwargs: Any) -> Callable[..., Any]: raise NotImplementedError + @abstractmethod def _update(self, *args: Any, **kwargs: Any) -> Callable[..., Any]: raise NotImplementedError + @abstractmethod def _delete_one(self, *args: Any, **kwargs: Any) -> Callable[..., Any]: raise NotImplementedError + @abstractmethod def _delete_all(self, *args: Any, **kwargs: Any) -> Callable[..., Any]: raise NotImplementedError diff --git a/fastapi_crudrouter/core/_utils.py b/fastapi_crudrouter/core/_utils.py index 118047b..ef3562e 100644 --- a/fastapi_crudrouter/core/_utils.py +++ b/fastapi_crudrouter/core/_utils.py @@ -33,7 +33,7 @@ def schema_factory( } name = schema_cls.__name__ + name - schema = create_model(__model_name=name, **fields) # type: ignore + schema: Type[T] = create_model(__model_name=name, **fields) # type: ignore return schema diff --git a/tests/test_base.py b/tests/test_base.py index 9d44667..399ec14 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,27 +1,49 @@ +from abc import ABC +from typing import Type + import pytest from fastapi import APIRouter, FastAPI + from fastapi_crudrouter import ( GinoCRUDRouter, MemoryCRUDRouter, OrmarCRUDRouter, SQLAlchemyCRUDRouter, + DatabasesCRUDRouter, ) # noinspection PyProtectedMember from fastapi_crudrouter.core._base import CRUDGenerator - from tests import Potato -def test_router_type(): - assert issubclass(CRUDGenerator, APIRouter) - assert issubclass(SQLAlchemyCRUDRouter, APIRouter) - assert issubclass(MemoryCRUDRouter, APIRouter) - assert issubclass(OrmarCRUDRouter, APIRouter) - assert issubclass(GinoCRUDRouter, APIRouter) +@pytest.fixture( + params=[ + GinoCRUDRouter, + SQLAlchemyCRUDRouter, + MemoryCRUDRouter, + OrmarCRUDRouter, + GinoCRUDRouter, + DatabasesCRUDRouter, + ] +) +def subclass(request) -> Type[CRUDGenerator]: + return request.param + + +def test_router_is_subclass_of_crud_generator(subclass): + assert issubclass(subclass, CRUDGenerator) -def test_get_one(): +def test_router_is_subclass_of_api_router(subclass): + assert issubclass(subclass, APIRouter) + + +def test_base_class_is_abstract(): + assert issubclass(CRUDGenerator, ABC) + + +def test_raise_not_implemented(): app = FastAPI() def foo(*args, **kwargs): @@ -30,14 +52,10 @@ def bar(): return bar - foo()() - methods = CRUDGenerator.get_routes() for m in methods: - with pytest.raises(NotImplementedError): + with pytest.raises(TypeError): app.include_router(CRUDGenerator(schema=Potato)) setattr(CRUDGenerator, f"_{m}", foo) - - app.include_router(CRUDGenerator(schema=Potato))