Skip to content

Commit

Permalink
use field aliases when serializing pydantic objects (dagster-io#21325)
Browse files Browse the repository at this point in the history
## Summary & Motivation

["An alias is an alternative name for a field, used when serializing and
deserializing data."](https://docs.pydantic.dev/latest/concepts/alias/)

The motivation for this is moving `MetadataValue` subclasses to
`DagsterModel`s, e.g. here:
dagster-io#21324.

## How I Tested These Changes
  • Loading branch information
sryza authored and nikomancy committed May 1, 2024
1 parent 2935842 commit 9ad13c3
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 2 deletions.
19 changes: 17 additions & 2 deletions python_modules/dagster/dagster/_serdes/serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,11 +672,26 @@ def constructor_param_names(self) -> Sequence[str]:

class PydanticModelSerializer(ObjectSerializer[T_PydanticModel]):
def object_as_mapping(self, value: T_PydanticModel) -> Mapping[str, Any]:
return value.__dict__
value_dict = value.__dict__

result = {}
for key, field in self.klass.__fields__.items():
if field.alias is None and (
getattr(field, "serialization_alias", None) is not None
or getattr(field, "validation_alias", None) is not None
):
raise SerializationError(
"Can't serialize pydantic models with serialization or validation aliases. Use "
"the storage_field_names argument to whitelist_for_serdes instead."
)
result_key = field.alias if field.alias else key
result[result_key] = value_dict[key]

return result

@property
def constructor_param_names(self) -> Sequence[str]:
return list(self.klass.__fields__.keys())
return [field.alias or key for key, field in self.klass.__fields__.items()]


class FieldSerializer(Serializer):
Expand Down
82 changes: 82 additions & 0 deletions python_modules/dagster/dagster_tests/general_tests/test_serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
unpack_value,
)
from dagster._serdes.utils import hash_str
from pydantic import Field


def test_deserialize_value_ok():
Expand Down Expand Up @@ -890,3 +891,84 @@ class MyEnt(pydantic.BaseModel):
# can deserialize previous NamedTuples in to future pydantic models
py_dc_ent = deserialize_value(ser_nt_ent, whitelist_map=py_m_env)
assert py_dc_ent


def test_pydantic_alias():
test_env = WhitelistMap.create()

@_whitelist_for_serdes(test_env)
class SomeDagsterModel(DagsterModel):
unaliased_id: int = Field(..., alias="id_alias")
name: str

o = SomeDagsterModel(id_alias=5, name="fdsk")
packed_o = pack_value(o, whitelist_map=test_env)
assert packed_o == {"__class__": "SomeDagsterModel", "id_alias": 5, "name": "fdsk"}
assert unpack_value(packed_o, whitelist_map=test_env, as_type=SomeDagsterModel) == o

ser_o = serialize_value(o, whitelist_map=test_env)
assert deserialize_value(ser_o, whitelist_map=test_env) == o


def test_pydantic_serialization_alias():
test_env = WhitelistMap.create()

@_whitelist_for_serdes(test_env)
class SomeDagsterModel(DagsterModel):
unaliased_id: int = Field(..., serialization_alias="id_alias")
name: str

o = SomeDagsterModel(unaliased_id=5, name="fdsk")
with pytest.raises(
SerializationError,
match="Can't serialize pydantic models with serialization or validation aliases.",
):
serialize_value(o, whitelist_map=test_env)

with pytest.raises(
SerializationError,
match="Can't serialize pydantic models with serialization or validation aliases.",
):
pack_value(o, whitelist_map=test_env)


def test_pydantic_validation_alias():
test_env = WhitelistMap.create()

@_whitelist_for_serdes(test_env)
class SomeDagsterModel(DagsterModel):
unaliased_id: int = Field(..., validation_alias="id_alias")
name: str

o = SomeDagsterModel(id_alias=5, name="fdsk")
with pytest.raises(
SerializationError,
match="Can't serialize pydantic models with serialization or validation aliases.",
):
serialize_value(o, whitelist_map=test_env)

with pytest.raises(
SerializationError,
match="Can't serialize pydantic models with serialization or validation aliases.",
):
pack_value(o, whitelist_map=test_env)


def test_pydantic_alias_generator():
test_env = WhitelistMap.create()

@_whitelist_for_serdes(test_env)
class SomeDagsterModel(DagsterModel):
id: int = Field(...)
name: str

class Config:
alias_generator = lambda field_name: f"{field_name}_alias"

o = SomeDagsterModel(id_alias=5, name_alias="fdsk")
packed_o = pack_value(o, whitelist_map=test_env)
assert packed_o == {"__class__": "SomeDagsterModel", "id_alias": 5, "name_alias": "fdsk"}
assert unpack_value(packed_o, whitelist_map=test_env, as_type=SomeDagsterModel) == o

ser_o = serialize_value(o, whitelist_map=test_env)
assert deserialize_value(ser_o, whitelist_map=test_env) == o

0 comments on commit 9ad13c3

Please sign in to comment.