diff --git a/fastcrud/crud/fast_crud.py b/fastcrud/crud/fast_crud.py index 4c46291..db2a47b 100644 --- a/fastcrud/crud/fast_crud.py +++ b/fastcrud/crud/fast_crud.py @@ -1913,19 +1913,23 @@ async def get_multi_joined( join_definitions = joins_config if joins_config else [] if join_model: - join_definitions.append( - JoinConfig( - model=join_model, - join_on=join_on - or _auto_detect_join_condition(self.model, join_model), - join_prefix=join_prefix, - schema_to_select=join_schema_to_select, - join_type=join_type, - alias=alias, - filters=join_filters, - relationship_type=relationship_type, + try: + join_definitions.append( + JoinConfig( + model=join_model, + join_on=join_on + if join_on is not None + else _auto_detect_join_condition(self.model, join_model), + join_prefix=join_prefix, + schema_to_select=join_schema_to_select, + join_type=join_type, + alias=alias, + filters=join_filters, + relationship_type=relationship_type, + ) ) - ) + except ValueError as e: # pragma: no cover + raise ValueError(f"Could not configure join: {str(e)}") stmt = self._prepare_and_apply_joins( stmt=stmt, joins_config=join_definitions, use_temporary_prefix=nest_joins @@ -1983,7 +1987,7 @@ async def get_multi_joined( ( join.join_prefix.rstrip("_") if join.join_prefix - else join.model.__name__ + else join.model.__tablename__ ): join.schema_to_select for join in join_definitions if join.schema_to_select diff --git a/fastcrud/crud/helper.py b/fastcrud/crud/helper.py index 402ee3e..17d0fe9 100644 --- a/fastcrud/crud/helper.py +++ b/fastcrud/crud/helper.py @@ -524,61 +524,93 @@ def _nest_multi_join_data( """ pre_nested_data = {} + for row in data: + if isinstance(row, BaseModel): + new_row = { + key: ([] if isinstance(value, list) else value) + for key, value in row.model_dump().items() + } + else: + new_row = { + key: ([] if isinstance(value, list) else value) + for key, value in row.items() + } + + primary_key_value = new_row[base_primary_key] + if primary_key_value not in pre_nested_data: + pre_nested_data[primary_key_value] = new_row + for join_config in joins_config: join_primary_key = _get_primary_key(join_config.model) + join_prefix = ( + join_config.join_prefix.rstrip("_") + if join_config.join_prefix + else join_config.model.__tablename__ + ) for row in data: - if isinstance(row, BaseModel): - new_row = { - key: (value[:] if isinstance(value, list) else value) - for key, value in row.model_dump().items() - } - else: - new_row = { - key: (value[:] if isinstance(value, list) else value) - for key, value in row.items() - } + row_dict = row if isinstance(row, dict) else row.model_dump() + primary_key_value = row_dict[base_primary_key] - primary_key_value = new_row[base_primary_key] - - if primary_key_value not in pre_nested_data: - for key, value in new_row.items(): - if isinstance(value, list) and any( - item[join_primary_key] is None for item in value - ): # pragma: no cover - new_row[key] = [] - elif ( - isinstance(value, dict) and value[join_primary_key] is None - ): # pragma: no cover - new_row[key] = None - - pre_nested_data[primary_key_value] = new_row - else: - existing_row = pre_nested_data[primary_key_value] - for key, value in new_row.items(): + if join_config.relationship_type == "one-to-many": + if join_prefix in row_dict: + value = row_dict[join_prefix] if isinstance(value, list): if any( item[join_primary_key] is None for item in value ): # pragma: no cover - existing_row[key] = [] + pre_nested_data[primary_key_value][join_prefix] = [] else: - existing_row[key].extend(value) + existing_items = { + item[join_primary_key] + for item in pre_nested_data[primary_key_value][ + join_prefix + ] + } + for item in value: + if item[join_primary_key] not in existing_items: + pre_nested_data[primary_key_value][ + join_prefix + ].append(item) + existing_items.add(item[join_primary_key]) + else: + if join_prefix in row_dict: + value = row_dict[join_prefix] + if ( + isinstance(value, dict) and value.get(join_primary_key) is None + ): # pragma: no cover + pre_nested_data[primary_key_value][join_prefix] = None + elif isinstance(value, dict): + pre_nested_data[primary_key_value][join_prefix] = value nested_data: list = list(pre_nested_data.values()) if return_as_model: - for i, item in enumerate(nested_data): + if not schema_to_select: # pragma: no cover + raise ValueError( + "schema_to_select must be provided when return_as_model is True." + ) + + converted_data = [] + for item in nested_data: if nested_schema_to_select: - for prefix, schema in nested_schema_to_select.items(): - if prefix in item: - if isinstance(item[prefix], list): - item[prefix] = [ - schema(**nested_item) for nested_item in item[prefix] + for prefix, nested_schema in nested_schema_to_select.items(): + prefix_key = prefix.rstrip("_") + if prefix_key in item: + if isinstance(item[prefix_key], list): + item[prefix_key] = [ + nested_schema(**nested_item) + for nested_item in item[prefix_key] ] else: # pragma: no cover - item[prefix] = schema(**item[prefix]) - if schema_to_select: - nested_data[i] = schema_to_select(**item) + item[prefix_key] = ( + nested_schema(**item[prefix_key]) + if item[prefix_key] is not None + else None + ) + + converted_data.append(schema_to_select(**item)) + return converted_data return nested_data diff --git a/fastcrud/endpoint/endpoint_creator.py b/fastcrud/endpoint/endpoint_creator.py index fd4cf2f..d81948c 100644 --- a/fastcrud/endpoint/endpoint_creator.py +++ b/fastcrud/endpoint/endpoint_creator.py @@ -372,7 +372,7 @@ async def endpoint( ) return paginated_response( crud_data=crud_data, - page=page, # type: ignore + page=page, # type: ignore items_per_page=items_per_page, # type: ignore ) @@ -382,8 +382,8 @@ async def endpoint( crud_data = await self.crud.get_multi( db, - offset=offset, # type: ignore - limit=limit, # type: ignore + offset=offset, # type: ignore + limit=limit, # type: ignore **filters, ) return crud_data # pragma: no cover