diff --git a/piccolo_api/crud/endpoints.py b/piccolo_api/crud/endpoints.py index 847e6e5c..026b8547 100644 --- a/piccolo_api/crud/endpoints.py +++ b/piccolo_api/crud/endpoints.py @@ -24,7 +24,6 @@ from piccolo.query.methods.select import Select from piccolo.table import Table from piccolo.utils.encoding import dump_json -from pydantic.error_wrappers import ValidationError from starlette.requests import Request from starlette.responses import JSONResponse, Response from starlette.routing import Route, Router @@ -38,7 +37,7 @@ ) from .exceptions import MalformedQuery, db_exception_handler -from .serializers import Config, create_pydantic_model +from .serializers import create_pydantic_model from .validators import Validators, apply_validators if t.TYPE_CHECKING: # pragma: no cover @@ -307,7 +306,7 @@ def _pydantic_model_output( self, include_readable: bool = False, include_columns: t.Tuple[Column, ...] = (), - nested: t.Union[bool, t.Tuple[Column, ...]] = False, + nested: t.Union[bool, t.Tuple[ForeignKey, ...]] = False, ) -> t.Type[pydantic.BaseModel]: return create_pydantic_model( self.table, @@ -343,8 +342,8 @@ def pydantic_model_plural( self, include_readable=False, include_columns: t.Tuple[Column, ...] = (), - nested: t.Union[bool, t.Tuple[Column, ...]] = False, - ): + nested: t.Union[bool, t.Tuple[ForeignKey, ...]] = False, + ) -> t.Type[pydantic.BaseModel]: """ This is for when we want to serialise many copies of the model. """ @@ -358,7 +357,9 @@ def pydantic_model_plural( ) return pydantic.create_model( str(self.table.__name__) + "Plural", - __config__=Config, + __config__=pydantic.config.ConfigDict( + arbitrary_types_allowed=True + ), rows=(t.List[base_model], None), ) @@ -367,7 +368,7 @@ async def get_schema(self, request: Request) -> JSONResponse: """ Return a representation of the model, so a UI can generate a form. """ - return JSONResponse(self.pydantic_model.schema()) + return JSONResponse(self.pydantic_model.model_json_schema()) ########################################################################### @@ -713,7 +714,7 @@ def _apply_filters( """ fields = params.fields if fields: - model_dict = self.pydantic_model_optional(**fields).dict() + model_dict = self.pydantic_model_optional(**fields).model_dump() for field_name in fields.keys(): value = model_dict.get(field_name, ...) if value is ...: @@ -778,7 +779,9 @@ async def get_all( nested: t.Union[bool, t.Tuple[Column, ...]] if visible_fields: nested = tuple( - i for i in visible_fields if len(i._meta.call_chain) > 0 + i._meta.call_chain[-1] + for i in visible_fields + if len(i._meta.call_chain) > 0 ) else: visible_fields = self.table._meta.columns @@ -865,7 +868,7 @@ async def get_all( include_readable=include_readable, include_columns=tuple(visible_fields), nested=nested, - )(rows=rows).json() + )(rows=rows).model_dump_json() return CustomJSONResponse(json, headers=headers) ########################################################################### @@ -894,19 +897,19 @@ async def post_single( cleaned_data = self._clean_data(data) try: model = self.pydantic_model(**cleaned_data) - except ValidationError as exception: + except pydantic.ValidationError as exception: return Response(str(exception), status_code=400) if issubclass(self.table, BaseUser): try: - user = await self.table.create_user(**model.dict()) + user = await self.table.create_user(**model.model_dump()) json = dump_json({"id": user.id}) return CustomJSONResponse(json, status_code=201) except Exception as e: return Response(f"Error: {e}", status_code=400) else: try: - row = self.table(**model.dict()) + row = self.table(**model.model_dump()) if self._hook_map: row = await execute_post_hooks( hooks=self._hook_map, @@ -969,7 +972,7 @@ async def get_new(self, request: Request) -> CustomJSONResponse: row_dict.pop(column_name) return CustomJSONResponse( - self.pydantic_model_optional(**row_dict).json() + self.pydantic_model_optional(**row_dict).model_dump_json() ) ########################################################################### @@ -1053,11 +1056,13 @@ async def get_single(self, request: Request, row_id: PK_TYPES) -> Response: return Response(str(exception), status_code=400) # Visible fields - nested: t.Union[bool, t.Tuple[Column, ...]] + nested: t.Union[bool, t.Tuple[ForeignKey, ...]] visible_fields = split_params.visible_fields if visible_fields: nested = tuple( - i for i in visible_fields if len(i._meta.call_chain) > 0 + i._meta.call_chain[-1] + for i in visible_fields + if len(i._meta.call_chain) > 0 ) else: visible_fields = self.table._meta.columns @@ -1098,7 +1103,7 @@ async def get_single(self, request: Request, row_id: PK_TYPES) -> Response: include_readable=split_params.include_readable, include_columns=tuple(visible_fields), nested=nested, - )(**row).json() + )(**row).model_dump_json() ) @apply_validators @@ -1113,7 +1118,7 @@ async def put_single( try: model = self.pydantic_model(**cleaned_data) - except ValidationError as exception: + except pydantic.ValidationError as exception: return Response(str(exception), status_code=400) cls = self.table @@ -1123,7 +1128,6 @@ async def put_single( } try: - await cls.update(values).where( cls._meta.primary_key == row_id ).run() @@ -1143,7 +1147,7 @@ async def patch_single( try: model = self.pydantic_model_optional(**cleaned_data) - except ValidationError as exception: + except pydantic.ValidationError as exception: return Response(str(exception), status_code=400) cls = self.table @@ -1168,7 +1172,9 @@ async def patch_single( for key in data.keys() } except AttributeError: - unrecognised_keys = set(data.keys()) - set(model.dict().keys()) + unrecognised_keys = set(data.keys()) - set( + model.model_dump().keys() + ) return Response( f"Unrecognised keys - {unrecognised_keys}.", status_code=400, @@ -1195,7 +1201,7 @@ async def patch_single( ) assert new_row return CustomJSONResponse( - self.pydantic_model(**new_row).json() + self.pydantic_model(**new_row).model_dump_json() ) except ValueError: return Response( diff --git a/piccolo_api/crud/serializers.py b/piccolo_api/crud/serializers.py index 5472340b..916a6e33 100644 --- a/piccolo_api/crud/serializers.py +++ b/piccolo_api/crud/serializers.py @@ -1,3 +1,3 @@ -from piccolo.utils.pydantic import Config, create_pydantic_model # noqa +from piccolo.utils.pydantic import create_pydantic_model # noqa -__all__ = ["Config", "create_pydantic_model"] +__all__ = ["create_pydantic_model"] diff --git a/piccolo_api/crud/validators.py b/piccolo_api/crud/validators.py index a86bbaa3..a6c1cdc5 100644 --- a/piccolo_api/crud/validators.py +++ b/piccolo_api/crud/validators.py @@ -85,7 +85,7 @@ def apply_validators(function): :class:`PiccoloCRUD`. """ - async def run_validators(*args, **kwargs): + async def run_validators(*args, **kwargs) -> None: piccolo_crud: PiccoloCRUD = args[0] validators = piccolo_crud.validators diff --git a/piccolo_api/fastapi/endpoints.py b/piccolo_api/fastapi/endpoints.py index 78d01ad9..ba7c1828 100644 --- a/piccolo_api/fastapi/endpoints.py +++ b/piccolo_api/fastapi/endpoints.py @@ -76,6 +76,36 @@ class ReferencesModel(BaseModel): references: t.List[ReferenceModel] +def _get_type(type_: t.Type) -> t.Type: + """ + Extract the inner type from an optional if necessary, otherwise return + the type as is. + + For example:: + + >>> get_type(Optional[int]) + int + + >>> get_type(int) + int + + >>> get_type(list[str]) + list[str] + + """ + origin = t.get_origin(type_) + + # Note: even if `t.Optional` is passed in, the origin is still a + # `t.Union`. + if origin is t.Union: + args = t.get_args(type_) + + if len(args) == 2 and None in args: + return [i for i in args if i is not None][0] + + return type_ + + class FastAPIWrapper: """ Wraps ``PiccoloCRUD`` so it can easily be integrated into FastAPI. @@ -413,8 +443,11 @@ def modify_signature( ), ] - for field_name, _field in model.__fields__.items(): - type_ = _field.outer_type_ + for field_name, _field in model.model_fields.items(): + annotation = _field.annotation + assert annotation is not None + type_ = _get_type(annotation) + parameters.append( Parameter( name=field_name, diff --git a/piccolo_api/media/local.py b/piccolo_api/media/local.py index 5f6e19a8..a1638883 100644 --- a/piccolo_api/media/local.py +++ b/piccolo_api/media/local.py @@ -173,7 +173,7 @@ async def get_file_keys(self) -> t.List[str]: Returns the file key for each file we have stored. """ file_keys = [] - for (_, _, filenames) in os.walk(self.media_path): + for _, _, filenames in os.walk(self.media_path): file_keys.extend(filenames) break diff --git a/piccolo_api/media/s3.py b/piccolo_api/media/s3.py index 807a3a9a..1dbc5457 100644 --- a/piccolo_api/media/s3.py +++ b/piccolo_api/media/s3.py @@ -23,10 +23,10 @@ def __init__( column: t.Union[Text, Varchar, Array], bucket_name: str, folder_name: t.Optional[str] = None, - connection_kwargs: t.Dict[str, t.Any] = None, + connection_kwargs: t.Optional[t.Dict[str, t.Any]] = None, sign_urls: bool = True, signed_url_expiry: int = 3600, - upload_metadata: t.Dict[str, t.Any] = None, + upload_metadata: t.Optional[t.Dict[str, t.Any]] = None, executor: t.Optional[Executor] = None, allowed_extensions: t.Optional[t.Sequence[str]] = ALLOWED_EXTENSIONS, allowed_characters: t.Optional[t.Sequence[str]] = ALLOWED_CHARACTERS, @@ -130,9 +130,9 @@ def __init__( self.boto3 = boto3 self.bucket_name = bucket_name - self.upload_metadata = upload_metadata + self.upload_metadata = upload_metadata or {} self.folder_name = folder_name - self.connection_kwargs = connection_kwargs + self.connection_kwargs = connection_kwargs or {} self.sign_urls = sign_urls self.signed_url_expiry = signed_url_expiry self.executor = executor or ThreadPoolExecutor(max_workers=10) @@ -181,7 +181,7 @@ def store_file_sync( file_key = self.generate_file_key(file_name=file_name, user=user) extension = file_key.rsplit(".", 1)[-1] client = self.get_client() - upload_metadata: t.Dict[str, t.Any] = self.upload_metadata or {} + upload_metadata: t.Dict[str, t.Any] = self.upload_metadata if extension in CONTENT_TYPE: upload_metadata["ContentType"] = CONTENT_TYPE[extension] @@ -374,9 +374,7 @@ def __hash__(self): return hash( ( "s3", - self.connection_kwargs.get("endpoint_url") - if self.connection_kwargs - else None, + self.connection_kwargs.get("endpoint_url"), self.bucket_name, self.folder_name, ) diff --git a/piccolo_api/shared/auth/junction.py b/piccolo_api/shared/auth/junction.py index 57a87db3..7e9aa769 100644 --- a/piccolo_api/shared/auth/junction.py +++ b/piccolo_api/shared/auth/junction.py @@ -21,7 +21,6 @@ def __init__(self, backends: t.Sequence[AuthenticationBackend]): async def authenticate( self, conn: HTTPConnection ) -> t.Optional[t.Tuple[AuthCredentials, BaseUser]]: - for backend in self.backends: try: response = await backend.authenticate(conn=conn) diff --git a/piccolo_api/shared/middleware/junction.py b/piccolo_api/shared/middleware/junction.py index b73508de..bf07eee6 100644 --- a/piccolo_api/shared/middleware/junction.py +++ b/piccolo_api/shared/middleware/junction.py @@ -15,13 +15,11 @@ def __init__(self, *routers: Router) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send): for router in self.routers: try: - asgi = await router(scope, receive=receive, send=send) + await router(scope, receive=receive, send=send) except HTTPException as exception: if exception.status_code != 404: raise exception else: - if getattr(asgi, "status_code", None) == 404: - continue return raise HTTPException(status_code=404) diff --git a/piccolo_api/token_auth/endpoints.py b/piccolo_api/token_auth/endpoints.py index 928657d1..26472ceb 100644 --- a/piccolo_api/token_auth/endpoints.py +++ b/piccolo_api/token_auth/endpoints.py @@ -42,7 +42,6 @@ async def get_token(self, username: str, password: str) -> t.Optional[str]: class TokenAuthLoginEndpoint(HTTPEndpoint): - token_provider: TokenProvider = PiccoloTokenProvider() async def post(self, request: Request) -> Response: diff --git a/pyproject.toml b/pyproject.toml index 61d2c70e..d4812e3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.black] line-length = 79 -target-version = ['py37', 'py38', 'py39'] +target-version = ['py38', 'py39', 'py310', 'py311'] [tool.isort] profile = "black" @@ -9,7 +9,6 @@ line_length = 79 [tool.mypy] [[tool.mypy.overrides]] module = [ - "asyncpg.pgproto.pgproto", "asyncpg.exceptions", "jinja2", "uvicorn", diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt index 71c89d90..a22b0286 100644 --- a/requirements/dev-requirements.txt +++ b/requirements/dev-requirements.txt @@ -1,6 +1,6 @@ -black>=21.7b0 -isort==5.10.1 +black==23.7.0 +isort==5.12.0 twine==4.0.2 -mypy==0.950 +mypy==1.5.1 pip-upgrader==1.4.15 -wheel==0.40.0 +wheel==0.41.2 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 9f847a5f..6c04781f 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,7 +1,7 @@ Jinja2>=2.11.0 -piccolo[postgres]>=0.104.0 -pydantic[email]>=1.6,<2.0 +piccolo[postgres]>=1.0a1 +pydantic[email]>=2.0 python-multipart>=0.0.5 -fastapi>=0.87.0,<0.100.0 +fastapi>=0.100.0 PyJWT>=2.0.0 httpx>=0.20.0 diff --git a/tests/crud/test_crud_endpoints.py b/tests/crud/test_crud_endpoints.py index 9a42944a..b16b5660 100644 --- a/tests/crud/test_crud_endpoints.py +++ b/tests/crud/test_crud_endpoints.py @@ -183,7 +183,6 @@ def test_patch_succeeds(self): self.assertEqual(movies[0]["name"], new_name) def test_patch_user_new_password(self): - client = TestClient(PiccoloCRUD(table=BaseUser, read_only=False)) json = { @@ -212,7 +211,6 @@ def test_patch_user_new_password(self): self.assertEqual(response.status_code, 200) def test_patch_user_old_password(self): - client = TestClient(PiccoloCRUD(table=BaseUser, read_only=False)) json = { @@ -241,7 +239,6 @@ def test_patch_user_old_password(self): self.assertEqual(response.status_code, 200) def test_patch_user_fails(self): - client = TestClient(PiccoloCRUD(table=BaseUser, read_only=False)) json = { @@ -282,6 +279,28 @@ def test_patch_fails(self): response = client.patch(f"/{movie.id}/", json={"foo": "bar"}) self.assertEqual(response.status_code, 400) + def test_patch_validation_error(self): + """ + Check if Pydantic validation error works. + """ + client = TestClient(PiccoloCRUD(table=Movie, read_only=False)) + + movie = Movie(name="Star Wars", rating=93) + movie.save().run_sync() + + response = client.patch( + f"/{movie.id}/", + json={"name": 95, "rating": "95"}, + ) + self.assertIn("validation error", str(response.content)) + self.assertEqual(response.status_code, 400) + + # Make sure nothing changed in the database: + self.assertListEqual( + Movie.select(Movie.name, Movie.rating).run_sync(), + [{"name": "Star Wars", "rating": 93}], + ) + class TestIDs(TestCase): def setUp(self): @@ -442,31 +461,34 @@ def test_get_schema(self): self.assertEqual( response.json(), { - "title": "MovieIn", - "type": "object", + "help_text": None, + "primary_key_name": "id", "properties": { "name": { - "title": "Name", + "extra": { + "choices": None, + "help_text": None, + "nullable": False, + }, "maxLength": 100, - "extra": {"help_text": None, "choices": None}, - "nullable": False, + "title": "Name", "type": "string", }, "rating": { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "default": None, + "extra": { + "choices": None, + "help_text": None, + "nullable": False, + }, "title": "Rating", - "extra": {"help_text": None, "choices": None}, - "nullable": False, - "type": "integer", }, }, "required": ["name"], - "help_text": None, - "visible_fields_options": [ - "id", - "name", - "rating", - ], - "primary_key_name": "id", + "title": "MovieIn", + "type": "object", + "visible_fields_options": ["id", "name", "rating"], }, ) @@ -493,33 +515,31 @@ class Rating(Enum): self.assertEqual( response.json(), { - "title": "ReviewIn", - "type": "object", + "help_text": None, + "primary_key_name": "id", "properties": { "score": { - "title": "Score", + "anyOf": [{"type": "integer"}, {"type": "null"}], + "default": None, "extra": { - "help_text": None, "choices": { - "bad": {"display_name": "Bad", "value": 1}, "average": { "display_name": "Average", "value": 2, }, + "bad": {"display_name": "Bad", "value": 1}, "good": {"display_name": "Good", "value": 3}, "great": {"display_name": "Great", "value": 4}, }, + "help_text": None, + "nullable": False, }, - "nullable": False, - "type": "integer", + "title": "Score", } }, - "help_text": None, - "visible_fields_options": [ - "id", - "score", - ], - "primary_key_name": "id", + "title": "ReviewIn", + "type": "object", + "visible_fields_options": ["id", "score"], }, ) @@ -538,30 +558,38 @@ def test_get_schema_with_joins(self): self.assertEqual( response.json(), { - "title": "RoleIn", - "type": "object", + "help_text": None, + "primary_key_name": "id", "properties": { "movie": { - "title": "Movie", + "anyOf": [{"type": "integer"}, {"type": "null"}], + "default": None, "extra": { + "choices": None, "foreign_key": True, - "to": "movie", - "target_column": "id", "help_text": None, - "choices": None, + "nullable": True, + "target_column": "id", + "to": "movie", }, - "nullable": True, - "type": "integer", + "title": "Movie", }, "name": { + "anyOf": [ + {"maxLength": 100, "type": "string"}, + {"type": "null"}, + ], + "default": None, + "extra": { + "choices": None, + "help_text": None, + "nullable": False, + }, "title": "Name", - "extra": {"help_text": None, "choices": None}, - "nullable": False, - "maxLength": 100, - "type": "string", }, }, - "help_text": None, + "title": "RoleIn", + "type": "object", "visible_fields_options": [ "id", "movie", @@ -570,7 +598,6 @@ def test_get_schema_with_joins(self): "movie.rating", "name", ], - "primary_key_name": "id", }, ) @@ -630,6 +657,24 @@ def test_put_existing(self): self.assertEqual(Movie.count().run_sync(), 1) + def test_put_validation_error(self): + """ + Check if Pydantic validation error works. + """ + client = TestClient(PiccoloCRUD(table=Movie, read_only=False)) + + movie = Movie(name="Star Wars", rating=93) + movie.save().run_sync() + + response = client.put( + f"/{movie.id}/", + json={"name": 95, "rating": "95"}, + ) + self.assertIn("validation error", str(response.content)) + self.assertEqual(response.status_code, 400) + + self.assertEqual(Movie.count().run_sync(), 1) + def test_put_new(self): """ We expect a 404 - we don't allow PUT requests to create new resources. @@ -1164,7 +1209,6 @@ def test_post_user_fails(self): self.assertEqual(response.status_code, 400) def test_validation_error(self): - """ Make sure a post returns a validation error with incorrect or missing data. diff --git a/tests/crud/test_validators.py b/tests/crud/test_validators.py index c485e7b9..67548ffd 100644 --- a/tests/crud/test_validators.py +++ b/tests/crud/test_validators.py @@ -5,7 +5,8 @@ from piccolo.columns import Integer, Varchar from piccolo.columns.readable import Readable from piccolo.table import Table -from starlette.exceptions import ExceptionMiddleware, HTTPException +from starlette.exceptions import HTTPException +from starlette.middleware.exceptions import ExceptionMiddleware from starlette.testclient import TestClient from piccolo_api.crud.endpoints import PiccoloCRUD, Validators diff --git a/tests/csrf/test_csrf.py b/tests/csrf/test_csrf.py index 710a5433..ae3ecafc 100644 --- a/tests/csrf/test_csrf.py +++ b/tests/csrf/test_csrf.py @@ -1,6 +1,6 @@ from unittest import TestCase -from starlette.exceptions import ExceptionMiddleware +from starlette.middleware.exceptions import ExceptionMiddleware from starlette.testclient import TestClient from piccolo_api.csrf.middleware import ( @@ -28,7 +28,6 @@ async def app(scope, receive, send): class TestCSRFMiddleware(TestCase): - csrf_token = CSRFMiddleware.get_new_token() incorrect_csrf_token = "abc123" diff --git a/tests/fastapi/test_fastapi_endpoints.py b/tests/fastapi/test_fastapi_endpoints.py index 1284a163..0cc9d59f 100644 --- a/tests/fastapi/test_fastapi_endpoints.py +++ b/tests/fastapi/test_fastapi_endpoints.py @@ -100,65 +100,78 @@ def test_schema(self): self.assertEqual( response.json(), { - "title": "MovieIn", - "type": "object", + "help_text": None, + "primary_key_name": "id", "properties": { "name": { + "anyOf": [ + {"maxLength": 100, "type": "string"}, + {"type": "null"}, + ], + "default": None, + "extra": { + "choices": None, + "help_text": None, + "nullable": False, + }, "title": "Name", - "extra": {"help_text": None, "choices": None}, - "maxLength": 100, - "nullable": False, - "type": "string", }, "rating": { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "default": None, + "extra": { + "choices": None, + "help_text": None, + "nullable": False, + }, "title": "Rating", - "extra": {"help_text": None, "choices": None}, - "nullable": False, - "type": "integer", }, }, - "help_text": None, - "visible_fields_options": [ - "id", - "name", - "rating", - ], - "primary_key_name": "id", + "title": "MovieIn", + "type": "object", + "visible_fields_options": ["id", "name", "rating"], }, ) def test_schema_joins(self): client = TestClient(app) response = client.get("/roles/schema/") - self.assertEqual(response.status_code, 200) self.assertEqual( response.json(), { - "title": "RoleIn", - "type": "object", + "help_text": None, + "primary_key_name": "id", "properties": { "movie": { - "title": "Movie", + "anyOf": [{"type": "integer"}, {"type": "null"}], + "default": None, "extra": { + "choices": None, "foreign_key": True, - "to": "movie", - "target_column": "id", "help_text": None, - "choices": None, + "nullable": True, + "target_column": "id", + "to": "movie", }, - "nullable": True, - "type": "integer", + "title": "Movie", }, "name": { + "anyOf": [ + {"maxLength": 100, "type": "string"}, + {"type": "null"}, + ], + "default": None, + "extra": { + "choices": None, + "help_text": None, + "nullable": False, + }, "title": "Name", - "extra": {"help_text": None, "choices": None}, - "nullable": False, - "maxLength": 100, - "type": "string", }, }, - "help_text": None, + "title": "RoleIn", + "type": "object", "visible_fields_options": [ "id", "movie", @@ -167,7 +180,6 @@ def test_schema_joins(self): "movie.rating", "name", ], - "primary_key_name": "id", }, ) @@ -197,7 +209,13 @@ def test_references(self): def test_delete(self): client = TestClient(app) - response = client.delete("/movies/?id=1") + response = client.delete("/movies/1/") + self.assertEqual(response.status_code, 204) + self.assertEqual(response.content, b"") + + def test_allow_bulk_delete(self): + client = TestClient(app) + response = client.delete("/movies/") self.assertEqual(response.status_code, 204) self.assertEqual(response.content, b"") diff --git a/tests/jwt_auth/test_jwt_endpoints.py b/tests/jwt_auth/test_jwt_endpoints.py index 5555491c..48652e34 100644 --- a/tests/jwt_auth/test_jwt_endpoints.py +++ b/tests/jwt_auth/test_jwt_endpoints.py @@ -12,7 +12,6 @@ class TestLoginEndpoint(TestCase): - credentials = {"username": "Bob", "password": "bob123"} def setUp(self): diff --git a/tests/session_auth/test_session.py b/tests/session_auth/test_session.py index 9f1f78be..78318123 100644 --- a/tests/session_auth/test_session.py +++ b/tests/session_auth/test_session.py @@ -6,8 +6,8 @@ from piccolo.utils.sync import run_sync from starlette.authentication import requires from starlette.endpoints import HTTPEndpoint -from starlette.exceptions import ExceptionMiddleware from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.middleware.exceptions import ExceptionMiddleware from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Mount, Route, Router from starlette.testclient import TestClient diff --git a/tests/token_auth/test_endpoints.py b/tests/token_auth/test_endpoints.py index f867a73a..0b295664 100644 --- a/tests/token_auth/test_endpoints.py +++ b/tests/token_auth/test_endpoints.py @@ -14,7 +14,6 @@ class TestLoginEndpoint(TestCase): - credentials = {"username": "Bob", "password": "bob123"} def setUp(self):