Skip to content

Commit

Permalink
refactor: implement ModelBindingsMapper utilizing a DataFrame
Browse files Browse the repository at this point in the history
This change introduces a refactor of the important ModelBindingsMapper
class using a pandas DataFrame for effecting grouping and aggregation
and a CurryModel utility for partially instantiating a given Pydantic
model with checks running for every partial application, allowing for
fast validation failure.

Closes #170.

Aggregation behavior is now only triggered for actually aggregated
fields, and not for top-level models as well. Therefore this also closes #181.
  • Loading branch information
lu-pl committed Jan 21, 2025
1 parent e5eca85 commit db0d1c4
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 63 deletions.
182 changes: 120 additions & 62 deletions rdfproxy/mapper.py
Original file line number Diff line number Diff line change
@@ -1,89 +1,147 @@
"""ModelBindingsMapper: Functionality for mapping binding maps to a Pydantic model."""

from collections.abc import Iterator
from typing import Any, Generic, get_args
from itertools import chain
from typing import Generic, get_args

import pandas as pd
from pandas.api.typing import DataFrameGroupBy
from pydantic import BaseModel
from rdfproxy.utils._types import ModelBoolPredicate, _TModelInstance
from rdfproxy.utils.mapper_utils import (
_collect_values_from_bindings,
_get_group_by,
_get_key_from_metadata,
_is_list_basemodel_type,
_is_list_type,
get_model_bool_predicate,
)
from rdfproxy.utils.utils import CurryModel, FieldsBindingsMap


class ModelBindingsMapper(Generic[_TModelInstance]):
"""Utility class for mapping flat bindings to a (potentially nested) Pydantic model."""
"""Utility class for mapping bindings to nested/grouped Pydantic models.
RDFProxy utilizes Pydantic models also as a modelling grammar for grouping
and aggregation, mainly by treating the 'group_by' entry in ConfigDict in
combination with list-type annoted model fields as grouping
and aggregation indicators. _ModelBindingsMapper applies this grammar
for mapping flat bindings to potentially nested and grouped Pydantic models.
Note: _ModelBindingsMapper is intended for use in rdfproxy.SPARQLModelAdapter and -
since no model sanity checking runs in the mapper itself - somewhat coupled to
SPARQLModelAdapter. The mapper can be useful in its own right though.
For standalone use, the initializer should be overwritten and model sanity checking
should be added to the _ModelBindingsMapper subclass.
"""

def __init__(self, model: type[_TModelInstance], *bindings: dict):
self.model = model
self.bindings = bindings
self._contexts = []

self.df = pd.DataFrame(data=self.bindings)
self.df.replace(pd.NA, None, inplace=True)

def get_models(self) -> list[_TModelInstance]:
"""Generate a list of (potentially nested) Pydantic models based on (flat) bindings."""
return self._get_unique_models(self.model, self.bindings)

def _get_unique_models(self, model, bindings):
"""Call the mapping logic and collect unique and non-empty models."""
models = []
model_bool_predicate: ModelBoolPredicate = get_model_bool_predicate(model)

for _bindings in bindings:
_model = model(**dict(self._generate_binding_pairs(model, **_bindings)))

if model_bool_predicate(_model) and (_model not in models):
models.append(_model)

return models

def _get_group_by(self, model) -> str:
"""Get the group_by value from a model and register it in self._contexts."""
group_by: str = _get_group_by(model)

if group_by not in self._contexts:
self._contexts.append(group_by)

return group_by

def _generate_binding_pairs(
self,
model: type[BaseModel],
**kwargs,
) -> Iterator[tuple[str, Any]]:
"""Generate an Iterator[tuple] projection of the bindings needed for model instantation."""
for k, v in model.model_fields.items():
if _is_list_basemodel_type(v.annotation):
group_by: str = self._get_group_by(model)
group_model, *_ = get_args(v.annotation)

applicable_bindings = filter(
lambda x: (x[group_by] == kwargs[group_by])
and (x[self._contexts[0]] == kwargs[self._contexts[0]]),
self.bindings,
"""Run the model mapping logic against bindings and collect a list of model instances."""
return list(self._instantiate_models(self.df, self.model))

def _instantiate_models(
self, df: pd.DataFrame, model: type[_TModelInstance]
) -> Iterator[_TModelInstance]:
"""Generate potentially nested and grouped model instances from a dataframe.
Note: The DataFrameGroupBy object must not be sorted,
else the result set order will not be maintained.
"""
alias_map = FieldsBindingsMap(model=model)

if (_group_by := model.model_config.get("group_by")) is None:
for _, row in df.iterrows():
yield self._instantiate_ungrouped_model_from_row(row, model)
else:
group_by = alias_map[_group_by]
group_by_object: DataFrameGroupBy = df.groupby(group_by, sort=False)

for _, group_df in group_by_object:
yield self._instantiate_grouped_model_from_df(group_df, model)

def _instantiate_ungrouped_model_from_row(
self, row: pd.Series, model: type[_TModelInstance]
) -> _TModelInstance:
"""Instantiate an ungrouped model from a pd.Series row.
This handles the UNGROUPED code path in _ModelBindingsMapper._instantiate_models.
"""
alias_map = FieldsBindingsMap(model=model)
curried_model = CurryModel(model=model)

for field_name, field_info in model.model_fields.items():
if isinstance(nested_model := field_info.annotation, type(BaseModel)):
curried_model(
**{
field_name: self._instantiate_ungrouped_model_from_row(
row,
nested_model, # type: ignore
)
}
)
value = self._get_unique_models(group_model, applicable_bindings)
else:
field_value = row.get(alias_map[field_name], None) or field_info.default
curried_model(**{field_name: field_value})

elif _is_list_type(v.annotation):
group_by: str = self._get_group_by(model)
applicable_bindings = filter(
lambda x: x[group_by] == kwargs[group_by],
self.bindings,
)
model_instance = curried_model()
assert isinstance(model_instance, model) # type narrow
return model_instance

@staticmethod
def _get_unique_models(models: Iterator[_TModelInstance]) -> list[_TModelInstance]:
"""Get a list of unique models from an iterable.
Unless frozen=True is specified in a model class,
Pydantic models instances are not hashable, i.e. dict.fromkeys
is not feasable for acquiring ordered unique models.
"""
unique_models = []

binding_key: str = _get_key_from_metadata(v, default=k)
value = _collect_values_from_bindings(binding_key, applicable_bindings)
_model = next(models, None)
assert _model is not None, "StopIteration should be unreachable"

elif isinstance(v.annotation, type(BaseModel)):
nested_model = v.annotation
value = nested_model(
**dict(self._generate_binding_pairs(nested_model, **kwargs))
model_bool_predicate: ModelBoolPredicate = get_model_bool_predicate(_model)

for model in chain([_model], models):
if (model not in unique_models) and (model_bool_predicate(model)):
unique_models.append(model)

return unique_models

def _instantiate_grouped_model_from_df(
self, df: pd.DataFrame, model: type[_TModelInstance]
) -> _TModelInstance:
"""Instantiate a grouped model from a pd.DataFrame (a group dataframe).
This handles the GROUPED code path in _ModelBindingsMapper._instantiate_models.
"""
alias_map = FieldsBindingsMap(model=model)
curried_model = CurryModel(model=model)

for field_name, field_info in model.model_fields.items():
if _is_list_basemodel_type(field_info.annotation):
nested_model, *_ = get_args(field_info.annotation)
value = self._get_unique_models(
self._instantiate_models(df, nested_model)
)
elif _is_list_type(field_info.annotation):
value = list(dict.fromkeys(df[alias_map[field_name]].dropna()))
elif isinstance(nested_model := field_info.annotation, type(BaseModel)):
first_row = df.iloc[0]
value = self._instantiate_ungrouped_model_from_row(
first_row,
nested_model, # type: ignore
)
else:
binding_key: str = _get_key_from_metadata(v, default=k)
value = kwargs.get(binding_key, v.default)
first_row = df.iloc[0]
value = first_row.get(alias_map[field_name]) or field_info.default

curried_model(**{field_name: value})

yield k, value
model_instance = curried_model()
assert isinstance(model_instance, model) # type narrow
return model_instance
51 changes: 50 additions & 1 deletion rdfproxy/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import UserDict
from collections.abc import Callable
from functools import partial
from typing import TypeVar
from typing import Any, Generic, Self, TypeVar

from rdfproxy.utils._types import _TModelInstance
from rdfproxy.utils._types import SPARQLBinding
Expand Down Expand Up @@ -70,3 +70,52 @@ def __call__(self, query) -> str:
if tkwargs := {k: v for k, v in self.kwargs.items() if v is not None}:
return partial(self.f, **tkwargs)(query)
return query


class CurryModel(Generic[_TModelInstance]):
"""Constructor for currying a Pydantic Model.
A CurryModel instance can be called with kwargs which are run against
the respective model field validators and kept in a kwargs cache.
Once the model can be instantiated, calling a CurryModel object will
instantiate the Pydantic model and return the model instance.
If the eager flag is True (default), model field default values are
added to the cache automatically, which means that models can be instantiated
as soon possible, i.e. as soon as all /required/ field values are provided.
"""

def __init__(self, model: type[_TModelInstance], eager: bool = True) -> None:
self.model = model
self.eager = eager

self._kwargs_cache: dict = (
{k: v.default for k, v in model.model_fields.items() if not v.is_required()}
if eager
else {}
)

def __repr__(self): # pragma: no cover
return f"CurryModel object {self._kwargs_cache}"

@staticmethod
def _validate_field(model: type[_TModelInstance], field: str, value: Any) -> Any:
"""Validate value for a single field given a model.
Note: Using a TypeVar for value is not possible here,
because Pydantic might coerce values (if not not in Strict Mode).
"""
result = model.__pydantic_validator__.validate_assignment(
model.model_construct(), field, value
)
return result

def __call__(self, **kwargs: Any) -> Self | _TModelInstance:
for k, v in kwargs.items():
self._validate_field(self.model, k, v)

self._kwargs_cache.update(kwargs)

if self.model.model_fields.keys() == self._kwargs_cache.keys():
return self.model(**self._kwargs_cache)
return self

0 comments on commit db0d1c4

Please sign in to comment.