Skip to content

Commit

Permalink
pyright
Browse files Browse the repository at this point in the history
  • Loading branch information
benpankow committed Oct 10, 2023
1 parent 9c22aa6 commit 35f9b86
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ class OpConfigWithUnion(Config):
from .config import Config, infer_schema_from_config_class

field_type = pydantic_field.annotation
discriminator = pydantic_field.discriminator if pydantic_field.discriminator else None

if not get_origin(field_type) == Union:
raise DagsterInvalidDefinitionError("Discriminated union must be a Union type.")
Expand All @@ -277,21 +278,21 @@ class OpConfigWithUnion(Config):
raise NotImplementedError("Discriminated unions with non-Config types are not supported.")

sub_fields_mapping = {}
for sub_field in sub_fields:
sub_field_annotation = model_fields(sub_field)[pydantic_field.discriminator].annotation
if discriminator:
for sub_field in sub_fields:
sub_field_annotation = model_fields(sub_field)[discriminator].annotation

for sub_field_key in get_args(sub_field_annotation):
sub_fields_mapping[sub_field_key] = sub_field

for sub_field_key in get_args(sub_field_annotation):
sub_fields_mapping[sub_field_key] = sub_field
# First, we generate a mapping between the various discriminator values and the
# Dagster config fields that correspond to them. We strip the discriminator key
# from the fields, since the user should not have to specify it.

dagster_config_field_mapping = {
discriminator_value: infer_schema_from_config_class(
field,
fields_to_omit=(
{pydantic_field.discriminator} if pydantic_field.discriminator else None
),
fields_to_omit=({discriminator} if discriminator else None),
)
for discriminator_value, field in sub_fields_mapping.items()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Expand All @@ -13,6 +14,9 @@
IAttachDifferentObjectToOpContext as IAttachDifferentObjectToOpContext,
)

if TYPE_CHECKING:
from pydantic.fields import ModelField

USING_PYDANTIC_2 = int(pydantic.__version__.split(".")[0]) >= 2


Expand All @@ -21,8 +25,8 @@ class ModelFieldCompat:
metadata and annotations between Pydantic 1 and 2.
"""

def __init__(self, field):
self.field = field
def __init__(self, field) -> None:
self.field: "ModelField" = field

@property
def annotation(self) -> Type:
Expand Down Expand Up @@ -50,16 +54,16 @@ def description(self) -> Optional[str]:

def is_required(self) -> bool:
if USING_PYDANTIC_2:
if hasattr(self.field, "is_required"):
return self.field.is_required()
return self.field.is_required() # type: ignore
else:
return self.field.required
# required is of type 'BoolUndefined', which is a Union of bool and pydantic 1.x's UndefinedType
return self.field.required if isinstance(self.field.required, bool) else False

@property
def discriminator(self) -> Optional[str]:
if USING_PYDANTIC_2:
if hasattr(self.field, "discriminator"):
return self.field.discriminator if hasattr(self.field, "discriminator") else None
return self.field.discriminator if hasattr(self.field, "discriminator") else None # type: ignore
else:
return getattr(self.field, "discriminator_key", None)

Expand Down

0 comments on commit 35f9b86

Please sign in to comment.