-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: implement ModelBindingsMapper utilizing a DataFrame
This change introduces a refactor of the important ModelBindingsMapper class using a pandas DataFrame for effecting grouping and aggregation and a CurryModel utility for partially instantiating a given Pydantic model with checks running for every partial application, allowing for fast validation failure. Closes #170. Aggregation behavior is now only triggered for actually aggregated fields, and not for top-level models as well. Therefore this also closes #181.
- Loading branch information
Showing
2 changed files
with
172 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,89 +1,146 @@ | ||
"""ModelBindingsMapper: Functionality for mapping binding maps to a Pydantic model.""" | ||
|
||
from collections.abc import Iterator | ||
from typing import Any, Generic, get_args | ||
from collections.abc import Iterable, Iterator | ||
from itertools import tee | ||
from typing import Generic, get_args | ||
|
||
import pandas as pd | ||
from pandas.api.typing import DataFrameGroupBy | ||
from pydantic import BaseModel | ||
from rdfproxy.utils._types import ModelBoolPredicate, _TModelInstance | ||
from rdfproxy.utils.mapper_utils import ( | ||
_collect_values_from_bindings, | ||
_get_group_by, | ||
_get_key_from_metadata, | ||
_is_list_basemodel_type, | ||
_is_list_type, | ||
get_model_bool_predicate, | ||
) | ||
from rdfproxy.utils.utils import CurryModel, FieldsBindingsMap | ||
|
||
|
||
class ModelBindingsMapper(Generic[_TModelInstance]): | ||
"""Utility class for mapping flat bindings to a (potentially nested) Pydantic model.""" | ||
"""Utility class for mapping bindings to nested/grouped Pydantic models. | ||
RDFProxy utilizes Pydantic models also as a modelling grammar for grouping | ||
and aggregation, mainly by treating the 'group_by' entry in ConfigDict in | ||
combination with list-type annoted model fields as grouping | ||
and aggregation indicators. _ModelBindingsMapper applies this grammar | ||
for mapping flat bindings to pontially nested and grouped Pydantic models. | ||
Note: _ModelBindingsMapper is intended for use in and - since no | ||
model sanity checking runs in the mapper itself - somewhat coupled to | ||
rdfproxy.SPARQLModelAdapter. The mapper can be useful in its own right | ||
though, in which case the initializer should be overwritten and | ||
model sanity checking should be added to the _ModelBindingsMapper subclass. | ||
""" | ||
|
||
def __init__(self, model: type[_TModelInstance], *bindings: dict): | ||
self.model = model | ||
self.bindings = bindings | ||
self._contexts = [] | ||
|
||
self.df = pd.DataFrame(data=self.bindings) | ||
self.df.replace(pd.NA, None, inplace=True) | ||
|
||
def get_models(self) -> list[_TModelInstance]: | ||
"""Generate a list of (potentially nested) Pydantic models based on (flat) bindings.""" | ||
return self._get_unique_models(self.model, self.bindings) | ||
|
||
def _get_unique_models(self, model, bindings): | ||
"""Call the mapping logic and collect unique and non-empty models.""" | ||
models = [] | ||
model_bool_predicate: ModelBoolPredicate = get_model_bool_predicate(model) | ||
|
||
for _bindings in bindings: | ||
_model = model(**dict(self._generate_binding_pairs(model, **_bindings))) | ||
|
||
if model_bool_predicate(_model) and (_model not in models): | ||
models.append(_model) | ||
|
||
return models | ||
|
||
def _get_group_by(self, model) -> str: | ||
"""Get the group_by value from a model and register it in self._contexts.""" | ||
group_by: str = _get_group_by(model) | ||
|
||
if group_by not in self._contexts: | ||
self._contexts.append(group_by) | ||
|
||
return group_by | ||
|
||
def _generate_binding_pairs( | ||
self, | ||
model: type[BaseModel], | ||
**kwargs, | ||
) -> Iterator[tuple[str, Any]]: | ||
"""Generate an Iterator[tuple] projection of the bindings needed for model instantation.""" | ||
for k, v in model.model_fields.items(): | ||
if _is_list_basemodel_type(v.annotation): | ||
group_by: str = self._get_group_by(model) | ||
group_model, *_ = get_args(v.annotation) | ||
|
||
applicable_bindings = filter( | ||
lambda x: (x[group_by] == kwargs[group_by]) | ||
and (x[self._contexts[0]] == kwargs[self._contexts[0]]), | ||
self.bindings, | ||
"""Run the model mapping logic against bindings and collect a list of model instances.""" | ||
return list(self._instantiate_models(self.df, self.model)) | ||
|
||
def _instantiate_models( | ||
self, df: pd.DataFrame, model: type[_TModelInstance] | ||
) -> Iterator[_TModelInstance]: | ||
"""Generate potentially nested and grouped model instances from a dataframe. | ||
Note: The DataFrameGroupBy object must not be sorted, | ||
else the result set order will not be maintained. | ||
""" | ||
alias_map = FieldsBindingsMap(model=model) | ||
|
||
if (_group_by := model.model_config.get("group_by")) is None: | ||
for _, row in df.iterrows(): | ||
yield self._instantiate_ungrouped_model_from_row(row, model) | ||
else: | ||
group_by = alias_map[_group_by] | ||
group_by_object: DataFrameGroupBy = df.groupby(group_by, sort=False) | ||
|
||
for _, group_df in group_by_object: | ||
yield self._instantiate_grouped_model_from_df(group_df, model) | ||
|
||
def _instantiate_ungrouped_model_from_row( | ||
self, row: pd.Series, model: type[_TModelInstance] | ||
) -> _TModelInstance: | ||
"""Instantiate an ungrouped model from a pd.Series row. | ||
This handles the UNGROUPED code path in _ModelBindingsMapper._instantiate_models. | ||
""" | ||
alias_map = FieldsBindingsMap(model=model) | ||
curried_model = CurryModel(model=model) | ||
|
||
for field_name, field_info in model.model_fields.items(): | ||
if isinstance(nested_model := field_info.annotation, type(BaseModel)): | ||
curried_model( | ||
**{ | ||
field_name: self._instantiate_ungrouped_model_from_row( | ||
row, | ||
nested_model, # type: ignore | ||
) | ||
} | ||
) | ||
value = self._get_unique_models(group_model, applicable_bindings) | ||
|
||
elif _is_list_type(v.annotation): | ||
group_by: str = self._get_group_by(model) | ||
applicable_bindings = filter( | ||
lambda x: x[group_by] == kwargs[group_by], | ||
self.bindings, | ||
else: | ||
field_value = row.get(alias_map[field_name], None) or field_info.default | ||
curried_model(**{field_name: field_value}) | ||
|
||
model_instance = curried_model() | ||
assert isinstance(model_instance, model) # type narrow | ||
return model_instance | ||
|
||
@staticmethod | ||
def _get_unique_models(models: Iterable[_TModelInstance]) -> list[_TModelInstance]: | ||
"""Get a list of unique models from an iterable. | ||
Unless frozen=True is specified in a model class, | ||
Pydantic models instances are not hashable, i.e. dict.fromkeys | ||
is not feasable for acquiring ordered unique models. | ||
""" | ||
unique_models = [] | ||
models, _models = tee(models) | ||
model_bool_predicate: ModelBoolPredicate = get_model_bool_predicate( | ||
next(_models) | ||
) | ||
|
||
for model in models: | ||
if (model not in unique_models) and (model_bool_predicate(model)): | ||
unique_models.append(model) | ||
|
||
return unique_models | ||
|
||
def _instantiate_grouped_model_from_df( | ||
self, df: pd.DataFrame, model: type[_TModelInstance] | ||
) -> _TModelInstance: | ||
"""Instantiate a grouped model from a pd.DataFrame (a group dataframe). | ||
This handles the GROUPED code path in _ModelBindingsMapper._instantiate_models. | ||
""" | ||
alias_map = FieldsBindingsMap(model=model) | ||
curried_model = CurryModel(model=model) | ||
|
||
for field_name, field_info in model.model_fields.items(): | ||
if _is_list_basemodel_type(field_info.annotation): | ||
nested_model, *_ = get_args(field_info.annotation) | ||
value = self._get_unique_models( | ||
self._instantiate_models(df, nested_model) | ||
) | ||
|
||
binding_key: str = _get_key_from_metadata(v, default=k) | ||
value = _collect_values_from_bindings(binding_key, applicable_bindings) | ||
|
||
elif isinstance(v.annotation, type(BaseModel)): | ||
nested_model = v.annotation | ||
value = nested_model( | ||
**dict(self._generate_binding_pairs(nested_model, **kwargs)) | ||
elif _is_list_type(field_info.annotation): | ||
value = list(dict.fromkeys(df[alias_map[field_name]].dropna())) | ||
elif isinstance(nested_model := field_info.annotation, type(BaseModel)): | ||
first_row = df.iloc[0] | ||
value = self._instantiate_ungrouped_model_from_row( | ||
first_row, | ||
nested_model, # type: ignore | ||
) | ||
else: | ||
binding_key: str = _get_key_from_metadata(v, default=k) | ||
value = kwargs.get(binding_key, v.default) | ||
first_row = df.iloc[0] | ||
value = first_row.get(alias_map[field_name]) or field_info.default | ||
|
||
curried_model(**{field_name: value}) | ||
|
||
yield k, value | ||
model_instance = curried_model() | ||
assert isinstance(model_instance, model) # type narrow | ||
return model_instance |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters