Skip to content

Commit

Permalink
[dagster-airbyte] Add connection_meta_to_group_fn argument
Browse files Browse the repository at this point in the history
  • Loading branch information
benpankow committed Nov 29, 2023
1 parent 6187f7c commit 3ca6b6d
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def __init__(
self,
key_prefix: Sequence[str],
create_assets_for_normalization_tables: bool,
connection_to_group_fn: Optional[Callable[[str], Optional[str]]],
connection_meta_to_group_fn: Optional[Callable[[AirbyteConnectionMetadata], Optional[str]]],
connection_to_io_manager_key_fn: Optional[Callable[[str], Optional[str]]],
connection_filter: Optional[Callable[[AirbyteConnectionMetadata], bool]],
connection_to_asset_key_fn: Optional[Callable[[AirbyteConnectionMetadata, str], AssetKey]],
Expand All @@ -534,7 +534,7 @@ def __init__(
):
self._key_prefix = key_prefix
self._create_assets_for_normalization_tables = create_assets_for_normalization_tables
self._connection_to_group_fn = connection_to_group_fn
self._connection_meta_to_group_fn = connection_meta_to_group_fn
self._connection_to_io_manager_key_fn = connection_to_io_manager_key_fn
self._connection_filter = connection_filter
self._connection_to_asset_key_fn: Callable[[AirbyteConnectionMetadata, str], AssetKey] = (
Expand Down Expand Up @@ -577,8 +577,8 @@ def compute_cacheable_data(self) -> Sequence[AssetsDefinitionCacheableData]:
},
asset_key_prefix=self._key_prefix,
group_name=(
self._connection_to_group_fn(connection.name)
if self._connection_to_group_fn
self._connection_meta_to_group_fn(connection)
if self._connection_meta_to_group_fn
else None
),
io_manager_key=(
Expand Down Expand Up @@ -616,7 +616,7 @@ def __init__(
workspace_id: Optional[str],
key_prefix: Sequence[str],
create_assets_for_normalization_tables: bool,
connection_to_group_fn: Optional[Callable[[str], Optional[str]]],
connection_meta_to_group_fn: Optional[Callable[[AirbyteConnectionMetadata], Optional[str]]],
connection_to_io_manager_key_fn: Optional[Callable[[str], Optional[str]]],
connection_filter: Optional[Callable[[AirbyteConnectionMetadata], bool]],
connection_to_asset_key_fn: Optional[Callable[[AirbyteConnectionMetadata, str], AssetKey]],
Expand All @@ -630,7 +630,7 @@ def __init__(
super().__init__(
key_prefix=key_prefix,
create_assets_for_normalization_tables=create_assets_for_normalization_tables,
connection_to_group_fn=connection_to_group_fn,
connection_meta_to_group_fn=connection_meta_to_group_fn,
connection_to_io_manager_key_fn=connection_to_io_manager_key_fn,
connection_filter=connection_filter,
connection_to_asset_key_fn=connection_to_asset_key_fn,
Expand Down Expand Up @@ -715,7 +715,7 @@ def __init__(
workspace_id: Optional[str],
key_prefix: Sequence[str],
create_assets_for_normalization_tables: bool,
connection_to_group_fn: Optional[Callable[[str], Optional[str]]],
connection_meta_to_group_fn: Optional[Callable[[AirbyteConnectionMetadata], Optional[str]]],
connection_to_io_manager_key_fn: Optional[Callable[[str], Optional[str]]],
connection_filter: Optional[Callable[[AirbyteConnectionMetadata], bool]],
connection_directories: Optional[Sequence[str]],
Expand All @@ -730,7 +730,7 @@ def __init__(
super().__init__(
key_prefix=key_prefix,
create_assets_for_normalization_tables=create_assets_for_normalization_tables,
connection_to_group_fn=connection_to_group_fn,
connection_meta_to_group_fn=connection_meta_to_group_fn,
connection_to_io_manager_key_fn=connection_to_io_manager_key_fn,
connection_filter=connection_filter,
connection_to_asset_key_fn=connection_to_asset_key_fn,
Expand Down Expand Up @@ -793,6 +793,9 @@ def load_assets_from_airbyte_instance(
key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
create_assets_for_normalization_tables: bool = True,
connection_to_group_fn: Optional[Callable[[str], Optional[str]]] = _clean_name,
connection_meta_to_group_fn: Optional[
Callable[[AirbyteConnectionMetadata], Optional[str]]
] = None,
io_manager_key: Optional[str] = None,
connection_to_io_manager_key_fn: Optional[Callable[[str], Optional[str]]] = None,
connection_filter: Optional[Callable[[AirbyteConnectionMetadata], bool]] = None,
Expand Down Expand Up @@ -822,6 +825,9 @@ def load_assets_from_airbyte_instance(
connection_to_group_fn (Optional[Callable[[str], Optional[str]]]): Function which returns an asset
group name for a given Airbyte connection name. If None, no groups will be created. Defaults
to a basic sanitization function.
connection_meta_to_group_fn (Optional[Callable[[AirbyteConnectionMetadata], Optional[str]]]): Function which
returns an asset group name for a given Airbyte connection metadata. If None and connection_to_group_fn
is None, no groups will be created
io_manager_key (Optional[str]): The I/O manager key to use for all assets. Defaults to "io_manager".
Use this if all assets should be loaded from the same source, otherwise use connection_to_io_manager_key_fn.
connection_to_io_manager_key_fn (Optional[Callable[[str], Optional[str]]]): Function which returns an
Expand Down Expand Up @@ -888,12 +894,22 @@ def load_assets_from_airbyte_instance(
if not connection_to_io_manager_key_fn:
connection_to_io_manager_key_fn = lambda _: io_manager_key

check.invariant(
not connection_meta_to_group_fn
or not connection_to_group_fn
or connection_to_group_fn == _clean_name,
"Cannot specify both connection_meta_to_group_fn and connection_to_group_fn",
)

if not connection_meta_to_group_fn and connection_to_group_fn:
connection_meta_to_group_fn = lambda meta: connection_to_group_fn(meta.name)

return AirbyteInstanceCacheableAssetsDefinition(
airbyte_resource_def=airbyte,
workspace_id=workspace_id,
key_prefix=key_prefix,
create_assets_for_normalization_tables=create_assets_for_normalization_tables,
connection_to_group_fn=connection_to_group_fn,
connection_meta_to_group_fn=connection_meta_to_group_fn,
connection_to_io_manager_key_fn=connection_to_io_manager_key_fn,
connection_filter=connection_filter,
connection_to_asset_key_fn=connection_to_asset_key_fn,
Expand All @@ -908,6 +924,9 @@ def load_assets_from_airbyte_project(
key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
create_assets_for_normalization_tables: bool = True,
connection_to_group_fn: Optional[Callable[[str], Optional[str]]] = _clean_name,
connection_meta_to_group_fn: Optional[
Callable[[AirbyteConnectionMetadata], Optional[str]]
] = None,
io_manager_key: Optional[str] = None,
connection_to_io_manager_key_fn: Optional[Callable[[str], Optional[str]]] = None,
connection_filter: Optional[Callable[[AirbyteConnectionMetadata], bool]] = None,
Expand Down Expand Up @@ -939,6 +958,9 @@ def load_assets_from_airbyte_project(
connection_to_group_fn (Optional[Callable[[str], Optional[str]]]): Function which returns an asset
group name for a given Airbyte connection name. If None, no groups will be created. Defaults
to a basic sanitization function.
connection_meta_to_group_fn (Optional[Callable[[AirbyteConnectionMetadata], Optional[str]]]): Function
which returns an asset group name for a given Airbyte connection metadata. If None and connection_to_group_fn
is None, no groups will be created. Defaults to None.
io_manager_key (Optional[str]): The I/O manager key to use for all assets. Defaults to "io_manager".
Use this if all assets should be loaded from the same source, otherwise use connection_to_io_manager_key_fn.
connection_to_io_manager_key_fn (Optional[Callable[[str], Optional[str]]]): Function which returns an
Expand Down Expand Up @@ -993,12 +1015,22 @@ def load_assets_from_airbyte_project(
if not connection_to_io_manager_key_fn:
connection_to_io_manager_key_fn = lambda _: io_manager_key

check.invariant(
not connection_meta_to_group_fn
or not connection_to_group_fn
or connection_to_group_fn == _clean_name,
"Cannot specify both connection_meta_to_group_fn and connection_to_group_fn",
)

if not connection_meta_to_group_fn and connection_to_group_fn:
connection_meta_to_group_fn = lambda meta: connection_to_group_fn(meta.name)

return AirbyteYAMLCacheableAssetsDefinition(
project_dir=project_dir,
workspace_id=workspace_id,
key_prefix=key_prefix,
create_assets_for_normalization_tables=create_assets_for_normalization_tables,
connection_to_group_fn=connection_to_group_fn,
connection_meta_to_group_fn=connection_meta_to_group_fn,
connection_to_io_manager_key_fn=connection_to_io_manager_key_fn,
connection_filter=connection_filter,
connection_directories=connection_directories,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ def __init__(
airbyte_resource_def: AirbyteResource,
key_prefix: Sequence[str],
create_assets_for_normalization_tables: bool,
connection_to_group_fn: Optional[Callable[[str], Optional[str]]],
connection_meta_to_group_fn: Optional[Callable[[AirbyteConnectionMetadata], Optional[str]]],
connections: Iterable[AirbyteConnection],
connection_to_io_manager_key_fn: Optional[Callable[[str], Optional[str]]],
connection_to_asset_key_fn: Optional[Callable[[AirbyteConnectionMetadata, str], AssetKey]],
Expand All @@ -719,7 +719,7 @@ def __init__(
workspace_id=None,
key_prefix=key_prefix,
create_assets_for_normalization_tables=create_assets_for_normalization_tables,
connection_to_group_fn=connection_to_group_fn,
connection_meta_to_group_fn=connection_meta_to_group_fn,
connection_to_io_manager_key_fn=connection_to_io_manager_key_fn,
connection_filter=lambda conn: conn.name in defined_conn_names,
connection_to_asset_key_fn=connection_to_asset_key_fn,
Expand Down Expand Up @@ -749,6 +749,9 @@ def load_assets_from_connections(
key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
create_assets_for_normalization_tables: bool = True,
connection_to_group_fn: Optional[Callable[[str], Optional[str]]] = _clean_name,
connection_meta_to_group_fn: Optional[
Callable[[AirbyteConnectionMetadata], Optional[str]]
] = None,
io_manager_key: Optional[str] = None,
connection_to_io_manager_key_fn: Optional[Callable[[str], Optional[str]]] = None,
connection_to_asset_key_fn: Optional[
Expand All @@ -772,6 +775,9 @@ def load_assets_from_connections(
connection_to_group_fn (Optional[Callable[[str], Optional[str]]]): Function which returns an asset
group name for a given Airbyte connection name. If None, no groups will be created. Defaults
to a basic sanitization function.
connection_meta_to_group_fn (Optional[Callable[[AirbyteConnectionMetadata], Optional[str]]]): Function which
returns an asset group name for a given Airbyte connection metadata. If None and connection_to_group_fn
is None, no groups will be created. Defaults to None.
io_manager_key (Optional[str]): The IO manager key to use for all assets. Defaults to "io_manager".
Use this if all assets should be loaded from the same source, otherwise use connection_to_io_manager_key_fn.
connection_to_io_manager_key_fn (Optional[Callable[[str], Optional[str]]]): Function which returns an
Expand Down Expand Up @@ -814,6 +820,16 @@ def load_assets_from_connections(
if not connection_to_io_manager_key_fn:
connection_to_io_manager_key_fn = lambda _: io_manager_key

check.invariant(
not connection_meta_to_group_fn
or not connection_to_group_fn
or connection_to_group_fn == _clean_name,
"Cannot specify both connection_meta_to_group_fn and connection_to_group_fn",
)

if not connection_meta_to_group_fn and connection_to_group_fn:
connection_meta_to_group_fn = lambda meta: connection_to_group_fn(meta.name)

return AirbyteManagedElementCacheableAssetsDefinition(
airbyte_resource_def=(
airbyte
Expand All @@ -824,8 +840,8 @@ def load_assets_from_connections(
create_assets_for_normalization_tables=check.bool_param(
create_assets_for_normalization_tables, "create_assets_for_normalization_tables"
),
connection_to_group_fn=check.opt_callable_param(
connection_to_group_fn, "connection_to_group_fn"
connection_meta_to_group_fn=check.opt_callable_param(
connection_meta_to_group_fn, "connection_meta_to_group_fn"
),
connection_to_io_manager_key_fn=connection_to_io_manager_key_fn,
connections=check.iterable_param(connections, "connections", of_type=AirbyteConnection),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ def airbyte_instance_fixture(request):

@responses.activate
@pytest.mark.parametrize("use_normalization_tables", [True, False])
@pytest.mark.parametrize("connection_to_group_fn", [None, lambda x: f"{x[0]}_group"])
@pytest.mark.parametrize(
"connection_to_group_fn, connection_meta_to_group_fn",
[(None, lambda meta: f"{meta.name[0]}_group"), (None, None), (lambda x: f"{x[0]}_group", None)],
)
@pytest.mark.parametrize("filter_connection", [True, False])
@pytest.mark.parametrize(
"connection_to_asset_key_fn", [None, lambda conn, name: AssetKey([f"{conn.name[0]}_{name}"])]
Expand All @@ -64,6 +67,7 @@ def airbyte_instance_fixture(request):
def test_load_from_instance(
use_normalization_tables,
connection_to_group_fn,
connection_meta_to_group_fn,
filter_connection,
connection_to_asset_key_fn,
connection_to_freshness_policy_fn,
Expand Down Expand Up @@ -109,6 +113,7 @@ def load_input(self, context: InputContext) -> Any:
airbyte_instance,
create_assets_for_normalization_tables=use_normalization_tables,
connection_to_group_fn=connection_to_group_fn,
connection_meta_to_group_fn=connection_meta_to_group_fn,
connection_filter=(lambda _: False) if filter_connection else None,
connection_to_io_manager_key_fn=(lambda _: "test_io_manager"),
connection_to_asset_key_fn=connection_to_asset_key_fn,
Expand All @@ -121,6 +126,7 @@ def load_input(self, context: InputContext) -> Any:
create_assets_for_normalization_tables=use_normalization_tables,
connection_filter=(lambda _: False) if filter_connection else None,
io_manager_key="test_io_manager",
connection_meta_to_group_fn=connection_meta_to_group_fn,
connection_to_asset_key_fn=connection_to_asset_key_fn,
connection_to_freshness_policy_fn=connection_to_freshness_policy_fn,
connection_to_auto_materialize_policy_fn=connection_to_auto_materialize_policy_fn,
Expand Down Expand Up @@ -213,9 +219,15 @@ def downstream_asset(dagster_tags):
[
ab_assets[0].group_names_by_key.get(AssetKey(t))
== (
connection_to_group_fn("GitHub <> snowflake-ben")
if connection_to_group_fn
else "github_snowflake_ben"
connection_meta_to_group_fn(
AirbyteConnectionMetadata("GitHub <> snowflake-ben", "", False, [])
)
if connection_meta_to_group_fn
else (
connection_to_group_fn("GitHub <> snowflake-ben")
if connection_to_group_fn
else "github_snowflake_ben"
)
)
for t in tables
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,18 @@ def airbyte_instance_fixture(request) -> AirbyteResource:

@responses.activate
@pytest.mark.parametrize("use_normalization_tables", [True, False])
@pytest.mark.parametrize("connection_to_group_fn", [None, lambda x: f"{x[0]}_group"])
@pytest.mark.parametrize(
"connection_to_group_fn, connection_meta_to_group_fn",
[(None, lambda meta: f"{meta.name[0]}_group"), (None, None), (lambda x: f"{x[0]}_group", None)],
)
@pytest.mark.parametrize("filter_connection", [None, "filter_fn", "dirs"])
@pytest.mark.parametrize(
"connection_to_asset_key_fn", [None, lambda conn, name: AssetKey([f"{conn.name[0]}_{name}"])]
)
def test_load_from_project(
use_normalization_tables,
connection_to_group_fn,
connection_meta_to_group_fn,
filter_connection,
connection_to_asset_key_fn,
airbyte_instance,
Expand All @@ -37,6 +41,7 @@ def test_load_from_project(
file_relative_path(__file__, "./test_airbyte_project"),
create_assets_for_normalization_tables=use_normalization_tables,
connection_to_group_fn=connection_to_group_fn,
connection_meta_to_group_fn=connection_meta_to_group_fn,
connection_filter=(lambda _: False) if filter_connection == "filter_fn" else None,
connection_directories=(
["github_snowflake_ben"] if filter_connection == "dirs" else None
Expand All @@ -47,6 +52,7 @@ def test_load_from_project(
ab_cacheable_assets = load_assets_from_airbyte_project(
file_relative_path(__file__, "./test_airbyte_project"),
create_assets_for_normalization_tables=use_normalization_tables,
connection_meta_to_group_fn=connection_meta_to_group_fn,
connection_filter=(lambda _: False) if filter_connection == "filter_fn" else None,
connection_directories=(
["github_snowflake_ben"] if filter_connection == "dirs" else None
Expand Down Expand Up @@ -93,9 +99,17 @@ def test_load_from_project(
[
ab_assets[0].group_names_by_key.get(AssetKey(t))
== (
connection_to_group_fn("GitHub <> snowflake-ben")
if connection_to_group_fn
else "github_snowflake_ben"
connection_meta_to_group_fn(
AirbyteConnectionMetadata(
"GitHub <> snowflake-ben", "", use_normalization_tables, []
)
)
if connection_meta_to_group_fn
else (
connection_to_group_fn("GitHub <> snowflake-ben")
if connection_to_group_fn
else "github_snowflake_ben"
)
)
for t in tables
]
Expand Down

0 comments on commit 3ca6b6d

Please sign in to comment.