Skip to content

Commit

Permalink
[model] namedtuple transform decorator (#22043)
Browse files Browse the repository at this point in the history
Improve the ergonomics of working with `NamedTuple` by using a custom
decorator that:
* hides the gross NamedTuple instantiation that you have to do to
subclass or override `__new__`
* bans using the object like a plain tuple (iteration and [] access) 
* optionally creating a type checked `__new__` based on your field set

## How I Tested These Changes

added tests
  • Loading branch information
alangenfeld authored Jun 7, 2024
1 parent c5ad5ba commit dee425f
Show file tree
Hide file tree
Showing 5 changed files with 545 additions and 13 deletions.
15 changes: 13 additions & 2 deletions python_modules/dagster/dagster/_model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, Hashable, Optional
from typing import TYPE_CHECKING, Any, Dict, Hashable, Optional, TypeVar

from pydantic import BaseModel, ConfigDict, PrivateAttr
from typing_extensions import Annotated, Self, TypeAlias, TypeVar
from typing_extensions import Annotated, Self, TypeAlias

# decorator based models public API
from .decorator import (
IHaveNew as IHaveNew,
LegacyNamedTupleMixin as LegacyNamedTupleMixin,
as_dict as as_dict,
copy as copy,
dagster_model as dagster_model,
dagster_model_custom as dagster_model_custom,
has_generated_new as has_generated_new,
is_dagster_model as is_dagster_model,
)
from .pydantic_compat_layer import USING_PYDANTIC_2

if USING_PYDANTIC_2:
Expand Down
264 changes: 264 additions & 0 deletions python_modules/dagster/dagster/_model/decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
from abc import ABC
from functools import cached_property, partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
Mapping,
NamedTuple,
Optional,
Type,
TypeVar,
Union,
overload,
)

from typing_extensions import dataclass_transform

import dagster._check as check
from dagster._check import EvalContext, build_check_call

TType = TypeVar("TType", bound=Type)
TVal = TypeVar("TVal")


_MODEL_MARKER_VALUE = object()
_MODEL_MARKER_FIELD = (
"__checkrepublic__" # "I do want to release this as checkrepublic one day" - schrockn
)
_GENERATED_NEW = "__checked_new__"


def _namedtuple_model_transform(
cls: TType,
*,
checked: bool,
with_new: bool,
decorator_frames: int,
) -> TType:
"""Transforms the input class in to one that inherits a generated NamedTuple base class
and:
* bans tuple methods that don't make sense for a model object
* creates a run time checked __new__ (optional).
"""
base = NamedTuple(f"_{cls.__name__}", cls.__annotations__.items())

if checked:
orig_new = base.__new__
checks_builder = LazyCheckBuilder(
cls.__annotations__,
EvalContext.capture_from_frame(1 + decorator_frames),
1 if with_new else 0,
)

def __checked_new__(cls: TType, **kwargs):
for key, fn in checks_builder.checks.items():
fn(kwargs[key])

return orig_new(cls, **kwargs)

base.__new__ = __checked_new__ # type: ignore # unhappy with dropping positional args

if with_new and cls.__new__ is object.__new__:
# verify the alignment since it impacts frame capture
check.failed(f"Expected __new__ on {cls}, add it or switch from the _with_new decorator.")

new_type = type(
cls.__name__,
(cls, base),
{ # these will override an implementation on the class if it exists
"__iter__": _banned_iter,
"__getitem__": _banned_idx,
"__hidden_iter__": base.__iter__,
_MODEL_MARKER_FIELD: _MODEL_MARKER_VALUE,
"__annotations__": cls.__annotations__,
},
)

return new_type # type: ignore


@overload
def dagster_model(
cls: TType,
) -> TType: ... # Overload for using decorator with no ().


@overload
def dagster_model(
*,
checked: bool = True,
) -> Callable[[TType], TType]: ... # Overload for using decorator used with args.


@dataclass_transform(
kw_only_default=True,
frozen_default=True,
)
def dagster_model(
cls: Optional[TType] = None,
*,
checked: bool = True,
) -> Union[TType, Callable[[TType], TType]]:
"""A class decorator that will create an immutable model class based on the defined fields.
Args:
checked: Whether or not to generate runtime type checked construction.
enable_cached_method: Whether or not to support object instance level caching using @cached_method.
serdes: whitelist this class for serdes, with the defined options if SerdesOptions used.
"""
if cls:
return _namedtuple_model_transform(
cls,
checked=checked,
with_new=False,
decorator_frames=1,
)
else:
return partial(
_namedtuple_model_transform,
checked=checked,
with_new=False,
decorator_frames=0,
)


@overload
def dagster_model_custom(
cls: TType,
) -> TType: ... # Overload for using decorator with no ().


@overload
def dagster_model_custom(
*,
checked: bool = True,
) -> Callable[[TType], TType]: ... # Overload for using decorator used with args.


def dagster_model_custom(
cls: Optional[TType] = None,
*,
checked: bool = True,
) -> Union[TType, Callable[[TType], TType]]:
"""Variant of the dagster_model decorator to use to opt out of the dataclass_transform decorator behavior.
This is often doesn't to be able to override __new__, so the type checker respects your constructor.
@dagster_model_custom
class Coerced(IHaveNew):
name: str
def __new__(cls, name: Optional[str] = None)
if not name:
name = "bob"
return super().__new__(
cls,
name=name,
)
It would have been cool if we could do that with an argument and @overload but
from https://peps.python.org/pep-0681/ " When applied to an overload,
the dataclass_transform decorator still impacts all usage of the function."
"""
if cls:
return _namedtuple_model_transform(
cls,
checked=checked,
with_new=True,
decorator_frames=1,
)
else:
return partial(
_namedtuple_model_transform,
checked=checked,
with_new=True,
decorator_frames=0,
)


class IHaveNew:
"""Marker class to be used when overriding new in @dagster_model_custom classes to prevent
type errors when calling super().__new__.
"""

if TYPE_CHECKING:

def __new__(cls, **kwargs): ...


def is_dagster_model(obj) -> bool:
"""Whether or not this object was produced by a dagster_model decorator."""
return getattr(obj, _MODEL_MARKER_FIELD, None) == _MODEL_MARKER_VALUE


def has_generated_new(obj) -> bool:
return obj.__new__.__name__ == _GENERATED_NEW


def as_dict(obj) -> Mapping[str, Any]:
"""Creates a dict representation of a model."""
if not is_dagster_model(obj):
raise Exception("Only works for @dagster_model decorated classes")

return {key: value for key, value in zip(obj._fields, obj.__hidden_iter__())}


def copy(obj: TVal, **kwargs) -> TVal:
"""Create a copy of this dagster_model instance, with new values specified as key word args."""
return obj.__class__(
**{
**as_dict(obj),
**kwargs,
}
)


class LegacyNamedTupleMixin(ABC):
"""Mixin to ease migration by adding NamedTuple utility methods.
Inherit when converting an existing NamedTuple that has callsites to _replace / _asdict, ie.
@dagster_model
def AssetSubset(LegacyNamedTupleMixin):
asset_key: AssetKey
value: Union[bool, PartitionsSubset]
"""

def _replace(self, **kwargs):
return copy(self, **kwargs)

def _asdict(self):
return as_dict(self)


class LazyCheckBuilder:
# Class object to support building check calls on first use and keeping them.
# This allows resolving ForwardRefs for types that were not available at initial definition.

def __init__(self, field_set: dict, eval_ctx: EvalContext, new_frames: int):
self._field_set = field_set
self._eval_ctx = eval_ctx
self._new_frames = new_frames # how many frames of __new__ there are

@cached_property
def checks(self) -> Mapping[str, Callable[[Any], Any]]:
# update the context with callsite locals/globals to resolve
# ForwardRefs that were unavailable at definition time.

# 3: checks -> __new__ -> callsite (+ optional override __new__ frame)
self._eval_ctx.update_from_frame(3 + self._new_frames)

return {
name: build_check_call(ttype=ttype, name=name, eval_ctx=self._eval_ctx)
for name, ttype in self._field_set.items()
}


def _banned_iter(*args, **kwargs):
raise Exception("Iteration is not allowed on `@dagster_model`s.")


def _banned_idx(*args, **kwargs):
raise Exception("Index access is not allowed on `@dagster_model`s.")
13 changes: 13 additions & 0 deletions python_modules/dagster/dagster/_serdes/serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

import dagster._check as check
import dagster._seven as seven
from dagster._model import as_dict, has_generated_new, is_dagster_model
from dagster._model.pydantic_compat_layer import ModelFieldCompat, model_fields
from dagster._utils import is_named_tuple_instance, is_named_tuple_subclass
from dagster._utils.warnings import disable_dagster_warnings
Expand Down Expand Up @@ -315,6 +316,7 @@ def whitelist_for_serdes(
)(__cls)
else: # decorator passed params
check.opt_class_param(serializer, "serializer", superclass=Serializer)

return _whitelist_for_serdes(
whitelist_map=_WHITELIST_MAP,
serializer=serializer,
Expand Down Expand Up @@ -635,10 +637,17 @@ def get_storage_name(self) -> str:

class NamedTupleSerializer(ObjectSerializer[T_NamedTuple]):
def object_as_mapping(self, value: T_NamedTuple) -> Mapping[str, Any]:
if is_dagster_model(value):
return as_dict(value)

return value._asdict()

@cached_property
def constructor_param_names(self) -> Sequence[str]:
# if its an @dagster_model generated new, just use annotations
if has_generated_new(self.klass):
return list(self.klass.__annotations__.keys())

return list(signature(self.klass.__new__).parameters.keys())


Expand Down Expand Up @@ -1193,6 +1202,10 @@ def _unpack_value(
def _check_serdes_tuple_class_invariants(
klass: Type[NamedTuple], is_pickleable: bool = True
) -> None:
# can skip validation on @dagster_model generated new
if has_generated_new(klass):
return

sig_params = signature(klass.__new__).parameters
dunder_new_params = list(sig_params.values())

Expand Down
Loading

0 comments on commit dee425f

Please sign in to comment.