Skip to content

Commit

Permalink
Implement _resolve_back_compat_method
Browse files Browse the repository at this point in the history
  • Loading branch information
maximearmstrong committed Jan 17, 2025
1 parent 24a469a commit 71ad256
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_load_from_path(sling_path: Path) -> None:
assert len(components) == 1
assert get_asset_keys(components[0]) == {
AssetKey("input_csv"),
AssetKey(["input_duckdb"]),
AssetKey(["foo", "input_duckdb"]),
}

assert_assets(components[0], 2)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from collections.abc import Iterable, Mapping
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any, Callable, Optional

from dagster import AssetKey, AssetSpec, AutoMaterializePolicy, FreshnessPolicy, MetadataValue
from dagster._annotations import public
Expand All @@ -20,17 +20,50 @@ def get_asset_spec(self, stream_definition: Mapping[str, Any]) -> AssetSpec:
the value is a dictionary representing the Sling Replication Stream Config.
"""
return AssetSpec(
key=self._default_asset_key_fn(stream_definition),
deps=self._default_deps_fn(stream_definition),
description=self._default_description_fn(stream_definition),
metadata=self._default_metadata_fn(stream_definition),
tags=self._default_tags_fn(stream_definition),
kinds=self._default_kinds_fn(stream_definition),
group_name=self._default_group_name_fn(stream_definition),
freshness_policy=self._default_freshness_policy_fn(stream_definition),
auto_materialize_policy=self._default_auto_materialize_policy_fn(stream_definition),
key=self._resolve_back_compat_method(
"get_asset_key", self._default_asset_key_fn, stream_definition
),
deps=self._resolve_back_compat_method(
"get_deps_asset_key", self._default_deps_fn, stream_definition
),
description=self._resolve_back_compat_method(
"get_description", self._default_description_fn, stream_definition
),
metadata=self._resolve_back_compat_method(
"get_metadata", self._default_metadata_fn, stream_definition
),
tags=self._resolve_back_compat_method(
"get_tags", self._default_tags_fn, stream_definition
),
kinds=self._resolve_back_compat_method(
"get_kinds", self._default_kinds_fn, stream_definition
),
group_name=self._resolve_back_compat_method(
"get_group_name", self._default_group_name_fn, stream_definition
),
freshness_policy=self._resolve_back_compat_method(
"get_freshness_policy", self._default_freshness_policy_fn, stream_definition
),
auto_materialize_policy=self._resolve_back_compat_method(
"get_auto_materialize_policy",
self._default_auto_materialize_policy_fn,
stream_definition,
),
)

def _resolve_back_compat_method(
self,
method_name: str,
default_fn: Callable[[Mapping[str, Any]], Any],
stream_definition: Mapping[str, Any],
):
method = getattr(type(self), method_name)
base_method = getattr(DagsterSlingTranslator, method_name)
if method is not base_method: # user defined this
return method(self, stream_definition)
else:
return default_fn(stream_definition)

@public
def sanitize_stream_name(self, stream_name: str) -> str:
"""A function that takes a stream name from a Sling replication config and returns a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,29 @@ class CustomSlingTranslator(DagsterSlingTranslator):
def get_asset_spec(self, stream_definition: Mapping[str, Any]) -> AssetSpec:
default_spec = super().get_asset_spec(stream_definition)
return default_spec.replace_attributes(
kinds=["sling", "foo"], tags={"custom_tag": "custom_value"}
kinds={"sling", "foo"}, tags={"custom_tag": "custom_value"}
)

@sling_assets(
replication_config=replication_config_path,
dagster_sling_translator=CustomSlingTranslator(),
)
def my_sling_assets(): ...

for asset_key in my_sling_assets.keys:
assert my_sling_assets.tags_by_key[asset_key] == {
"custom_tag": "custom_value",
**build_kind_tag("sling"),
**build_kind_tag("foo"),
}


def test_base_with_custom_tags_translator_legacy() -> None:
replication_config_path = file_relative_path(
__file__, "replication_configs/base_with_default_meta/replication.yaml"
)

class CustomSlingTranslator(DagsterSlingTranslator):
def get_tags(self, stream_definition):
return {"custom_tag": "custom_value"}

Expand Down

0 comments on commit 71ad256

Please sign in to comment.