From 4c102b0666988ede37352a8a72996edbe2429c8a Mon Sep 17 00:00:00 2001 From: benpankow Date: Sun, 1 Oct 2023 16:35:04 -0700 Subject: [PATCH] pyright --- .../_config/pythonic_config/conversion_utils.py | 15 ++++++++------- .../pythonic_config/pydantic_compat_layer.py | 16 ++++++++++------ 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/python_modules/dagster/dagster/_config/pythonic_config/conversion_utils.py b/python_modules/dagster/dagster/_config/pythonic_config/conversion_utils.py index ee2e4e3860d56..b5afbd4f934ba 100644 --- a/python_modules/dagster/dagster/_config/pythonic_config/conversion_utils.py +++ b/python_modules/dagster/dagster/_config/pythonic_config/conversion_utils.py @@ -268,6 +268,7 @@ class OpConfigWithUnion(Config): from .config import Config, infer_schema_from_config_class field_type = pydantic_field.annotation + discriminator = pydantic_field.discriminator if pydantic_field.discriminator else None if not get_origin(field_type) == Union: raise DagsterInvalidDefinitionError("Discriminated union must be a Union type.") @@ -277,11 +278,13 @@ class OpConfigWithUnion(Config): raise NotImplementedError("Discriminated unions with non-Config types are not supported.") sub_fields_mapping = {} - for sub_field in sub_fields: - sub_field_annotation = model_fields(sub_field)[pydantic_field.discriminator].annotation + if discriminator: + for sub_field in sub_fields: + sub_field_annotation = model_fields(sub_field)[discriminator].annotation + + for sub_field_key in get_args(sub_field_annotation): + sub_fields_mapping[sub_field_key] = sub_field - for sub_field_key in get_args(sub_field_annotation): - sub_fields_mapping[sub_field_key] = sub_field # First, we generate a mapping between the various discriminator values and the # Dagster config fields that correspond to them. We strip the discriminator key # from the fields, since the user should not have to specify it. @@ -289,9 +292,7 @@ class OpConfigWithUnion(Config): dagster_config_field_mapping = { discriminator_value: infer_schema_from_config_class( field, - fields_to_omit=( - {pydantic_field.discriminator} if pydantic_field.discriminator else None - ), + fields_to_omit=({discriminator} if discriminator else None), ) for discriminator_value, field in sub_fields_mapping.items() } diff --git a/python_modules/dagster/dagster/_config/pythonic_config/pydantic_compat_layer.py b/python_modules/dagster/dagster/_config/pythonic_config/pydantic_compat_layer.py index 5aff1977c63bb..d9786740848c9 100644 --- a/python_modules/dagster/dagster/_config/pythonic_config/pydantic_compat_layer.py +++ b/python_modules/dagster/dagster/_config/pythonic_config/pydantic_compat_layer.py @@ -1,4 +1,5 @@ from typing import ( + TYPE_CHECKING, Any, Dict, List, @@ -13,6 +14,9 @@ IAttachDifferentObjectToOpContext as IAttachDifferentObjectToOpContext, ) +if TYPE_CHECKING: + from pydantic.fields import ModelField + USING_PYDANTIC_2 = int(pydantic.__version__.split(".")[0]) >= 2 @@ -21,8 +25,8 @@ class ModelFieldCompat: metadata and annotations between Pydantic 1 and 2. """ - def __init__(self, field): - self.field = field + def __init__(self, field) -> None: + self.field: "ModelField" = field @property def annotation(self) -> Type: @@ -50,16 +54,16 @@ def description(self) -> Optional[str]: def is_required(self) -> bool: if USING_PYDANTIC_2: - if hasattr(self.field, "is_required"): - return self.field.is_required() + return self.field.is_required() # type: ignore else: - return self.field.required + # required is of type 'BoolUndefined', which is a Union of bool and pydantic 1.x's UndefinedType + return self.field.required if isinstance(self.field.required, bool) else False @property def discriminator(self) -> Optional[str]: if USING_PYDANTIC_2: if hasattr(self.field, "discriminator"): - return self.field.discriminator if hasattr(self.field, "discriminator") else None + return self.field.discriminator if hasattr(self.field, "discriminator") else None # type: ignore else: return getattr(self.field, "discriminator_key", None)