From 793517c23c8c46102be99472707c35e57eb91aef Mon Sep 17 00:00:00 2001 From: Claire Lin Date: Tue, 21 Nov 2023 11:41:40 -0800 Subject: [PATCH] generalize for non-scalar values --- .../dagster/_core/definitions/asset_key.py | 189 ------------------ .../dagster/_core/definitions/events.py | 163 ++++++++++++++- .../dagster/dagster/_serdes/serdes.py | 49 +++-- .../general_tests/test_serdes.py | 60 ++++-- 4 files changed, 224 insertions(+), 237 deletions(-) delete mode 100644 python_modules/dagster/dagster/_core/definitions/asset_key.py diff --git a/python_modules/dagster/dagster/_core/definitions/asset_key.py b/python_modules/dagster/dagster/_core/definitions/asset_key.py deleted file mode 100644 index 9c597cf35bf32..0000000000000 --- a/python_modules/dagster/dagster/_core/definitions/asset_key.py +++ /dev/null @@ -1,189 +0,0 @@ -import re -from typing import ( - TYPE_CHECKING, - Mapping, - NamedTuple, - Optional, - Sequence, - Union, -) - -import dagster._check as check -import dagster._seven as seven -from dagster._annotations import PublicAttr -from dagster._serdes import whitelist_for_serdes - -if TYPE_CHECKING: - from dagster._core.definitions.assets import AssetsDefinition - from dagster._core.definitions.source_asset import SourceAsset - - -ASSET_KEY_SPLIT_REGEX = re.compile("[^a-zA-Z0-9_]") -ASSET_KEY_DELIMITER = "/" - - -def parse_asset_key_string(s: str) -> Sequence[str]: - return list(filter(lambda x: x, re.split(ASSET_KEY_SPLIT_REGEX, s))) - - -@whitelist_for_serdes -class AssetKey(NamedTuple("_AssetKey", [("path", PublicAttr[Sequence[str]])])): - """Object representing the structure of an asset key. Takes in a sanitized string, list of - strings, or tuple of strings. - - Example usage: - - .. code-block:: python - - from dagster import op - - @op - def emit_metadata(context, df): - yield AssetMaterialization( - asset_key=AssetKey('flat_asset_key'), - metadata={"text_metadata": "Text-based metadata for this event"}, - ) - - @op - def structured_asset_key(context, df): - yield AssetMaterialization( - asset_key=AssetKey(['parent', 'child', 'grandchild']), - metadata={"text_metadata": "Text-based metadata for this event"}, - ) - - @op - def structured_asset_key_2(context, df): - yield AssetMaterialization( - asset_key=AssetKey(('parent', 'child', 'grandchild')), - metadata={"text_metadata": "Text-based metadata for this event"}, - ) - - Args: - path (Sequence[str]): String, list of strings, or tuple of strings. A list of strings - represent the hierarchical structure of the asset_key. - """ - - def __new__(cls, path: Sequence[str]): - if isinstance(path, str): - path = [path] - else: - path = list(check.sequence_param(path, "path", of_type=str)) - - return super(AssetKey, cls).__new__(cls, path=path) - - def __str__(self): - return f"AssetKey({self.path})" - - def __repr__(self): - return f"AssetKey({self.path})" - - def __hash__(self): - return hash(tuple(self.path)) - - def __eq__(self, other): - if not isinstance(other, AssetKey): - return False - if len(self.path) != len(other.path): - return False - for i in range(0, len(self.path)): - if self.path[i] != other.path[i]: - return False - return True - - def to_string(self) -> str: - """E.g. '["first_component", "second_component"]'.""" - return seven.json.dumps(self.path) - - def to_user_string(self) -> str: - """E.g. "first_component/second_component".""" - return ASSET_KEY_DELIMITER.join(self.path) - - def to_python_identifier(self, suffix: Optional[str] = None) -> str: - """Build a valid Python identifier based on the asset key that can be used for - operation names or I/O manager keys. - """ - path = list(self.path) - - if suffix is not None: - path.append(suffix) - - return "__".join(path).replace("-", "_") - - @staticmethod - def from_user_string(asset_key_string: str) -> "AssetKey": - return AssetKey(asset_key_string.split(ASSET_KEY_DELIMITER)) - - @staticmethod - def from_db_string(asset_key_string: Optional[str]) -> Optional["AssetKey"]: - if not asset_key_string: - return None - if asset_key_string[0] == "[": - # is a json string - try: - path = seven.json.loads(asset_key_string) - except seven.JSONDecodeError: - path = parse_asset_key_string(asset_key_string) - else: - path = parse_asset_key_string(asset_key_string) - return AssetKey(path) - - @staticmethod - def get_db_prefix(path: Sequence[str]): - check.sequence_param(path, "path", of_type=str) - return seven.json.dumps(path)[:-2] # strip trailing '"]' from json string - - @staticmethod - def from_graphql_input(graphql_input_asset_key: Mapping[str, Sequence[str]]) -> "AssetKey": - return AssetKey(graphql_input_asset_key["path"]) - - def to_graphql_input(self) -> Mapping[str, Sequence[str]]: - return {"path": self.path} - - @staticmethod - def from_coercible(arg: "CoercibleToAssetKey") -> "AssetKey": - if isinstance(arg, AssetKey): - return check.inst_param(arg, "arg", AssetKey) - elif isinstance(arg, str): - return AssetKey([arg]) - elif isinstance(arg, list): - check.list_param(arg, "arg", of_type=str) - return AssetKey(arg) - elif isinstance(arg, tuple): - check.tuple_param(arg, "arg", of_type=str) - return AssetKey(arg) - else: - check.failed(f"Unexpected type for AssetKey: {type(arg)}") - - @staticmethod - def from_coercible_or_definition( - arg: Union["CoercibleToAssetKey", "AssetsDefinition", "SourceAsset"] - ) -> "AssetKey": - from dagster._core.definitions.assets import AssetsDefinition - from dagster._core.definitions.source_asset import SourceAsset - - if isinstance(arg, AssetsDefinition): - return arg.key - elif isinstance(arg, SourceAsset): - return arg.key - else: - return AssetKey.from_coercible(arg) - - def has_prefix(self, prefix: Sequence[str]) -> bool: - return len(self.path) >= len(prefix) and self.path[: len(prefix)] == prefix - - def with_prefix(self, prefix: "CoercibleToAssetKeyPrefix") -> "AssetKey": - prefix = key_prefix_from_coercible(prefix) - return AssetKey(list(prefix) + list(self.path)) - - -CoercibleToAssetKey = Union[AssetKey, str, Sequence[str]] -CoercibleToAssetKeyPrefix = Union[str, Sequence[str]] - - -def key_prefix_from_coercible(key_prefix: CoercibleToAssetKeyPrefix) -> Sequence[str]: - if isinstance(key_prefix, str): - return [key_prefix] - elif isinstance(key_prefix, list): - return key_prefix - else: - check.failed(f"Unexpected type for key_prefix: {type(key_prefix)}") diff --git a/python_modules/dagster/dagster/_core/definitions/events.py b/python_modules/dagster/dagster/_core/definitions/events.py index 1856e07b3b640..43e4f35db6888 100644 --- a/python_modules/dagster/dagster/_core/definitions/events.py +++ b/python_modules/dagster/dagster/_core/definitions/events.py @@ -1,3 +1,4 @@ +import re from enum import Enum from typing import ( TYPE_CHECKING, @@ -16,13 +17,13 @@ ) import dagster._check as check +import dagster._seven as seven from dagster._annotations import PublicAttr, deprecated, experimental_param, public from dagster._core.definitions.data_version import DATA_VERSION_TAG, DataVersion from dagster._core.storage.tags import MULTIDIMENSIONAL_PARTITION_PREFIX, SYSTEM_TAG_PREFIX from dagster._serdes import whitelist_for_serdes from dagster._serdes.serdes import NamedTupleSerializer -from .asset_key import AssetKey, parse_asset_key_string from .metadata import ( MetadataFieldSerializer, MetadataMapping, @@ -33,9 +34,169 @@ from .utils import DEFAULT_OUTPUT, check_valid_name if TYPE_CHECKING: + from dagster._core.definitions.assets import AssetsDefinition + from dagster._core.definitions.source_asset import SourceAsset from dagster._core.execution.context.output import OutputContext +ASSET_KEY_SPLIT_REGEX = re.compile("[^a-zA-Z0-9_]") +ASSET_KEY_DELIMITER = "/" + + +def parse_asset_key_string(s: str) -> Sequence[str]: + return list(filter(lambda x: x, re.split(ASSET_KEY_SPLIT_REGEX, s))) + + +@whitelist_for_serdes +class AssetKey(NamedTuple("_AssetKey", [("path", PublicAttr[Sequence[str]])])): + """Object representing the structure of an asset key. Takes in a sanitized string, list of + strings, or tuple of strings. + + Example usage: + + .. code-block:: python + + from dagster import op + + @op + def emit_metadata(context, df): + yield AssetMaterialization( + asset_key=AssetKey('flat_asset_key'), + metadata={"text_metadata": "Text-based metadata for this event"}, + ) + + @op + def structured_asset_key(context, df): + yield AssetMaterialization( + asset_key=AssetKey(['parent', 'child', 'grandchild']), + metadata={"text_metadata": "Text-based metadata for this event"}, + ) + + @op + def structured_asset_key_2(context, df): + yield AssetMaterialization( + asset_key=AssetKey(('parent', 'child', 'grandchild')), + metadata={"text_metadata": "Text-based metadata for this event"}, + ) + + Args: + path (Sequence[str]): String, list of strings, or tuple of strings. A list of strings + represent the hierarchical structure of the asset_key. + """ + + def __new__(cls, path: Sequence[str]): + if isinstance(path, str): + path = [path] + else: + path = list(check.sequence_param(path, "path", of_type=str)) + + return super(AssetKey, cls).__new__(cls, path=path) + + def __str__(self): + return f"AssetKey({self.path})" + + def __repr__(self): + return f"AssetKey({self.path})" + + def __hash__(self): + return hash(tuple(self.path)) + + def __eq__(self, other): + if not isinstance(other, AssetKey): + return False + if len(self.path) != len(other.path): + return False + for i in range(0, len(self.path)): + if self.path[i] != other.path[i]: + return False + return True + + def to_string(self) -> str: + """E.g. '["first_component", "second_component"]'.""" + return seven.json.dumps(self.path) + + def to_user_string(self) -> str: + """E.g. "first_component/second_component".""" + return ASSET_KEY_DELIMITER.join(self.path) + + def to_python_identifier(self, suffix: Optional[str] = None) -> str: + """Build a valid Python identifier based on the asset key that can be used for + operation names or I/O manager keys. + """ + path = list(self.path) + + if suffix is not None: + path.append(suffix) + + return "__".join(path).replace("-", "_") + + @staticmethod + def from_user_string(asset_key_string: str) -> "AssetKey": + return AssetKey(asset_key_string.split(ASSET_KEY_DELIMITER)) + + @staticmethod + def from_db_string(asset_key_string: Optional[str]) -> Optional["AssetKey"]: + if not asset_key_string: + return None + if asset_key_string[0] == "[": + # is a json string + try: + path = seven.json.loads(asset_key_string) + except seven.JSONDecodeError: + path = parse_asset_key_string(asset_key_string) + else: + path = parse_asset_key_string(asset_key_string) + return AssetKey(path) + + @staticmethod + def get_db_prefix(path: Sequence[str]): + check.sequence_param(path, "path", of_type=str) + return seven.json.dumps(path)[:-2] # strip trailing '"]' from json string + + @staticmethod + def from_graphql_input(graphql_input_asset_key: Mapping[str, Sequence[str]]) -> "AssetKey": + return AssetKey(graphql_input_asset_key["path"]) + + def to_graphql_input(self) -> Mapping[str, Sequence[str]]: + return {"path": self.path} + + @staticmethod + def from_coercible(arg: "CoercibleToAssetKey") -> "AssetKey": + if isinstance(arg, AssetKey): + return check.inst_param(arg, "arg", AssetKey) + elif isinstance(arg, str): + return AssetKey([arg]) + elif isinstance(arg, list): + check.list_param(arg, "arg", of_type=str) + return AssetKey(arg) + elif isinstance(arg, tuple): + check.tuple_param(arg, "arg", of_type=str) + return AssetKey(arg) + else: + check.failed(f"Unexpected type for AssetKey: {type(arg)}") + + @staticmethod + def from_coercible_or_definition( + arg: Union["CoercibleToAssetKey", "AssetsDefinition", "SourceAsset"] + ) -> "AssetKey": + from dagster._core.definitions.assets import AssetsDefinition + from dagster._core.definitions.source_asset import SourceAsset + + if isinstance(arg, AssetsDefinition): + return arg.key + elif isinstance(arg, SourceAsset): + return arg.key + else: + return AssetKey.from_coercible(arg) + + def has_prefix(self, prefix: Sequence[str]) -> bool: + return len(self.path) >= len(prefix) and self.path[: len(prefix)] == prefix + + def with_prefix(self, prefix: "CoercibleToAssetKeyPrefix") -> "AssetKey": + prefix = key_prefix_from_coercible(prefix) + return AssetKey(list(prefix) + list(self.path)) + + class AssetKeyPartitionKey(NamedTuple): """An AssetKey with an (optional) partition key. Refers either to a non-partitioned asset or a partition of a partitioned asset. diff --git a/python_modules/dagster/dagster/_serdes/serdes.py b/python_modules/dagster/dagster/_serdes/serdes.py index a83bbd94f1f4a..1a70a40fe2e75 100644 --- a/python_modules/dagster/dagster/_serdes/serdes.py +++ b/python_modules/dagster/dagster/_serdes/serdes.py @@ -18,7 +18,6 @@ from functools import partial from inspect import Parameter, signature from typing import ( - TYPE_CHECKING, AbstractSet, Any, Callable, @@ -49,9 +48,6 @@ from .errors import DeserializationError, SerdesUsageError, SerializationError -if TYPE_CHECKING: - from dagster._core.definitions.asset_key import AssetKey - ################################################################################################### # Types ################################################################################################### @@ -95,26 +91,28 @@ "UnknownSerdesValue", ] +_K = TypeVar("_K") _V = TypeVar("_V") -class AssetKeyMap(Mapping["AssetKey", _V]): - def __init__(self, mapping: Mapping["AssetKey", _V] = {}) -> None: - from dagster._core.definitions.asset_key import AssetKey +class SerializableNonScalarKeyMapping(Mapping[_K, _V]): + """Wrapper class for non-scalar key mappings, used to performantly type check when serializing + without having to access types of specific keys. + """ - check.mapping_param(mapping, "mapping", key_type=AssetKey) - self.mapping: Mapping[AssetKey, _V] = mapping + def __init__(self, mapping: Mapping[_K, _V] = {}) -> None: + self.mapping: Mapping[_K, _V] = mapping - def __setitem__(self, key, item): - raise NotImplementedError("AssetKeyMap is immutable") + def __setitem__(self, key: _K, item: _V): + raise NotImplementedError("SerializableNonScalarKeyMapping is immutable") - def __getitem__(self, item: "AssetKey") -> _V: + def __getitem__(self, item: _K) -> _V: return self.mapping[item] def __len__(self) -> int: return len(self.mapping) - def __iter__(self) -> Iterator["AssetKey"]: + def __iter__(self) -> Iterator[_K]: return iter(self.mapping) @@ -708,19 +706,22 @@ def _pack_value( _pack_value(item, whitelist_map, f"{descent_path}[{idx}]") for idx, item in enumerate(cast(list, val)) ] - if tval is AssetKeyMap: - return { - "__asset_key_map__": { - key.to_string(): _pack_value(value, whitelist_map, f"{descent_path}.{key}") - for key, value in cast(dict, val).items() - } - } if tval is dict: dict_val = cast(dict, val) return { key: _pack_value(value, whitelist_map, f"{descent_path}.{key}") for key, value in dict_val.items() } + if tval is SerializableNonScalarKeyMapping: + return { + "__non_scalar_key_mapping_items__": [ + [ + _pack_value(k, whitelist_map, f"{descent_path}.{k}"), + _pack_value(v, whitelist_map, f"{descent_path}.{k}"), + ] + for k, v in cast(dict, val).items() + ] + } # inlined is_named_tuple_instance if isinstance(val, tuple) and hasattr(val, "_fields"): @@ -891,12 +892,10 @@ def _unpack_object(val: dict, whitelist_map: WhitelistMap, context: UnpackContex items = cast(List[JsonSerializableValue], val["__frozenset__"]) return frozenset(items) - if "__asset_key_map__" in val: - from dagster._core.events import AssetKey - + if "__non_scalar_key_mapping_items__" in val: return { - AssetKey.from_db_string(key): _unpack_value(value, whitelist_map, context) - for key, value in cast(dict, val["__asset_key_map__"]).items() + _unpack_value(k, whitelist_map, context): _unpack_value(v, whitelist_map, context) + for k, v in val["__non_scalar_key_mapping_items__"] } return val diff --git a/python_modules/dagster/dagster_tests/general_tests/test_serdes.py b/python_modules/dagster/dagster_tests/general_tests/test_serdes.py index 6d747c1168667..7f7f91a51acd0 100644 --- a/python_modules/dagster/dagster_tests/general_tests/test_serdes.py +++ b/python_modules/dagster/dagster_tests/general_tests/test_serdes.py @@ -5,14 +5,13 @@ from typing import AbstractSet, Any, Dict, Mapping, NamedTuple, Optional, Sequence import pytest -from dagster._check import CheckError, ParameterCheckError, inst_param, set_param -from dagster._core.events import AssetKey +from dagster._check import ParameterCheckError, inst_param, set_param from dagster._serdes.errors import DeserializationError, SerdesUsageError, SerializationError from dagster._serdes.serdes import ( - AssetKeyMap, EnumSerializer, FieldSerializer, NamedTupleSerializer, + SerializableNonScalarKeyMapping, SetToSequenceFieldSerializer, UnpackContext, WhitelistMap, @@ -725,38 +724,55 @@ class Foo(Enum): assert deserialized == Foo.RED -def test_serialize_mapping_keyed_by_asset_key(): +def test_serialize_non_scalar_key_mapping(): test_env = WhitelistMap.create() - asset_key_map = AssetKeyMap({AssetKey(["a", "a_2"]): 1}) - serialized = serialize_value(asset_key_map, whitelist_map=test_env) - assert serialized == '{"__asset_key_map__": {"[\\"a\\", \\"a_2\\"]": 1}}' - assert asset_key_map == deserialize_value(serialized, whitelist_map=test_env) + @_whitelist_for_serdes(whitelist_map=test_env) + class Bar(NamedTuple): + color: str + non_scalar_key_mapping = SerializableNonScalarKeyMapping({Bar("red"): 1}) -def test_asset_key_map(): - asset_key_map = AssetKeyMap({AssetKey(["a", "a_2"]): 1}) + serialized = serialize_value(non_scalar_key_mapping, whitelist_map=test_env) + assert ( + serialized + == """{"__non_scalar_key_mapping_items__": [[{"__class__": "Bar", "color": "red"}, 1]]}""" + ) + assert non_scalar_key_mapping == deserialize_value(serialized, whitelist_map=test_env) - assert len(asset_key_map) == 1 - assert asset_key_map[AssetKey(["a", "a_2"])] == 1 - assert list(iter(asset_key_map)) == list(iter([AssetKey(["a", "a_2"])])) - with pytest.raises(NotImplementedError, match="AssetKeyMap is immutable"): - asset_key_map["foo"] = None +def test_serializable_non_scalar_key_mapping(): + test_env = WhitelistMap.create() - with pytest.raises(CheckError, match="Key in Mapping mismatches type"): - AssetKeyMap({"a": 1}) + @_whitelist_for_serdes(test_env) + class Bar(NamedTuple): + color: str + non_scalar_key_mapping = SerializableNonScalarKeyMapping({Bar("red"): 1}) -def test_mapping_keyed_by_asset_key_in_named_tuple(): + assert len(non_scalar_key_mapping) == 1 + assert non_scalar_key_mapping[Bar("red")] == 1 + assert list(iter(non_scalar_key_mapping)) == list(iter([Bar("red")])) + + with pytest.raises(NotImplementedError, match="SerializableNonScalarKeyMapping is immutable"): + non_scalar_key_mapping["foo"] = None + + +def test_serializable_non_scalar_key_mapping_in_named_tuple(): test_env = WhitelistMap.create() @_whitelist_for_serdes(test_env) - class Foo(NamedTuple("_Foo", [("keyed_by_asset_key", Mapping[AssetKey, int])])): - def __new__(cls, keyed_by_asset_key): - return super(Foo, cls).__new__(cls, AssetKeyMap(keyed_by_asset_key)) + class Bar(NamedTuple): + color: str + + @_whitelist_for_serdes(test_env) + class Foo(NamedTuple("_Foo", [("keyed_by_non_scalar", Mapping[Bar, int])])): + def __new__(cls, keyed_by_non_scalar): + return super(Foo, cls).__new__( + cls, SerializableNonScalarKeyMapping(keyed_by_non_scalar) + ) - named_tuple = Foo(keyed_by_asset_key={AssetKey(["a", "a_2"]): 1}) + named_tuple = Foo(keyed_by_non_scalar={Bar("red"): 1}) assert ( deserialize_value( serialize_value(named_tuple, whitelist_map=test_env), whitelist_map=test_env