From 8be2bba2e57922260e04428903458488f17eea83 Mon Sep 17 00:00:00 2001 From: Claire Lin Date: Wed, 15 Nov 2023 17:21:08 -0800 Subject: [PATCH] enable serializing dicts keyed by asset key --- .../dagster/_core/definitions/asset_key.py | 189 ++++++++++++++++++ .../dagster/_core/definitions/events.py | 163 +-------------- .../dagster/dagster/_serdes/serdes.py | 22 +- .../general_tests/test_serdes.py | 38 ++++ 4 files changed, 249 insertions(+), 163 deletions(-) create mode 100644 python_modules/dagster/dagster/_core/definitions/asset_key.py diff --git a/python_modules/dagster/dagster/_core/definitions/asset_key.py b/python_modules/dagster/dagster/_core/definitions/asset_key.py new file mode 100644 index 0000000000000..9c597cf35bf32 --- /dev/null +++ b/python_modules/dagster/dagster/_core/definitions/asset_key.py @@ -0,0 +1,189 @@ +import re +from typing import ( + TYPE_CHECKING, + Mapping, + NamedTuple, + Optional, + Sequence, + Union, +) + +import dagster._check as check +import dagster._seven as seven +from dagster._annotations import PublicAttr +from dagster._serdes import whitelist_for_serdes + +if TYPE_CHECKING: + from dagster._core.definitions.assets import AssetsDefinition + from dagster._core.definitions.source_asset import SourceAsset + + +ASSET_KEY_SPLIT_REGEX = re.compile("[^a-zA-Z0-9_]") +ASSET_KEY_DELIMITER = "/" + + +def parse_asset_key_string(s: str) -> Sequence[str]: + return list(filter(lambda x: x, re.split(ASSET_KEY_SPLIT_REGEX, s))) + + +@whitelist_for_serdes +class AssetKey(NamedTuple("_AssetKey", [("path", PublicAttr[Sequence[str]])])): + """Object representing the structure of an asset key. Takes in a sanitized string, list of + strings, or tuple of strings. + + Example usage: + + .. code-block:: python + + from dagster import op + + @op + def emit_metadata(context, df): + yield AssetMaterialization( + asset_key=AssetKey('flat_asset_key'), + metadata={"text_metadata": "Text-based metadata for this event"}, + ) + + @op + def structured_asset_key(context, df): + yield AssetMaterialization( + asset_key=AssetKey(['parent', 'child', 'grandchild']), + metadata={"text_metadata": "Text-based metadata for this event"}, + ) + + @op + def structured_asset_key_2(context, df): + yield AssetMaterialization( + asset_key=AssetKey(('parent', 'child', 'grandchild')), + metadata={"text_metadata": "Text-based metadata for this event"}, + ) + + Args: + path (Sequence[str]): String, list of strings, or tuple of strings. A list of strings + represent the hierarchical structure of the asset_key. + """ + + def __new__(cls, path: Sequence[str]): + if isinstance(path, str): + path = [path] + else: + path = list(check.sequence_param(path, "path", of_type=str)) + + return super(AssetKey, cls).__new__(cls, path=path) + + def __str__(self): + return f"AssetKey({self.path})" + + def __repr__(self): + return f"AssetKey({self.path})" + + def __hash__(self): + return hash(tuple(self.path)) + + def __eq__(self, other): + if not isinstance(other, AssetKey): + return False + if len(self.path) != len(other.path): + return False + for i in range(0, len(self.path)): + if self.path[i] != other.path[i]: + return False + return True + + def to_string(self) -> str: + """E.g. '["first_component", "second_component"]'.""" + return seven.json.dumps(self.path) + + def to_user_string(self) -> str: + """E.g. "first_component/second_component".""" + return ASSET_KEY_DELIMITER.join(self.path) + + def to_python_identifier(self, suffix: Optional[str] = None) -> str: + """Build a valid Python identifier based on the asset key that can be used for + operation names or I/O manager keys. + """ + path = list(self.path) + + if suffix is not None: + path.append(suffix) + + return "__".join(path).replace("-", "_") + + @staticmethod + def from_user_string(asset_key_string: str) -> "AssetKey": + return AssetKey(asset_key_string.split(ASSET_KEY_DELIMITER)) + + @staticmethod + def from_db_string(asset_key_string: Optional[str]) -> Optional["AssetKey"]: + if not asset_key_string: + return None + if asset_key_string[0] == "[": + # is a json string + try: + path = seven.json.loads(asset_key_string) + except seven.JSONDecodeError: + path = parse_asset_key_string(asset_key_string) + else: + path = parse_asset_key_string(asset_key_string) + return AssetKey(path) + + @staticmethod + def get_db_prefix(path: Sequence[str]): + check.sequence_param(path, "path", of_type=str) + return seven.json.dumps(path)[:-2] # strip trailing '"]' from json string + + @staticmethod + def from_graphql_input(graphql_input_asset_key: Mapping[str, Sequence[str]]) -> "AssetKey": + return AssetKey(graphql_input_asset_key["path"]) + + def to_graphql_input(self) -> Mapping[str, Sequence[str]]: + return {"path": self.path} + + @staticmethod + def from_coercible(arg: "CoercibleToAssetKey") -> "AssetKey": + if isinstance(arg, AssetKey): + return check.inst_param(arg, "arg", AssetKey) + elif isinstance(arg, str): + return AssetKey([arg]) + elif isinstance(arg, list): + check.list_param(arg, "arg", of_type=str) + return AssetKey(arg) + elif isinstance(arg, tuple): + check.tuple_param(arg, "arg", of_type=str) + return AssetKey(arg) + else: + check.failed(f"Unexpected type for AssetKey: {type(arg)}") + + @staticmethod + def from_coercible_or_definition( + arg: Union["CoercibleToAssetKey", "AssetsDefinition", "SourceAsset"] + ) -> "AssetKey": + from dagster._core.definitions.assets import AssetsDefinition + from dagster._core.definitions.source_asset import SourceAsset + + if isinstance(arg, AssetsDefinition): + return arg.key + elif isinstance(arg, SourceAsset): + return arg.key + else: + return AssetKey.from_coercible(arg) + + def has_prefix(self, prefix: Sequence[str]) -> bool: + return len(self.path) >= len(prefix) and self.path[: len(prefix)] == prefix + + def with_prefix(self, prefix: "CoercibleToAssetKeyPrefix") -> "AssetKey": + prefix = key_prefix_from_coercible(prefix) + return AssetKey(list(prefix) + list(self.path)) + + +CoercibleToAssetKey = Union[AssetKey, str, Sequence[str]] +CoercibleToAssetKeyPrefix = Union[str, Sequence[str]] + + +def key_prefix_from_coercible(key_prefix: CoercibleToAssetKeyPrefix) -> Sequence[str]: + if isinstance(key_prefix, str): + return [key_prefix] + elif isinstance(key_prefix, list): + return key_prefix + else: + check.failed(f"Unexpected type for key_prefix: {type(key_prefix)}") diff --git a/python_modules/dagster/dagster/_core/definitions/events.py b/python_modules/dagster/dagster/_core/definitions/events.py index 43e4f35db6888..1856e07b3b640 100644 --- a/python_modules/dagster/dagster/_core/definitions/events.py +++ b/python_modules/dagster/dagster/_core/definitions/events.py @@ -1,4 +1,3 @@ -import re from enum import Enum from typing import ( TYPE_CHECKING, @@ -17,13 +16,13 @@ ) import dagster._check as check -import dagster._seven as seven from dagster._annotations import PublicAttr, deprecated, experimental_param, public from dagster._core.definitions.data_version import DATA_VERSION_TAG, DataVersion from dagster._core.storage.tags import MULTIDIMENSIONAL_PARTITION_PREFIX, SYSTEM_TAG_PREFIX from dagster._serdes import whitelist_for_serdes from dagster._serdes.serdes import NamedTupleSerializer +from .asset_key import AssetKey, parse_asset_key_string from .metadata import ( MetadataFieldSerializer, MetadataMapping, @@ -34,169 +33,9 @@ from .utils import DEFAULT_OUTPUT, check_valid_name if TYPE_CHECKING: - from dagster._core.definitions.assets import AssetsDefinition - from dagster._core.definitions.source_asset import SourceAsset from dagster._core.execution.context.output import OutputContext -ASSET_KEY_SPLIT_REGEX = re.compile("[^a-zA-Z0-9_]") -ASSET_KEY_DELIMITER = "/" - - -def parse_asset_key_string(s: str) -> Sequence[str]: - return list(filter(lambda x: x, re.split(ASSET_KEY_SPLIT_REGEX, s))) - - -@whitelist_for_serdes -class AssetKey(NamedTuple("_AssetKey", [("path", PublicAttr[Sequence[str]])])): - """Object representing the structure of an asset key. Takes in a sanitized string, list of - strings, or tuple of strings. - - Example usage: - - .. code-block:: python - - from dagster import op - - @op - def emit_metadata(context, df): - yield AssetMaterialization( - asset_key=AssetKey('flat_asset_key'), - metadata={"text_metadata": "Text-based metadata for this event"}, - ) - - @op - def structured_asset_key(context, df): - yield AssetMaterialization( - asset_key=AssetKey(['parent', 'child', 'grandchild']), - metadata={"text_metadata": "Text-based metadata for this event"}, - ) - - @op - def structured_asset_key_2(context, df): - yield AssetMaterialization( - asset_key=AssetKey(('parent', 'child', 'grandchild')), - metadata={"text_metadata": "Text-based metadata for this event"}, - ) - - Args: - path (Sequence[str]): String, list of strings, or tuple of strings. A list of strings - represent the hierarchical structure of the asset_key. - """ - - def __new__(cls, path: Sequence[str]): - if isinstance(path, str): - path = [path] - else: - path = list(check.sequence_param(path, "path", of_type=str)) - - return super(AssetKey, cls).__new__(cls, path=path) - - def __str__(self): - return f"AssetKey({self.path})" - - def __repr__(self): - return f"AssetKey({self.path})" - - def __hash__(self): - return hash(tuple(self.path)) - - def __eq__(self, other): - if not isinstance(other, AssetKey): - return False - if len(self.path) != len(other.path): - return False - for i in range(0, len(self.path)): - if self.path[i] != other.path[i]: - return False - return True - - def to_string(self) -> str: - """E.g. '["first_component", "second_component"]'.""" - return seven.json.dumps(self.path) - - def to_user_string(self) -> str: - """E.g. "first_component/second_component".""" - return ASSET_KEY_DELIMITER.join(self.path) - - def to_python_identifier(self, suffix: Optional[str] = None) -> str: - """Build a valid Python identifier based on the asset key that can be used for - operation names or I/O manager keys. - """ - path = list(self.path) - - if suffix is not None: - path.append(suffix) - - return "__".join(path).replace("-", "_") - - @staticmethod - def from_user_string(asset_key_string: str) -> "AssetKey": - return AssetKey(asset_key_string.split(ASSET_KEY_DELIMITER)) - - @staticmethod - def from_db_string(asset_key_string: Optional[str]) -> Optional["AssetKey"]: - if not asset_key_string: - return None - if asset_key_string[0] == "[": - # is a json string - try: - path = seven.json.loads(asset_key_string) - except seven.JSONDecodeError: - path = parse_asset_key_string(asset_key_string) - else: - path = parse_asset_key_string(asset_key_string) - return AssetKey(path) - - @staticmethod - def get_db_prefix(path: Sequence[str]): - check.sequence_param(path, "path", of_type=str) - return seven.json.dumps(path)[:-2] # strip trailing '"]' from json string - - @staticmethod - def from_graphql_input(graphql_input_asset_key: Mapping[str, Sequence[str]]) -> "AssetKey": - return AssetKey(graphql_input_asset_key["path"]) - - def to_graphql_input(self) -> Mapping[str, Sequence[str]]: - return {"path": self.path} - - @staticmethod - def from_coercible(arg: "CoercibleToAssetKey") -> "AssetKey": - if isinstance(arg, AssetKey): - return check.inst_param(arg, "arg", AssetKey) - elif isinstance(arg, str): - return AssetKey([arg]) - elif isinstance(arg, list): - check.list_param(arg, "arg", of_type=str) - return AssetKey(arg) - elif isinstance(arg, tuple): - check.tuple_param(arg, "arg", of_type=str) - return AssetKey(arg) - else: - check.failed(f"Unexpected type for AssetKey: {type(arg)}") - - @staticmethod - def from_coercible_or_definition( - arg: Union["CoercibleToAssetKey", "AssetsDefinition", "SourceAsset"] - ) -> "AssetKey": - from dagster._core.definitions.assets import AssetsDefinition - from dagster._core.definitions.source_asset import SourceAsset - - if isinstance(arg, AssetsDefinition): - return arg.key - elif isinstance(arg, SourceAsset): - return arg.key - else: - return AssetKey.from_coercible(arg) - - def has_prefix(self, prefix: Sequence[str]) -> bool: - return len(self.path) >= len(prefix) and self.path[: len(prefix)] == prefix - - def with_prefix(self, prefix: "CoercibleToAssetKeyPrefix") -> "AssetKey": - prefix = key_prefix_from_coercible(prefix) - return AssetKey(list(prefix) + list(self.path)) - - class AssetKeyPartitionKey(NamedTuple): """An AssetKey with an (optional) partition key. Refers either to a non-partitioned asset or a partition of a partitioned asset. diff --git a/python_modules/dagster/dagster/_serdes/serdes.py b/python_modules/dagster/dagster/_serdes/serdes.py index a7cbb9e7b25b4..c53f832efc09a 100644 --- a/python_modules/dagster/dagster/_serdes/serdes.py +++ b/python_modules/dagster/dagster/_serdes/serdes.py @@ -682,9 +682,21 @@ def _pack_value( for idx, item in enumerate(cast(list, val)) ] if tval is dict: + from dagster._core.definitions.asset_key import AssetKey + + dict_val = cast(dict, val) + + if dict_val and all(type(key) is AssetKey for key in dict_val.keys()): + return { + "__dict_keyed_by_asset_key_": { + key.to_user_string(): _pack_value(value, whitelist_map, f"{descent_path}.{key}") + for key, value in dict_val.items() + } + } + return { key: _pack_value(value, whitelist_map, f"{descent_path}.{key}") - for key, value in cast(dict, val).items() + for key, value in dict_val.items() } # inlined is_named_tuple_instance @@ -856,6 +868,14 @@ def _unpack_object(val: dict, whitelist_map: WhitelistMap, context: UnpackContex items = cast(List[JsonSerializableValue], val["__frozenset__"]) return frozenset(items) + if "__dict_keyed_by_asset_key_" in val: + from dagster._core.events import AssetKey + + return { + AssetKey.from_user_string(key): _unpack_value(value, whitelist_map, context) + for key, value in cast(dict, val["__dict_keyed_by_asset_key_"]).items() + } + return val diff --git a/python_modules/dagster/dagster_tests/general_tests/test_serdes.py b/python_modules/dagster/dagster_tests/general_tests/test_serdes.py index 1f84aa2505c12..36ab50540ebfe 100644 --- a/python_modules/dagster/dagster_tests/general_tests/test_serdes.py +++ b/python_modules/dagster/dagster_tests/general_tests/test_serdes.py @@ -6,6 +6,7 @@ import pytest from dagster._check import ParameterCheckError, inst_param, set_param +from dagster._core.events import AssetKey from dagster._serdes.errors import DeserializationError, SerdesUsageError, SerializationError from dagster._serdes.serdes import ( EnumSerializer, @@ -721,3 +722,40 @@ class Foo(Enum): assert serialized == '{"__enum__": "Foo.BLUE"}' deserialized = deserialize_value(serialized, whitelist_map=test_env) assert deserialized == Foo.RED + + +def test_serialize_dict_keyed_by_asset_key(): + test_env = WhitelistMap.create() + + @_whitelist_for_serdes(test_env) + class Fizz(NamedTuple): + buzz: int + + mapping = { + AssetKey(["a", "a_2"]): Fizz(1), + AssetKey("b"): 1, + AssetKey("c"): {AssetKey("d"): "1"}, + } + + serialized = serialize_value(mapping, whitelist_map=test_env) + assert ( + serialized + == '{"__dict_keyed_by_asset_key_": {"a/a_2": {"__class__": "Fizz", "buzz": 1}, "b": 1, "c": {"__dict_keyed_by_asset_key_": {"d": "1"}}}}' + ) + assert deserialize_value(serialized, whitelist_map=test_env) == mapping + + +def test_serialize_mapping_keyed_by_asset_key(): + test_env = WhitelistMap.create() + + @_whitelist_for_serdes(test_env) + class Foo(NamedTuple): + keyed_by_asset_key: Mapping[AssetKey, int] + + named_tuple = Foo(keyed_by_asset_key={AssetKey(["a", "a_2"]): 1}) + assert ( + deserialize_value( + serialize_value(named_tuple, whitelist_map=test_env), whitelist_map=test_env + ) + == named_tuple + )