Skip to content

Commit

Permalink
refactor: implement ModelBindingsMapper utilizing a DataFrame
Browse files Browse the repository at this point in the history
This change introduces a refactor of the important ModelBindingsMapper
class using a pandas DataFrame for effecting grouping and aggregation
and a CurryModel utility for partially instantiating a given Pydantic
model with checks running for every partial application, allowing for
fast validation failure.

Closes #170.

Aggregation behavior is now only triggered for actually aggregated
fields, and not for top-level models as well. Therefore this also closes #181.
  • Loading branch information
lu-pl committed Jan 20, 2025
1 parent 068f110 commit f133519
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 72 deletions.
198 changes: 127 additions & 71 deletions rdfproxy/mapper.py
Original file line number Diff line number Diff line change
@@ -1,89 +1,145 @@
"""ModelBindingsMapper: Functionality for mapping binding maps to a Pydantic model."""

from collections.abc import Iterator
from typing import Any, Generic, get_args
from collections.abc import Iterable, Iterator
from typing import Generic, get_args

import numpy as np
import pandas as pd
from pandas.api.typing import DataFrameGroupBy
from pydantic import BaseModel
from rdfproxy.utils._types import ModelBoolPredicate, _TModelInstance
from rdfproxy.utils.mapper_utils import (
_collect_values_from_bindings,
_get_group_by,
_get_key_from_metadata,
_is_list_basemodel_type,
_is_list_type,
get_model_bool_predicate,
)
from rdfproxy.utils._types import _TModelInstance
from rdfproxy.utils.mapper_utils import _is_list_basemodel_type, _is_list_type
from rdfproxy.utils.utils import CurryModel, FieldsBindingsMap


class ModelBindingsMapper(Generic[_TModelInstance]):
"""Utility class for mapping flat bindings to a (potentially nested) Pydantic model."""
"""Utility class for mapping bindings to nested/grouped Pydantic models.
RDFProxy utilizes Pydantic models also as a modelling grammar for grouping
and aggregation, mainly by treating the 'group_by' entry in ConfigDict in
combination with list-type annoted model fields as grouping
and aggregation indicators. _ModelBindingsMapper applies this grammar
for mapping flat bindings to pontially nested and grouped Pydantic models.
Note: _ModelBindingsMapper is intended for use in and - since no
model sanity checking runs in the mapper itself - somewhat coupled to
rdfproxy.SPARQLModelAdapter. The mapper can be useful in its own right
though, in which case the initializer should be overwritten and
model sanity checking should be added to the _ModelBindingsMapper subclass.
"""

def __init__(self, model: type[_TModelInstance], *bindings: dict):
self.model = model
self.bindings = bindings
self._contexts = []

self.df = pd.DataFrame(data=self.bindings)
self.df.replace(np.nan, None, inplace=True)

def get_models(self) -> list[_TModelInstance]:
"""Generate a list of (potentially nested) Pydantic models based on (flat) bindings."""
return self._get_unique_models(self.model, self.bindings)

def _get_unique_models(self, model, bindings):
"""Call the mapping logic and collect unique and non-empty models."""
models = []
model_bool_predicate: ModelBoolPredicate = get_model_bool_predicate(model)

for _bindings in bindings:
_model = model(**dict(self._generate_binding_pairs(model, **_bindings)))

if model_bool_predicate(_model) and (_model not in models):
models.append(_model)

return models

def _get_group_by(self, model) -> str:
"""Get the group_by value from a model and register it in self._contexts."""
group_by: str = _get_group_by(model)

if group_by not in self._contexts:
self._contexts.append(group_by)

return group_by

def _generate_binding_pairs(
self,
model: type[BaseModel],
**kwargs,
) -> Iterator[tuple[str, Any]]:
"""Generate an Iterator[tuple] projection of the bindings needed for model instantation."""
for k, v in model.model_fields.items():
if _is_list_basemodel_type(v.annotation):
group_by: str = self._get_group_by(model)
group_model, *_ = get_args(v.annotation)

applicable_bindings = filter(
lambda x: (x[group_by] == kwargs[group_by])
and (x[self._contexts[0]] == kwargs[self._contexts[0]]),
self.bindings,
"""Run the model mapping logic against bindings and collect a list of model instances."""
return list(self._instantiate_models(self.df, self.model))

def _instantiate_models(
self, df: pd.DataFrame, model: type[_TModelInstance]
) -> Iterator[_TModelInstance]:
"""Generate potentially nested and grouped model instances from a dataframe.
Note: The DataFrameGroupBy object must not be sorted,
else the result set order will not be maintained.
"""
alias_map = FieldsBindingsMap(model=model)

if (_group_by := model.model_config.get("group_by")) is None:
for _, row in df.iterrows():
yield self._instantiate_ungrouped_model_from_row(row, model)
else:
group_by = alias_map[_group_by]
group_by_object: DataFrameGroupBy = df.groupby(group_by, sort=False)

for _, group_df in group_by_object:
yield self._instantiate_grouped_model_from_df(group_df, model)

def _instantiate_ungrouped_model_from_row(
self, row: pd.Series, model: type[_TModelInstance]
) -> _TModelInstance:
"""Instantiate an ungrouped model from a pd.Series row.
This handles the UNGROUPED code path in _ModelBindingsMapper._instantiate_models.
"""
alias_map = FieldsBindingsMap(model=model)
curried_model = CurryModel(model=model)

for field_name, field_info in model.model_fields.items():
if isinstance(nested_model := field_info.annotation, type(BaseModel)):
curried_model(
**{
field_name: self._instantiate_ungrouped_model_from_row(
row,
nested_model, # type: ignore
)
}
)
value = self._get_unique_models(group_model, applicable_bindings)

elif _is_list_type(v.annotation):
group_by: str = self._get_group_by(model)
applicable_bindings = filter(
lambda x: x[group_by] == kwargs[group_by],
self.bindings,
else:
field_value = row.get(alias_map[field_name], None) or field_info.default
curried_model(**{field_name: field_value})

model_instance = curried_model()
assert isinstance(model_instance, model) # type narrow
return model_instance

def _get_scalar_type_from_grouped_df(self):
pass

@staticmethod
def _get_unique_models(models: Iterable[_TModelInstance]) -> list[_TModelInstance]:
"""Get a list of unique models from an iterable.
Unless frozen=True is specified in a model class,
Pydantic models instances are not hashable, i.e. dict.fromkeys
is not feasable for acquiring ordered unique models.
"""
unique_models = []

for model in models:
if (model not in unique_models) and any(dict(model).values()):
## re-check, but looks good
# if model not in unique_models:
unique_models.append(model)

return unique_models

def _instantiate_grouped_model_from_df(
self, df: pd.DataFrame, model: type[_TModelInstance]
) -> _TModelInstance:
"""Instantiate a grouped model from a pd.DataFrame (a group dataframe).
This handles the GROUPED code path in _ModelBindingsMapper._instantiate_models.
"""
alias_map = FieldsBindingsMap(model=model)
curried_model = CurryModel(model=model)

for field_name, field_info in model.model_fields.items():
if _is_list_basemodel_type(field_info.annotation):
nested_model, *_ = get_args(field_info.annotation)
value = self._get_unique_models(
self._instantiate_models(df, nested_model)
)

binding_key: str = _get_key_from_metadata(v, default=k)
value = _collect_values_from_bindings(binding_key, applicable_bindings)

elif isinstance(v.annotation, type(BaseModel)):
nested_model = v.annotation
value = nested_model(
**dict(self._generate_binding_pairs(nested_model, **kwargs))
elif _is_list_type(field_info.annotation):
## re-check, but looks good
# value = list(dict.fromkeys(df[alias_map[field_name]]))
value = list(dict.fromkeys(df[alias_map[field_name]].dropna()))
elif isinstance(nested_model := field_info.annotation, type(BaseModel)):
first_row = df.iloc[0]
value = self._instantiate_ungrouped_model_from_row(
first_row,
nested_model, # type: ignore
)
else:
binding_key: str = _get_key_from_metadata(v, default=k)
value = kwargs.get(binding_key, v.default)
first_row = df.iloc[0]
value = first_row.get(alias_map[field_name]) or field_info.default

curried_model(**{field_name: value})

yield k, value
model_instance = curried_model()
assert isinstance(model_instance, model) # type narrow
return model_instance
51 changes: 50 additions & 1 deletion rdfproxy/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import UserDict
from collections.abc import Callable
from functools import partial
from typing import TypeVar
from typing import Any, Generic, Self, TypeVar

from rdfproxy.utils._types import _TModelInstance
from rdfproxy.utils._types import SPARQLBinding
Expand Down Expand Up @@ -70,3 +70,52 @@ def __call__(self, query) -> str:
if tkwargs := {k: v for k, v in self.kwargs.items() if v is not None}:
return partial(self.f, **tkwargs)(query)
return query


class CurryModel(Generic[_TModelInstance]):
"""Constructor for currying a Pydantic Model.
A CurryModel instance can be called with kwargs which are run against
the respective model field validators and kept in a kwargs cache.
Once the model can be instantiated, calling a CurryModel object will
instantiate the Pydantic model and return the model instance.
If the eager flag is True (default), model field default values are
added to the cache automatically, which means that models can be instantiated
as soon possible, i.e. as soon as all /required/ field values are provided.
"""

def __init__(self, model: type[_TModelInstance], eager: bool = True) -> None:
self.model = model
self.eager = eager

self._kwargs_cache: dict = (
{k: v.default for k, v in model.model_fields.items() if not v.is_required()}
if eager
else {}
)

def __repr__(self): # pragma: no cover
return f"CurryModel object {self._kwargs_cache}"

@staticmethod
def _validate_field(model: type[_TModelInstance], field: str, value: Any) -> Any:
"""Validate value for a single field given a model.
Note: Using a TypeVar for value is not possible here,
because Pydantic might coerce values (if not not in Strict Mode).
"""
result = model.__pydantic_validator__.validate_assignment(
model.model_construct(), field, value
)
return result

def __call__(self, **kwargs: Any) -> Self | _TModelInstance:
for k, v in kwargs.items():
self._validate_field(self.model, k, v)

self._kwargs_cache.update(kwargs)

if self.model.model_fields.keys() == self._kwargs_cache.keys():
return self.model(**self._kwargs_cache)
return self

0 comments on commit f133519

Please sign in to comment.