Skip to content

Commit

Permalink
[pythonic config] Update config conversion logic to rely on annotatio…
Browse files Browse the repository at this point in the history
…n rather than Pydantic model type
  • Loading branch information
benpankow committed Sep 28, 2023
1 parent 23c3411 commit b6e711d
Show file tree
Hide file tree
Showing 15 changed files with 147 additions and 142 deletions.
2 changes: 1 addition & 1 deletion docs/content/api/modules.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/content/api/searchindex.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/content/api/sections.json

Large diffs are not rendered by default.

Binary file modified docs/next/public/objects.inv
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
'configFields': list([
dict({
'configType': dict({
'key': 'StringSourceType',
'key': 'Noneable.StringSourceType',
}),
'description': None,
'name': 'a_string',
}),
dict({
'configType': dict({
'key': 'StringSourceType',
'key': 'Noneable.StringSourceType',
}),
'description': None,
'name': 'an_unset_string',
Expand Down Expand Up @@ -42,14 +42,14 @@
'configFields': list([
dict({
'configType': dict({
'key': 'StringSourceType',
'key': 'Noneable.StringSourceType',
}),
'description': None,
'name': 'a_string',
}),
dict({
'configType': dict({
'key': 'StringSourceType',
'key': 'Noneable.StringSourceType',
}),
'description': None,
'name': 'an_unset_string',
Expand Down Expand Up @@ -138,14 +138,14 @@
'configFields': list([
dict({
'configType': dict({
'key': 'StringSourceType',
'key': 'Noneable.StringSourceType',
}),
'description': None,
'name': 'a_string',
}),
dict({
'configType': dict({
'key': 'StringSourceType',
'key': 'Noneable.StringSourceType',
}),
'description': None,
'name': 'an_unset_string',
Expand All @@ -170,14 +170,14 @@
'configFields': list([
dict({
'configType': dict({
'key': 'StringSourceType',
'key': 'Noneable.StringSourceType',
}),
'description': None,
'name': 'a_string',
}),
dict({
'configType': dict({
'key': 'StringSourceType',
'key': 'Noneable.StringSourceType',
}),
'description': None,
'name': 'an_unset_string',
Expand All @@ -202,14 +202,14 @@
'configFields': list([
dict({
'configType': dict({
'key': 'StringSourceType',
'key': 'Noneable.StringSourceType',
}),
'description': None,
'name': 'a_string',
}),
dict({
'configType': dict({
'key': 'StringSourceType',
'key': 'Noneable.StringSourceType',
}),
'description': None,
'name': 'an_unset_string',
Expand Down
15 changes: 11 additions & 4 deletions python_modules/dagster/dagster/_config/pythonic_config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .attach_other_object_to_context import (
IAttachDifferentObjectToOpContext as IAttachDifferentObjectToOpContext,
)
from .conversion_utils import _convert_pydantic_field, _is_pydantic_field_required, safe_is_subclass
from .conversion_utils import _convert_pydantic_field, safe_is_subclass
from .typing_utils import BaseConfigMeta

try:
Expand Down Expand Up @@ -163,6 +163,9 @@ def __init__(self, **config_dict) -> None:
modified_data = {}
for key, value in config_dict.items():
field = self.__fields__.get(key)
if field and not field.required and value is None:
continue

if field and field.field_info.discriminator:
nested_dict = value

Expand All @@ -186,6 +189,11 @@ def __init__(self, **config_dict) -> None:
}
else:
modified_data[key] = value

for key, field in self.__fields__.items():
if field.required and key not in modified_data:
modified_data[key] = None

super().__init__(**modified_data)

def _convert_to_config_dictionary(self) -> Mapping[str, Any]:
Expand Down Expand Up @@ -213,11 +221,10 @@ def _get_non_none_public_field_values(self) -> Mapping[str, Any]:
if self._is_field_internal(key):
continue
field = self.__fields__.get(key)
if field and value is None and not _is_pydantic_field_required(field):
continue

if field:
output[field.alias] = value
alias = field.alias
output[alias] = value
else:
output[key] = value
return output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@
from enum import Enum
from typing import (
Any,
Dict,
List,
Mapping,
Optional,
Type,
TypeVar,
)

from pydantic import ConstrainedFloat, ConstrainedInt, ConstrainedStr
from typing_extensions import TypeAlias
from typing_extensions import Annotated, get_args, get_origin

from dagster import (
Enum as DagsterEnum,
)
from dagster._config.config_type import (
Array,
ConfigFloatInstance,
ConfigType,
Noneable,
)
Expand Down Expand Up @@ -45,10 +47,6 @@ class cached_property:


from pydantic.fields import (
SHAPE_DICT,
SHAPE_LIST,
SHAPE_MAPPING,
SHAPE_SINGLETON,
ModelField,
)

Expand All @@ -60,7 +58,7 @@ class cached_property:
convert_potential_field,
)

from .inheritance_utils import safe_is_subclass
from .type_check_utils import is_optional, safe_is_subclass


# This is from https://github.com/dagster-io/dagster/pull/11470
Expand Down Expand Up @@ -114,60 +112,20 @@ def _curry_config_schema(schema_field: Field, data: Any) -> DefinitionConfigSche
TResValue = TypeVar("TResValue")


PydanticShapeType: TypeAlias = int

MAPPING_TYPES = {SHAPE_MAPPING, SHAPE_DICT}
MAPPING_KEY_TYPE_TO_SCALAR = {
StringSource: str,
IntSource: int,
BoolSource: bool,
ConfigFloatInstance: float,
}


def _wrap_config_type(
shape_type: PydanticShapeType,
key_type: Optional[ConfigType],
config_type: ConfigType,
) -> ConfigType:
"""Based on a Pydantic shape type, wraps a config type in the appropriate Dagster config wrapper.
For example, if the shape type is a Pydantic list, the config type will be wrapped in an Array.
"""
if shape_type == SHAPE_SINGLETON:
return config_type
elif shape_type == SHAPE_LIST:
return Array(config_type)
elif shape_type in MAPPING_TYPES:
if key_type not in MAPPING_KEY_TYPE_TO_SCALAR:
raise NotImplementedError(
f"Pydantic shape type is a mapping, but key type {key_type} is not a valid "
"Map key type. Valid Map key types are: "
f"{', '.join([str(t) for t in MAPPING_KEY_TYPE_TO_SCALAR.keys()])}."
)
return Map(MAPPING_KEY_TYPE_TO_SCALAR[key_type], config_type)
else:
raise NotImplementedError(f"Pydantic shape type {shape_type} not supported.")


def _get_inner_field_if_exists(
shape_type: PydanticShapeType, field: ModelField
) -> Optional[ModelField]:
def _get_inner_field_if_exists(shape_type: Type, field: ModelField) -> Optional[ModelField]:
"""Grabs the inner Pydantic field type for a data structure such as a list or dictionary.
Returns None for types which have no inner field.
"""
# See https://github.com/pydantic/pydantic/blob/v1.10.3/pydantic/fields.py#L758 for
# where sub_fields is set.
if shape_type == SHAPE_SINGLETON:
return None
elif shape_type == SHAPE_LIST:
# List has a single subfield, which is the type of the list elements.
return check.not_none(field.sub_fields)[0]
elif shape_type in MAPPING_TYPES:
# Mapping has a single subfield, which is the type of the mapping values.
return check.not_none(field.sub_fields)[0]
else:
raise NotImplementedError(f"Pydantic shape type {shape_type} not supported.")
if safe_is_subclass(get_origin(shape_type), list):
return check.not_none(get_args(shape_type))[0]
if safe_is_subclass(get_origin(shape_type), dict) or safe_is_subclass(
get_origin(shape_type), Mapping
):
return check.not_none(get_args(shape_type))[1]
return None


def _convert_pydantic_field(pydantic_field: ModelField, model_cls: Optional[Type] = None) -> Field:
Expand All @@ -181,49 +139,26 @@ def _convert_pydantic_field(pydantic_field: ModelField, model_cls: Optional[Type
"""
from .config import Config, infer_schema_from_config_class

key_type = (
_config_type_for_pydantic_field(pydantic_field.key_field)
if pydantic_field.key_field
else None
)
if pydantic_field.field_info.discriminator:
return _convert_pydantic_descriminated_union_field(pydantic_field)

if safe_is_subclass(pydantic_field.type_, Config):
field_type = pydantic_field.annotation
if safe_is_subclass(field_type, Config):
inferred_field = infer_schema_from_config_class(
pydantic_field.type_,
field_type,
description=pydantic_field.field_info.description,
)
wrapped_config_type = _wrap_config_type(
shape_type=pydantic_field.shape,
config_type=inferred_field.config_type,
key_type=key_type,
)
return Field(
config=(
Noneable(wrapped_config_type) if pydantic_field.allow_none else wrapped_config_type
),
description=inferred_field.description,
is_required=_is_pydantic_field_required(pydantic_field),
)
return inferred_field
else:
# For certain data structure types, we need to grab the inner Pydantic field (e.g. List type)
inner_field = _get_inner_field_if_exists(pydantic_field.shape, pydantic_field)
if inner_field:
config_type = _convert_pydantic_field(inner_field, model_cls=model_cls).config_type
else:
config_type = _config_type_for_pydantic_field(pydantic_field)

wrapped_config_type = _wrap_config_type(
shape_type=pydantic_field.shape, config_type=config_type, key_type=key_type
)
if not pydantic_field.required and not is_optional(field_type):
field_type = Optional[field_type]

config_type = _config_type_for_type_on_pydantic_field(field_type)

return Field(
config=(
Noneable(wrapped_config_type) if pydantic_field.allow_none else wrapped_config_type
),
config=config_type,
description=pydantic_field.field_info.description,
is_required=_is_pydantic_field_required(pydantic_field),
is_required=pydantic_field.required and not is_optional(field_type),
default_value=(
pydantic_field.default
if pydantic_field.default is not None
Expand All @@ -232,17 +167,6 @@ def _convert_pydantic_field(pydantic_field: ModelField, model_cls: Optional[Type
)


def _config_type_for_pydantic_field(pydantic_field: ModelField) -> ConfigType:
"""Generates a Dagster ConfigType from a Pydantic field.
Args:
pydantic_field (ModelField): The Pydantic field to convert.
"""
return _config_type_for_type_on_pydantic_field(
pydantic_field.type_,
)


def _config_type_for_type_on_pydantic_field(
potential_dagster_type: Any,
) -> ConfigType:
Expand All @@ -251,6 +175,9 @@ def _config_type_for_type_on_pydantic_field(
Args:
potential_dagster_type (Any): The Python type of the Pydantic field.
"""
while get_origin(potential_dagster_type) == Annotated:
potential_dagster_type = get_args(potential_dagster_type)[0]

# special case pydantic constrained types to their source equivalents
if safe_is_subclass(potential_dagster_type, ConstrainedStr):
return StringSource
Expand All @@ -260,6 +187,31 @@ def _config_type_for_type_on_pydantic_field(
elif safe_is_subclass(potential_dagster_type, ConstrainedInt):
return IntSource

if safe_is_subclass(get_origin(potential_dagster_type), List):
list_inner_type = get_args(potential_dagster_type)[0]
return Array(_config_type_for_type_on_pydantic_field(list_inner_type))
elif is_optional(potential_dagster_type):
optional_inner_type = next(
arg for arg in get_args(potential_dagster_type) if arg is not type(None)
)
return Noneable(_config_type_for_type_on_pydantic_field(optional_inner_type))
elif safe_is_subclass(get_origin(potential_dagster_type), Dict) or safe_is_subclass(
get_origin(potential_dagster_type), Mapping
):
key_type, value_type = get_args(potential_dagster_type)
return Map(
key_type,
_config_type_for_type_on_pydantic_field(value_type),
)

from .config import Config, infer_schema_from_config_class

if safe_is_subclass(potential_dagster_type, Config):
inferred_field = infer_schema_from_config_class(
potential_dagster_type,
)
return inferred_field.config_type

if safe_is_subclass(potential_dagster_type, Enum):
return DagsterEnum.from_python_enum_direct_values(potential_dagster_type)

Expand All @@ -274,19 +226,6 @@ def _config_type_for_type_on_pydantic_field(
return convert_potential_field(potential_dagster_type).config_type


def _is_pydantic_field_required(pydantic_field: ModelField) -> bool:
# required is of type BoolUndefined = Union[bool, UndefinedType] in Pydantic

if isinstance(pydantic_field.required, bool):
return pydantic_field.required

raise Exception(
"pydantic.field.required is their UndefinedType sentinel value which we "
"do not fully understand the semantics of right now. For the time being going "
"to throw an error to figure see when we actually encounter this state."
)


def _convert_pydantic_descriminated_union_field(pydantic_field: ModelField) -> Field:
"""Builds a Selector config field from a Pydantic field which is a descriminated union.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
)
from .config import Config
from .conversion_utils import TResValue
from .inheritance_utils import safe_is_subclass
from .resource import (
AllowDelayedDependencies,
ConfigurableResourceFactory,
Expand All @@ -39,6 +38,7 @@
ResourceWithKeyMapping,
Self,
)
from .type_check_utils import safe_is_subclass

try:
from functools import cached_property # type: ignore # (py37 compat)
Expand Down
Loading

0 comments on commit b6e711d

Please sign in to comment.