diff --git a/python_modules/dagster/dagster/_serdes/serdes.py b/python_modules/dagster/dagster/_serdes/serdes.py index fecb4f07b609c..ecfd86739522a 100644 --- a/python_modules/dagster/dagster/_serdes/serdes.py +++ b/python_modules/dagster/dagster/_serdes/serdes.py @@ -672,11 +672,26 @@ def constructor_param_names(self) -> Sequence[str]: class PydanticModelSerializer(ObjectSerializer[T_PydanticModel]): def object_as_mapping(self, value: T_PydanticModel) -> Mapping[str, Any]: - return value.__dict__ + value_dict = value.__dict__ + + result = {} + for key, field in self.klass.__fields__.items(): + if field.alias is None and ( + getattr(field, "serialization_alias", None) is not None + or getattr(field, "validation_alias", None) is not None + ): + raise SerializationError( + "Can't serialize pydantic models with serialization or validation aliases. Use " + "the storage_field_names argument to whitelist_for_serdes instead." + ) + result_key = field.alias if field.alias else key + result[result_key] = value_dict[key] + + return result @property def constructor_param_names(self) -> Sequence[str]: - return list(self.klass.__fields__.keys()) + return [field.alias or key for key, field in self.klass.__fields__.items()] class FieldSerializer(Serializer): diff --git a/python_modules/dagster/dagster_tests/general_tests/test_serdes.py b/python_modules/dagster/dagster_tests/general_tests/test_serdes.py index c4f08d2b3d7d5..e48b4cf0ad236 100644 --- a/python_modules/dagster/dagster_tests/general_tests/test_serdes.py +++ b/python_modules/dagster/dagster_tests/general_tests/test_serdes.py @@ -25,6 +25,7 @@ unpack_value, ) from dagster._serdes.utils import hash_str +from pydantic import Field def test_deserialize_value_ok(): @@ -890,3 +891,84 @@ class MyEnt(pydantic.BaseModel): # can deserialize previous NamedTuples in to future pydantic models py_dc_ent = deserialize_value(ser_nt_ent, whitelist_map=py_m_env) assert py_dc_ent + + +def test_pydantic_alias(): + test_env = WhitelistMap.create() + + @_whitelist_for_serdes(test_env) + class SomeDagsterModel(DagsterModel): + unaliased_id: int = Field(..., alias="id_alias") + name: str + + o = SomeDagsterModel(id_alias=5, name="fdsk") + packed_o = pack_value(o, whitelist_map=test_env) + assert packed_o == {"__class__": "SomeDagsterModel", "id_alias": 5, "name": "fdsk"} + assert unpack_value(packed_o, whitelist_map=test_env, as_type=SomeDagsterModel) == o + + ser_o = serialize_value(o, whitelist_map=test_env) + assert deserialize_value(ser_o, whitelist_map=test_env) == o + + +def test_pydantic_serialization_alias(): + test_env = WhitelistMap.create() + + @_whitelist_for_serdes(test_env) + class SomeDagsterModel(DagsterModel): + unaliased_id: int = Field(..., serialization_alias="id_alias") + name: str + + o = SomeDagsterModel(unaliased_id=5, name="fdsk") + with pytest.raises( + SerializationError, + match="Can't serialize pydantic models with serialization or validation aliases.", + ): + serialize_value(o, whitelist_map=test_env) + + with pytest.raises( + SerializationError, + match="Can't serialize pydantic models with serialization or validation aliases.", + ): + pack_value(o, whitelist_map=test_env) + + +def test_pydantic_validation_alias(): + test_env = WhitelistMap.create() + + @_whitelist_for_serdes(test_env) + class SomeDagsterModel(DagsterModel): + unaliased_id: int = Field(..., validation_alias="id_alias") + name: str + + o = SomeDagsterModel(id_alias=5, name="fdsk") + with pytest.raises( + SerializationError, + match="Can't serialize pydantic models with serialization or validation aliases.", + ): + serialize_value(o, whitelist_map=test_env) + + with pytest.raises( + SerializationError, + match="Can't serialize pydantic models with serialization or validation aliases.", + ): + pack_value(o, whitelist_map=test_env) + + +def test_pydantic_alias_generator(): + test_env = WhitelistMap.create() + + @_whitelist_for_serdes(test_env) + class SomeDagsterModel(DagsterModel): + id: int = Field(...) + name: str + + class Config: + alias_generator = lambda field_name: f"{field_name}_alias" + + o = SomeDagsterModel(id_alias=5, name_alias="fdsk") + packed_o = pack_value(o, whitelist_map=test_env) + assert packed_o == {"__class__": "SomeDagsterModel", "id_alias": 5, "name_alias": "fdsk"} + assert unpack_value(packed_o, whitelist_map=test_env, as_type=SomeDagsterModel) == o + + ser_o = serialize_value(o, whitelist_map=test_env) + assert deserialize_value(ser_o, whitelist_map=test_env) == o