From b2b510c500f7da7d5ae563218f231ff6b38563bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Loipf=C3=BChrer?= Date: Fri, 27 Dec 2024 17:34:20 +0100 Subject: [PATCH] fix(core): adapt tests and exceptions to sftkit changes --- abrechnung/admin.py | 2 +- abrechnung/application/groups.py | 2 +- abrechnung/application/users.py | 9 ++++++++- abrechnung/demo.py | 2 +- abrechnung/mailer.py | 2 +- abrechnung/util.py | 9 +++++++++ tests/common.py | 11 ++++++----- tests/http_tests/common.py | 7 +++++++ tests/http_tests/test_auth.py | 2 +- tests/http_tests/test_groups.py | 8 ++++---- tools/generate_dummy_data.py | 2 +- 11 files changed, 40 insertions(+), 16 deletions(-) diff --git a/abrechnung/admin.py b/abrechnung/admin.py index cf8b4dbe..6142b258 100644 --- a/abrechnung/admin.py +++ b/abrechnung/admin.py @@ -16,7 +16,7 @@ async def create_user(config: Config, name: str, email: str, skip_email_check: b print("Passwords do not match!") return - database = get_database(config) + database = get_database(config.database) db_pool = await database.create_pool() user_service = UserService(db_pool, config) user_service.enable_registration = True diff --git a/abrechnung/application/groups.py b/abrechnung/application/groups.py index 8d066f41..091e02a3 100644 --- a/abrechnung/application/groups.py +++ b/abrechnung/application/groups.py @@ -330,7 +330,7 @@ async def delete_group(self, *, conn: Connection, user: User, group_id: int): group_id, ) if n_members != 1: - raise PermissionError(f"Can only delete a group when you are the last member") + raise InvalidArgument(f"Can only delete a group when you are the last member") await conn.execute("delete from grp where id = $1", group_id) diff --git a/abrechnung/application/users.py b/abrechnung/application/users.py index 3992f46c..bffd7b2c 100644 --- a/abrechnung/application/users.py +++ b/abrechnung/application/users.py @@ -12,6 +12,7 @@ from abrechnung.config import Config from abrechnung.domain.users import Session, User +from abrechnung.util import is_valid_uuid ALGORITHM = "HS256" @@ -246,6 +247,8 @@ async def register_user( @with_db_transaction async def confirm_registration(self, *, conn: Connection, token: str) -> int: + if not is_valid_uuid(token): + raise InvalidArgument(f"Invalid confirmation token") row = await conn.fetchrow( "select user_id, valid_until from pending_registration where token = $1", token, @@ -342,6 +345,8 @@ async def request_email_change(self, *, conn: Connection, user: User, password: @with_db_transaction async def confirm_email_change(self, *, conn: Connection, token: str) -> int: + if not is_valid_uuid(token): + raise InvalidArgument(f"Invalid confirmation token") row = await conn.fetchrow( "select user_id, new_email, valid_until from pending_email_change where token = $1", token, @@ -360,7 +365,7 @@ async def confirm_email_change(self, *, conn: Connection, token: str) -> int: async def request_password_recovery(self, *, conn: Connection, email: str): user_id = await conn.fetchval("select id from usr where email = $1", email) if not user_id: - raise PermissionError + raise InvalidArgument("permission denied") await conn.execute( "insert into pending_password_recovery (user_id) values ($1)", @@ -369,6 +374,8 @@ async def request_password_recovery(self, *, conn: Connection, email: str): @with_db_transaction async def confirm_password_recovery(self, *, conn: Connection, token: str, new_password: str) -> int: + if not is_valid_uuid(token): + raise InvalidArgument(f"Invalid confirmation token") row = await conn.fetchrow( "select user_id, valid_until from pending_password_recovery where token = $1", token, diff --git a/abrechnung/demo.py b/abrechnung/demo.py index e8cb3328..10dc5c86 100644 --- a/abrechnung/demo.py +++ b/abrechnung/demo.py @@ -14,7 +14,7 @@ async def cleanup(config: Config): deletion_threshold = datetime.now() - config.demo.wipe_interval - database = get_database(config) + database = get_database(config.database) db_pool = await database.create_pool() async with db_pool.acquire() as conn: async with conn.transaction(): diff --git a/abrechnung/mailer.py b/abrechnung/mailer.py index 8065fdcc..bbf27bb1 100644 --- a/abrechnung/mailer.py +++ b/abrechnung/mailer.py @@ -18,7 +18,7 @@ def __init__(self, config: Config): self.config = config self.events: Optional[asyncio.Queue] = None self.psql: Connection | None = None - self.database = get_database(config) + self.database = get_database(config.database) self.mailer = None self.logger = logging.getLogger(__name__) diff --git a/abrechnung/util.py b/abrechnung/util.py index 40bffa46..5c068247 100644 --- a/abrechnung/util.py +++ b/abrechnung/util.py @@ -1,5 +1,6 @@ import logging import re +import uuid from datetime import datetime, timedelta, timezone postgres_timestamp_format = re.compile( @@ -63,3 +64,11 @@ def log_setup(setting, default=1): def clamp(number, smallest, largest): """return number but limit it to the inclusive given value range""" return max(smallest, min(number, largest)) + + +def is_valid_uuid(val: str): + try: + uuid.UUID(val) + return True + except ValueError: + return False diff --git a/tests/common.py b/tests/common.py index 958848dd..6e33412e 100644 --- a/tests/common.py +++ b/tests/common.py @@ -24,11 +24,12 @@ def get_test_db_config() -> DatabaseConfig: return DatabaseConfig( - user=os.environ.get("TEST_DB_USER", "abrechnung-test"), - password=os.environ.get("TEST_DB_PASSWORD", "asdf1234"), - host=os.environ.get("TEST_DB_HOST", "localhost"), - dbname=os.environ.get("TEST_DB_DATABASE", "abrechnung-test"), + user=os.environ.get("TEST_DB_USER"), + password=os.environ.get("TEST_DB_PASSWORD"), + host=os.environ.get("TEST_DB_HOST"), + dbname=os.environ.get("TEST_DB_DATABASE", "abrechnung_test"), port=int(os.environ.get("TEST_DB_PORT", 5432)), + sslrootcert=None, ) @@ -57,7 +58,7 @@ async def get_test_db() -> Pool: """ get a connection pool to the test database """ - database = get_database(TEST_CONFIG) + database = get_database(TEST_CONFIG.database) pool = await database.create_pool() await reset_schema(pool) diff --git a/tests/http_tests/common.py b/tests/http_tests/common.py index c62aecfd..694666af 100644 --- a/tests/http_tests/common.py +++ b/tests/http_tests/common.py @@ -1,5 +1,6 @@ # pylint: disable=attribute-defined-outside-init from httpx import ASGITransport, AsyncClient +from sftkit.http._context import ContextMiddleware from abrechnung.http.api import Api from tests.common import TEST_CONFIG, BaseTestCase @@ -12,6 +13,12 @@ async def asyncSetUp(self) -> None: self.http_service = Api(config=self.test_config) await self.http_service._setup() + # workaround for bad testability in sftkit + self.http_service.server.api.add_middleware( + ContextMiddleware, + context=self.http_service.context, + ) + self.transport = ASGITransport(app=self.http_service.server.api) self.client = AsyncClient(transport=self.transport, base_url="https://abrechnung.sft.lol") self.transaction_service = self.http_service.transaction_service diff --git a/tests/http_tests/test_auth.py b/tests/http_tests/test_auth.py index 92c54ffa..62b9c6df 100644 --- a/tests/http_tests/test_auth.py +++ b/tests/http_tests/test_auth.py @@ -175,7 +175,7 @@ async def test_reset_password(self): f"/api/v1/auth/recover_password", json={"email": "fooo@stusta.de"}, ) - self.assertEqual(403, resp.status_code) + self.assertEqual(400, resp.status_code) resp = await self.client.post( f"/api/v1/auth/recover_password", diff --git a/tests/http_tests/test_groups.py b/tests/http_tests/test_groups.py index 8e61b0c0..5a719b7a 100644 --- a/tests/http_tests/test_groups.py +++ b/tests/http_tests/test_groups.py @@ -51,7 +51,7 @@ async def test_create_group(self): group = await self._fetch_group(group_id) self.assertEqual("name", group["name"]) - await self._fetch_group(13333, 404) + await self._fetch_group(13333, 400) resp = await self._post( f"/api/v1/groups/{group_id}", @@ -128,12 +128,12 @@ async def test_delete_group(self): ) resp = await self._delete(f"/api/v1/groups/{group_id}") - self.assertEqual(403, resp.status_code) + self.assertEqual(400, resp.status_code) resp = await self._post(f"/api/v1/groups/{group_id}/leave") self.assertEqual(204, resp.status_code) - await self._fetch_group(group_id, expected_status=404) + await self._fetch_group(group_id, expected_status=400) resp = await self.client.delete( f"/api/v1/groups/{group_id}", @@ -345,7 +345,7 @@ async def test_get_account(self): self.assertEqual(422, resp.status_code) resp = await self._get(f"/api/v1/groups/{group_id}/accounts/13232") - self.assertEqual(404, resp.status_code) + self.assertEqual(400, resp.status_code) async def test_invites(self): group_id = await self.group_service.create_group( diff --git a/tools/generate_dummy_data.py b/tools/generate_dummy_data.py index 87e04e43..683aaf2f 100755 --- a/tools/generate_dummy_data.py +++ b/tools/generate_dummy_data.py @@ -35,7 +35,7 @@ async def main( ): config = read_config(Path(config_path)) - database = get_database(config) + database = get_database(config.database) db_pool = await database.create_pool() user_service = UserService(db_pool, config) group_service = GroupService(db_pool, config)