Skip to content

Commit

Permalink
amend to use custom AssetKeyMap type
Browse files Browse the repository at this point in the history
  • Loading branch information
clairelin135 committed Nov 20, 2023
1 parent 9a23533 commit 656cd32
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 33 deletions.
51 changes: 37 additions & 14 deletions python_modules/dagster/dagster/_serdes/serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from functools import partial
from inspect import Parameter, signature
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
Dict,
FrozenSet,
Generic,
Iterator,
List,
Mapping,
NamedTuple,
Expand All @@ -47,6 +49,9 @@

from .errors import DeserializationError, SerdesUsageError, SerializationError

if TYPE_CHECKING:
from dagster._core.definitions.asset_key import AssetKey

###################################################################################################
# Types
###################################################################################################
Expand Down Expand Up @@ -90,6 +95,28 @@
"UnknownSerdesValue",
]

_V = TypeVar("_V")


class AssetKeyMap(Mapping["AssetKey", _V]):
def __init__(self, mapping: Mapping["AssetKey", _V] = {}) -> None:
from dagster._core.definitions.asset_key import AssetKey

check.mapping_param(mapping, "mapping", key_type=AssetKey)
self.mapping: Mapping[AssetKey, _V] = mapping

def __setitem__(self, key, item):
raise NotImplementedError("AssetKeyMap is immutable")

def __getitem__(self, item: "AssetKey") -> _V:
return self.mapping[item]

def __len__(self) -> int:
return len(self.mapping)

def __iter__(self) -> Iterator["AssetKey"]:
return iter(self.mapping)


###################################################################################################
# Whitelisting
Expand Down Expand Up @@ -681,19 +708,15 @@ def _pack_value(
_pack_value(item, whitelist_map, f"{descent_path}[{idx}]")
for idx, item in enumerate(cast(list, val))
]
if tval is AssetKeyMap:
return {
"__asset_key_map__": {
key.to_string(): _pack_value(value, whitelist_map, f"{descent_path}.{key}")
for key, value in cast(dict, val).items()
}
}
if tval is dict:
from dagster._core.definitions.asset_key import AssetKey

dict_val = cast(dict, val)

if dict_val and all(type(key) is AssetKey for key in dict_val.keys()):
return {
"__dict_keyed_by_asset_key_": {
key.to_user_string(): _pack_value(value, whitelist_map, f"{descent_path}.{key}")
for key, value in dict_val.items()
}
}

return {
key: _pack_value(value, whitelist_map, f"{descent_path}.{key}")
for key, value in dict_val.items()
Expand Down Expand Up @@ -868,12 +891,12 @@ def _unpack_object(val: dict, whitelist_map: WhitelistMap, context: UnpackContex
items = cast(List[JsonSerializableValue], val["__frozenset__"])
return frozenset(items)

if "__dict_keyed_by_asset_key_" in val:
if "__asset_key_map__" in val:
from dagster._core.events import AssetKey

return {
AssetKey.from_user_string(key): _unpack_value(value, whitelist_map, context)
for key, value in cast(dict, val["__dict_keyed_by_asset_key_"]).items()
AssetKey.from_db_string(key): _unpack_value(value, whitelist_map, context)
for key, value in cast(dict, val["__asset_key_map__"]).items()
}

return val
Expand Down
42 changes: 23 additions & 19 deletions python_modules/dagster/dagster_tests/general_tests/test_serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from typing import AbstractSet, Any, Dict, Mapping, NamedTuple, Optional, Sequence

import pytest
from dagster._check import ParameterCheckError, inst_param, set_param
from dagster._check import CheckError, ParameterCheckError, inst_param, set_param
from dagster._core.events import AssetKey
from dagster._serdes.errors import DeserializationError, SerdesUsageError, SerializationError
from dagster._serdes.serdes import (
AssetKeyMap,
EnumSerializer,
FieldSerializer,
NamedTupleSerializer,
Expand Down Expand Up @@ -724,33 +725,36 @@ class Foo(Enum):
assert deserialized == Foo.RED


def test_serialize_dict_keyed_by_asset_key():
def test_serialize_mapping_keyed_by_asset_key():
test_env = WhitelistMap.create()
asset_key_map = AssetKeyMap({AssetKey(["a", "a_2"]): 1})

@_whitelist_for_serdes(test_env)
class Fizz(NamedTuple):
buzz: int
serialized = serialize_value(asset_key_map, whitelist_map=test_env)
assert serialized == '{"__asset_key_map__": {"[\\"a\\", \\"a_2\\"]": 1}}'
assert asset_key_map == deserialize_value(serialized, whitelist_map=test_env)

mapping = {
AssetKey(["a", "a_2"]): Fizz(1),
AssetKey("b"): 1,
AssetKey("c"): {AssetKey("d"): "1"},
}

serialized = serialize_value(mapping, whitelist_map=test_env)
assert (
serialized
== '{"__dict_keyed_by_asset_key_": {"a/a_2": {"__class__": "Fizz", "buzz": 1}, "b": 1, "c": {"__dict_keyed_by_asset_key_": {"d": "1"}}}}'
)
assert deserialize_value(serialized, whitelist_map=test_env) == mapping
def test_asset_key_map():
asset_key_map = AssetKeyMap({AssetKey(["a", "a_2"]): 1})

assert len(asset_key_map) == 1
assert asset_key_map[AssetKey(["a", "a_2"])] == 1
assert list(iter(asset_key_map)) == list(iter([AssetKey(["a", "a_2"])]))

def test_serialize_mapping_keyed_by_asset_key():
with pytest.raises(NotImplementedError, match="AssetKeyMap is immutable"):
asset_key_map["foo"] = None

with pytest.raises(CheckError, match="Key in Mapping mismatches type"):
AssetKeyMap({"a": 1})


def test_mapping_keyed_by_asset_key_in_named_tuple():
test_env = WhitelistMap.create()

@_whitelist_for_serdes(test_env)
class Foo(NamedTuple):
keyed_by_asset_key: Mapping[AssetKey, int]
class Foo(NamedTuple("_Foo", [("keyed_by_asset_key", Mapping[AssetKey, int])])):
def __new__(cls, keyed_by_asset_key):
return super(Foo, cls).__new__(cls, AssetKeyMap(keyed_by_asset_key))

named_tuple = Foo(keyed_by_asset_key={AssetKey(["a", "a_2"]): 1})
assert (
Expand Down

0 comments on commit 656cd32

Please sign in to comment.