From 28df02bd4894248877f43e8e0ef3a2f851596647 Mon Sep 17 00:00:00 2001 From: Roman Reznikov Date: Sun, 10 Jul 2022 14:23:19 +0400 Subject: [PATCH] #22: feat: add common `MongoRepository` for working with data from MongoDB --- src/ugc/containers.py | 7 +++ src/ugc/infrastructure/db/repositories.py | 73 ++++++++++++++++++++++- src/ugc/infrastructure/db/types.py | 1 + 3 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 src/ugc/infrastructure/db/types.py diff --git a/src/ugc/containers.py b/src/ugc/containers.py index 2f0fa5a..5f2461f 100644 --- a/src/ugc/containers.py +++ b/src/ugc/containers.py @@ -1,4 +1,5 @@ import logging.config +from functools import partial import orjson from dependency_injector import containers, providers @@ -48,6 +49,12 @@ class Container(containers.DeclarativeContainer): repositories.RedisRepository, ) + mongo_repository_factory = partial( + providers.Factory, + provides=repositories.MongoRepository, + db=mongo_client, + ) + kafka_producer_client = providers.Resource( producers.init_kafka_producer_client, servers=providers.List(config.KAFKA_URL), diff --git a/src/ugc/infrastructure/db/repositories.py b/src/ugc/infrastructure/db/repositories.py index f8094c2..fe14c65 100644 --- a/src/ugc/infrastructure/db/repositories.py +++ b/src/ugc/infrastructure/db/repositories.py @@ -1,12 +1,22 @@ import functools from operator import and_ -from typing import Generic, Type, TypeVar +from typing import Any, Generic, Type, TypeVar from aredis_om.model.model import NotFoundError, RedisModel +from bson import ObjectId +from pymongo import ASCENDING +from ugc.domain.factories import BaseModelFactory +from ugc.domain.types import BaseModel from ugc.helpers import resolve_callables +from .clients import MongoCollectionClient, MongoDatabaseClient +from .types import PaginationCursor + _RM = TypeVar("_RM", bound=RedisModel) +_BM = TypeVar("_BM", bound=BaseModel) + +MongoResult = dict[str, Any] class RedisRepository(Generic[_RM]): @@ -72,3 +82,64 @@ def _extract_model_params(defaults: dict | None, **kwargs) -> dict: params = {k: v for k, v in kwargs.items()} params.update(defaults) return params + + +class MongoRepository(Generic[_BM]): + """Репозиторий для работы с данными из MongoDB.""" + + def __init__(self, db: MongoDatabaseClient, factory: BaseModelFactory, collection_name: str) -> None: + assert isinstance(db, MongoDatabaseClient) + self._db = db + + assert isinstance(factory, BaseModelFactory) + self._factory = factory + + assert isinstance(collection_name, str) + self._collection_name = collection_name + + @property + def collection(self) -> MongoCollectionClient: + return self._db[self._collection_name] + + def get_paginated_results_iter( + self, *, + limit: int, + cursor: PaginationCursor = None, + ordering: tuple[str, int] | None = None, filter_query: dict | None = None, pagination_query: dict | None = None, + ): + """Получение итератора по результатам с пагинацией.""" + if ordering is None: + ordering = ("_id", ASCENDING) + if pagination_query is None: + pagination_query = {"_id": {"$gt": ObjectId(cursor)}} + + if cursor is None: + return self.collection.find(filter_query).limit(limit).sort(*ordering) + filter_query.update(pagination_query) + return self.collection.find(filter_query).limit(limit).sort(*ordering) + + async def get_paginated_results( + self, *, + limit: int, + cursor: PaginationCursor = None, cursor_field: str, + ordering: tuple[str, int] | None = None, filter_query: dict | None = None, pagination_query: dict | None = None, + ) -> tuple[list[_BM], PaginationCursor]: + """Получение списка документов с использованием `cursor-based` пагинации.""" + raw_results = self.get_paginated_results_iter( + limit=limit, cursor=cursor, ordering=ordering, filter_query=filter_query, pagination_query=pagination_query) + results = [ + self._factory.create_from_serialized(self.fix_id_field(raw_result)) + async for raw_result in raw_results + ] + + if not results: + return [], None + + new_cursor = results[-1].__getattribute__(cursor_field) + return results, new_cursor + + @staticmethod + def fix_id_field(data: MongoResult) -> MongoResult: + """Переименование поля `_id`.""" + data["id"] = str(data.pop("_id")) + return data diff --git a/src/ugc/infrastructure/db/types.py b/src/ugc/infrastructure/db/types.py new file mode 100644 index 0000000..963c5df --- /dev/null +++ b/src/ugc/infrastructure/db/types.py @@ -0,0 +1 @@ +PaginationCursor = str | None