diff --git a/fastcrud/crud/fast_crud.py b/fastcrud/crud/fast_crud.py index ddc8582..96b1bc6 100644 --- a/fastcrud/crud/fast_crud.py +++ b/fastcrud/crud/fast_crud.py @@ -1264,30 +1264,36 @@ async def get_multi_joined( stmt = stmt.offset(offset).limit(limit) result = await db.execute(stmt) - data: list[dict] = [dict(row) for row in result.mappings().all()] + data: list[Union[dict, BaseModel]] = [] + for row in result.mappings().all(): + row_dict = dict(row) - if nest_joins: - data = [_nest_join_data(row, join_definitions) for row in data] + if nest_joins: + row_dict = _nest_join_data(row_dict, join_definitions) + + if return_as_model: + if schema_to_select is None: + raise ValueError( + "schema_to_select must be provided when return_as_model is True." + ) + try: + model_instance = schema_to_select(**row_dict) + data.append(model_instance) + except ValidationError as e: + raise ValueError( + f"Data validation error for schema {schema_to_select.__name__}: {e}" + ) + else: + data.append(row_dict) response: dict[str, Any] = {"data": data} if return_total_count: - total_count = await self.count(db=db, joins_config=joins_config, **kwargs) + total_count: int = await self.count( + db=db, joins_config=joins_config, **kwargs + ) response["total_count"] = total_count - if return_as_model: - if not schema_to_select: - raise ValueError( - "schema_to_select must be provided when return_as_model is True." - ) - try: - model_data: list[BaseModel] = [schema_to_select(**row) for row in data] - response["data"] = model_data - except ValidationError as e: - raise ValueError( - f"Data validation error for schema {schema_to_select.__name__}: {e}" - ) - return response async def get_multi_by_cursor(