Skip to content

Commit

Permalink
[3/n][pythonic config] Layer in Pydantic compat abstraction layer"
Browse files Browse the repository at this point in the history
  • Loading branch information
benpankow committed Sep 28, 2023
1 parent 039c24b commit 8924ae9
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 46 deletions.
2 changes: 1 addition & 1 deletion docs/content/api/modules.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/content/api/searchindex.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/content/api/sections.json

Large diffs are not rendered by default.

102 changes: 77 additions & 25 deletions python_modules/dagster/dagster/_config/pythonic_config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,28 @@
from typing import (
Any,
Dict,
List,
Mapping,
Optional,
Set,
Type,
cast,
)

from pydantic import BaseModel, Extra
from pydantic.fields import (
ModelField,
)
from pydantic import BaseModel
from typing_extensions import TypeVar

import dagster._check as check
from dagster import (
Field,
Field as DagsterField,
Shape,
)
from dagster._config.field_utils import Permissive
from dagster._config.field_utils import (
EnvVar,
IntEnvVar,
Permissive,
)
from dagster._core.definitions.definition_config_schema import (
DefinitionConfigSchema,
)
Expand All @@ -37,6 +40,12 @@
IAttachDifferentObjectToOpContext as IAttachDifferentObjectToOpContext,
)
from .conversion_utils import _convert_pydantic_field, safe_is_subclass
from .pydantic_compat_layer import (
USING_PYDANTIC_2,
ModelFieldCompat,
model_config,
model_fields,
)
from .typing_utils import BaseConfigMeta

try:
Expand Down Expand Up @@ -66,7 +75,10 @@ class Config:
# Necessary to allow for caching decorators
arbitrary_types_allowed = True
# Avoid pydantic reading a cached property class as part of the schema
keep_untouched = (cached_property,)
if USING_PYDANTIC_2:
ignored_types = (cached_property,)
else:
keep_untouched = (cached_property,)
# Ensure the class is serializable, for caching purposes
frozen = True

Expand All @@ -85,7 +97,11 @@ def __setattr__(self, name: str, value: Any):
return super().__setattr__(name, value)
except (TypeError, ValueError) as e:
clsname = self.__class__.__name__
if "is immutable and does not support item assignment" in str(e):
if "Instance is frozen" in str( # Pydantic 2.x error
e
) or "is immutable and does not support item assignment" in str( # Pydantic 1.x error
e
):
if isinstance(self, ConfigurableResourceFactory):
raise DagsterInvalidInvocationError(
f"'{clsname}' is a Pythonic resource and does not support item assignment,"
Expand Down Expand Up @@ -125,6 +141,31 @@ def _is_field_internal(self, name: str) -> bool:
return name.endswith(INTERNAL_MARKER)


T = TypeVar("T")


def ensure_env_vars_set_post_init(set_value: T, input_value: Any) -> T:
"""Pydantic 2.x utility. Ensures that Pydantic field values are set to the appropriate
EnvVar or IntEnvVar objects post-model-instantiation, since Pydantic 2.x will cast
EnvVar or IntEnvVar values to raw strings or ints as part of the model instantiation process.
"""
if isinstance(set_value, dict) and isinstance(input_value, dict):
for key, value in input_value.items():
if isinstance(value, (EnvVar, IntEnvVar)):
set_value[key] = value
elif isinstance(value, (dict, list)):
set_value[key] = ensure_env_vars_set_post_init(set_value[key], value)
if isinstance(set_value, List) and isinstance(input_value, List):
for i in range(len(set_value)):
value = input_value[i]
if isinstance(value, (EnvVar, IntEnvVar)):
set_value[i] = value
elif isinstance(value, (dict, list)):
set_value[i] = ensure_env_vars_set_post_init(set_value[i], value)

return set_value


class Config(MakeConfigCacheable, metaclass=BaseConfigMeta):
"""Base class for Dagster configuration models, used to specify config schema for
ops and assets. Subclasses :py:class:`pydantic.BaseModel`.
Expand Down Expand Up @@ -163,13 +204,14 @@ def __init__(self, **config_dict) -> None:
modified_data = {}
for key, value in config_dict.items():
field = self.__fields__.get(key)
if field and not field.required and value is None:
field = model_fields(self).get(key)
if field and not field.is_required() and value is None:
continue

if field and field.field_info.discriminator:
if field and field.discriminator:
nested_dict = value

discriminator_key = check.not_none(field.discriminator_key)
discriminator_key = check.not_none(field.discriminator)
if isinstance(value, Config):
nested_dict = _discriminated_union_config_dict_to_selector_config_dict(
discriminator_key,
Expand All @@ -190,7 +232,13 @@ def __init__(self, **config_dict) -> None:
else:
modified_data[key] = value

for key, field in model_fields(self).items():
if field.is_required() and key not in modified_data:
modified_data[key] = None

super().__init__(**modified_data)
if USING_PYDANTIC_2:
self.__dict__ = ensure_env_vars_set_post_init(self.__dict__, modified_data)

def _convert_to_config_dictionary(self) -> Mapping[str, Any]:
"""Converts this Config object to a Dagster config dictionary, in the same format as the dictionary
Expand All @@ -201,7 +249,7 @@ def _convert_to_config_dictionary(self) -> Mapping[str, Any]:
"""
public_fields = self._get_non_none_public_field_values()
return {
k: _config_value_to_dict_representation(self.__fields__.get(k), v)
k: _config_value_to_dict_representation(model_fields(self).get(k), v)
for k, v in public_fields.items()
}

Expand All @@ -216,10 +264,10 @@ def _get_non_none_public_field_values(self) -> Mapping[str, Any]:
for key, value in self.__dict__.items():
if self._is_field_internal(key):
continue
field = self.__fields__.get(key)
field = model_fields(self).get(key)

if field:
alias = field.alias
alias = field.alias or key
output[alias] = value
else:
output[key] = value
Expand Down Expand Up @@ -255,7 +303,7 @@ def _discriminated_union_config_dict_to_selector_config_dict(
return wrapped_dict


def _config_value_to_dict_representation(field: Optional[ModelField], value: Any):
def _config_value_to_dict_representation(field: Optional[ModelFieldCompat], value: Any):
"""Converts a config value to a dictionary representation. If a field is provided, it will be used
to determine the appropriate dictionary representation in the case of discriminated unions.
"""
Expand All @@ -270,11 +318,11 @@ def _config_value_to_dict_representation(field: Optional[ModelField], value: Any
elif isinstance(value, IntEnvVar):
return {"env": value.name}
if isinstance(value, Config):
if field and field.discriminator_key:
if field and field.discriminator:
return {
k: v
for k, v in _discriminated_union_config_dict_to_selector_config_dict(
field.discriminator_key,
field.discriminator,
value._convert_to_config_dictionary(), # noqa: SLF001
).items()
}
Expand Down Expand Up @@ -331,7 +379,7 @@ def infer_schema_from_config_class(
fields_to_omit: Optional[Set[str]] = None,
) -> Field:
from .config import Config
from .resource import ConfigurableResourceFactory
from .resource import ConfigurableResourceFactory, _is_annotated_as_resource_type

"""Parses a structured config class and returns a corresponding Dagster config Field."""
fields_to_omit = fields_to_omit or set()
Expand All @@ -342,29 +390,33 @@ def infer_schema_from_config_class(
)

fields: Dict[str, Field] = {}
for pydantic_field in model_cls.__fields__.values():
if pydantic_field.name not in fields_to_omit:
if isinstance(pydantic_field.default, Field):
for key, pydantic_field_info in model_fields(model_cls).items():
if _is_annotated_as_resource_type(
pydantic_field_info.annotation, pydantic_field_info.metadata
):
continue

alias = pydantic_field_info.alias if pydantic_field_info.alias else key
if key not in fields_to_omit:
if isinstance(pydantic_field_info.default, Field):
raise DagsterInvalidDefinitionError(
"Using 'dagster.Field' is not supported within a Pythonic config or resource"
" definition. 'dagster.Field' should only be used in legacy Dagster config"
" schemas. Did you mean to use 'pydantic.Field' instead?"
)

try:
fields[pydantic_field.alias] = _convert_pydantic_field(
pydantic_field,
)
fields[alias] = _convert_pydantic_field(pydantic_field_info)
except DagsterInvalidConfigDefinitionError as e:
raise DagsterInvalidPythonicConfigDefinitionError(
config_class=model_cls,
field_name=pydantic_field.name,
field_name=key,
invalid_type=e.current_value,
is_resource=model_cls is not None
and safe_is_subclass(model_cls, ConfigurableResourceFactory),
)

shape_cls = Permissive if model_cls.__config__.extra == Extra.allow else Shape
shape_cls = Permissive if model_config(model_cls).get("extra") == "allow" else Shape

docstring = model_cls.__doc__.strip() if model_cls.__doc__ else None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Optional,
Type,
TypeVar,
Union,
)

from pydantic import ConstrainedFloat, ConstrainedInt, ConstrainedStr
Expand Down Expand Up @@ -58,6 +59,7 @@ class cached_property:
convert_potential_field,
)

from .pydantic_compat_layer import ModelFieldCompat, model_fields
from .type_check_utils import is_optional, safe_is_subclass


Expand Down Expand Up @@ -128,7 +130,9 @@ def _get_inner_field_if_exists(shape_type: Type, field: ModelField) -> Optional[
return None


def _convert_pydantic_field(pydantic_field: ModelField, model_cls: Optional[Type] = None) -> Field:
def _convert_pydantic_field(
pydantic_field: ModelFieldCompat, model_cls: Optional[Type] = None
) -> Field:
"""Transforms a Pydantic field into a corresponding Dagster config field.
Expand All @@ -139,26 +143,26 @@ def _convert_pydantic_field(pydantic_field: ModelField, model_cls: Optional[Type
"""
from .config import Config, infer_schema_from_config_class

if pydantic_field.field_info.discriminator:
if pydantic_field.discriminator:
return _convert_pydantic_descriminated_union_field(pydantic_field)

field_type = pydantic_field.annotation
if safe_is_subclass(field_type, Config):
inferred_field = infer_schema_from_config_class(
field_type,
description=pydantic_field.field_info.description,
description=pydantic_field.description,
)
return inferred_field
else:
if not pydantic_field.required and not is_optional(field_type):
if not pydantic_field.is_required() and not is_optional(field_type):
field_type = Optional[field_type]

config_type = _config_type_for_type_on_pydantic_field(field_type)

return Field(
config=config_type,
description=pydantic_field.field_info.description,
is_required=pydantic_field.required and not is_optional(field_type),
description=pydantic_field.description,
is_required=pydantic_field.is_required() and not is_optional(field_type),
default_value=(
pydantic_field.default
if pydantic_field.default is not None
Expand Down Expand Up @@ -226,7 +230,7 @@ def _config_type_for_type_on_pydantic_field(
return convert_potential_field(potential_dagster_type).config_type


def _convert_pydantic_descriminated_union_field(pydantic_field: ModelField) -> Field:
def _convert_pydantic_descriminated_union_field(pydantic_field: ModelFieldCompat) -> Field:
"""Builds a Selector config field from a Pydantic field which is a descriminated union.
For example:
Expand All @@ -253,24 +257,30 @@ class OpConfigWithUnion(Config):
"""
from .config import Config, infer_schema_from_config_class

sub_fields_mapping = pydantic_field.sub_fields_mapping
if not sub_fields_mapping or not all(
issubclass(pydantic_field.type_, Config) for pydantic_field in sub_fields_mapping.values()
):
field_type = pydantic_field.annotation

if not get_origin(field_type) == Union:
raise DagsterInvalidDefinitionError("Descriminated union must be a Union type.")

sub_fields = get_args(field_type)
if not all(issubclass(sub_field, Config) for sub_field in sub_fields):
raise NotImplementedError("Descriminated unions with non-Config types are not supported.")

sub_fields_mapping = {}
for sub_field in sub_fields:
sub_field_annotation = model_fields(sub_field)[pydantic_field.discriminator].annotation

for sub_field_key in get_args(sub_field_annotation):
sub_fields_mapping[sub_field_key] = sub_field
# First, we generate a mapping between the various discriminator values and the
# Dagster config fields that correspond to them. We strip the discriminator key
# from the fields, since the user should not have to specify it.

assert pydantic_field.sub_fields_mapping
dagster_config_field_mapping = {
discriminator_value: infer_schema_from_config_class(
field.type_,
field,
fields_to_omit=(
{pydantic_field.field_info.discriminator}
if pydantic_field.field_info.discriminator
else None
{pydantic_field.discriminator} if pydantic_field.discriminator else None
),
)
for discriminator_value, field in sub_fields_mapping.items()
Expand Down
Loading

0 comments on commit 8924ae9

Please sign in to comment.