From 656cd320570e72b3361f2d8bd69fde9c206a17b6 Mon Sep 17 00:00:00 2001 From: Claire Lin Date: Mon, 20 Nov 2023 15:28:26 -0800 Subject: [PATCH] amend to use custom AssetKeyMap type --- .../dagster/dagster/_serdes/serdes.py | 51 ++++++++++++++----- .../general_tests/test_serdes.py | 42 ++++++++------- 2 files changed, 60 insertions(+), 33 deletions(-) diff --git a/python_modules/dagster/dagster/_serdes/serdes.py b/python_modules/dagster/dagster/_serdes/serdes.py index c53f832efc09a..a83bbd94f1f4a 100644 --- a/python_modules/dagster/dagster/_serdes/serdes.py +++ b/python_modules/dagster/dagster/_serdes/serdes.py @@ -18,12 +18,14 @@ from functools import partial from inspect import Parameter, signature from typing import ( + TYPE_CHECKING, AbstractSet, Any, Callable, Dict, FrozenSet, Generic, + Iterator, List, Mapping, NamedTuple, @@ -47,6 +49,9 @@ from .errors import DeserializationError, SerdesUsageError, SerializationError +if TYPE_CHECKING: + from dagster._core.definitions.asset_key import AssetKey + ################################################################################################### # Types ################################################################################################### @@ -90,6 +95,28 @@ "UnknownSerdesValue", ] +_V = TypeVar("_V") + + +class AssetKeyMap(Mapping["AssetKey", _V]): + def __init__(self, mapping: Mapping["AssetKey", _V] = {}) -> None: + from dagster._core.definitions.asset_key import AssetKey + + check.mapping_param(mapping, "mapping", key_type=AssetKey) + self.mapping: Mapping[AssetKey, _V] = mapping + + def __setitem__(self, key, item): + raise NotImplementedError("AssetKeyMap is immutable") + + def __getitem__(self, item: "AssetKey") -> _V: + return self.mapping[item] + + def __len__(self) -> int: + return len(self.mapping) + + def __iter__(self) -> Iterator["AssetKey"]: + return iter(self.mapping) + ################################################################################################### # Whitelisting @@ -681,19 +708,15 @@ def _pack_value( _pack_value(item, whitelist_map, f"{descent_path}[{idx}]") for idx, item in enumerate(cast(list, val)) ] + if tval is AssetKeyMap: + return { + "__asset_key_map__": { + key.to_string(): _pack_value(value, whitelist_map, f"{descent_path}.{key}") + for key, value in cast(dict, val).items() + } + } if tval is dict: - from dagster._core.definitions.asset_key import AssetKey - dict_val = cast(dict, val) - - if dict_val and all(type(key) is AssetKey for key in dict_val.keys()): - return { - "__dict_keyed_by_asset_key_": { - key.to_user_string(): _pack_value(value, whitelist_map, f"{descent_path}.{key}") - for key, value in dict_val.items() - } - } - return { key: _pack_value(value, whitelist_map, f"{descent_path}.{key}") for key, value in dict_val.items() @@ -868,12 +891,12 @@ def _unpack_object(val: dict, whitelist_map: WhitelistMap, context: UnpackContex items = cast(List[JsonSerializableValue], val["__frozenset__"]) return frozenset(items) - if "__dict_keyed_by_asset_key_" in val: + if "__asset_key_map__" in val: from dagster._core.events import AssetKey return { - AssetKey.from_user_string(key): _unpack_value(value, whitelist_map, context) - for key, value in cast(dict, val["__dict_keyed_by_asset_key_"]).items() + AssetKey.from_db_string(key): _unpack_value(value, whitelist_map, context) + for key, value in cast(dict, val["__asset_key_map__"]).items() } return val diff --git a/python_modules/dagster/dagster_tests/general_tests/test_serdes.py b/python_modules/dagster/dagster_tests/general_tests/test_serdes.py index 36ab50540ebfe..6d747c1168667 100644 --- a/python_modules/dagster/dagster_tests/general_tests/test_serdes.py +++ b/python_modules/dagster/dagster_tests/general_tests/test_serdes.py @@ -5,10 +5,11 @@ from typing import AbstractSet, Any, Dict, Mapping, NamedTuple, Optional, Sequence import pytest -from dagster._check import ParameterCheckError, inst_param, set_param +from dagster._check import CheckError, ParameterCheckError, inst_param, set_param from dagster._core.events import AssetKey from dagster._serdes.errors import DeserializationError, SerdesUsageError, SerializationError from dagster._serdes.serdes import ( + AssetKeyMap, EnumSerializer, FieldSerializer, NamedTupleSerializer, @@ -724,33 +725,36 @@ class Foo(Enum): assert deserialized == Foo.RED -def test_serialize_dict_keyed_by_asset_key(): +def test_serialize_mapping_keyed_by_asset_key(): test_env = WhitelistMap.create() + asset_key_map = AssetKeyMap({AssetKey(["a", "a_2"]): 1}) - @_whitelist_for_serdes(test_env) - class Fizz(NamedTuple): - buzz: int + serialized = serialize_value(asset_key_map, whitelist_map=test_env) + assert serialized == '{"__asset_key_map__": {"[\\"a\\", \\"a_2\\"]": 1}}' + assert asset_key_map == deserialize_value(serialized, whitelist_map=test_env) - mapping = { - AssetKey(["a", "a_2"]): Fizz(1), - AssetKey("b"): 1, - AssetKey("c"): {AssetKey("d"): "1"}, - } - serialized = serialize_value(mapping, whitelist_map=test_env) - assert ( - serialized - == '{"__dict_keyed_by_asset_key_": {"a/a_2": {"__class__": "Fizz", "buzz": 1}, "b": 1, "c": {"__dict_keyed_by_asset_key_": {"d": "1"}}}}' - ) - assert deserialize_value(serialized, whitelist_map=test_env) == mapping +def test_asset_key_map(): + asset_key_map = AssetKeyMap({AssetKey(["a", "a_2"]): 1}) + assert len(asset_key_map) == 1 + assert asset_key_map[AssetKey(["a", "a_2"])] == 1 + assert list(iter(asset_key_map)) == list(iter([AssetKey(["a", "a_2"])])) -def test_serialize_mapping_keyed_by_asset_key(): + with pytest.raises(NotImplementedError, match="AssetKeyMap is immutable"): + asset_key_map["foo"] = None + + with pytest.raises(CheckError, match="Key in Mapping mismatches type"): + AssetKeyMap({"a": 1}) + + +def test_mapping_keyed_by_asset_key_in_named_tuple(): test_env = WhitelistMap.create() @_whitelist_for_serdes(test_env) - class Foo(NamedTuple): - keyed_by_asset_key: Mapping[AssetKey, int] + class Foo(NamedTuple("_Foo", [("keyed_by_asset_key", Mapping[AssetKey, int])])): + def __new__(cls, keyed_by_asset_key): + return super(Foo, cls).__new__(cls, AssetKeyMap(keyed_by_asset_key)) named_tuple = Foo(keyed_by_asset_key={AssetKey(["a", "a_2"]): 1}) assert (