diff --git a/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/dagster_operator.py b/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/dagster_operator.py index 2f3fca60a0a13..e6a75d3ab64e3 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/dagster_operator.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/dagster_operator.py @@ -7,6 +7,8 @@ from airflow.models.operator import BaseOperator from airflow.operators.python import PythonOperator +from dagster_airlift.core.utils import DAG_ID_TAG, TASK_ID_TAG + from .gql_queries import ASSET_NODES_QUERY, RUNS_QUERY, TRIGGER_ASSETS_MUTATION logger = logging.getLogger(__name__) @@ -23,11 +25,16 @@ def compute_fn() -> None: def launch_runs_for_task(dag_id: str, task_id: str, dagster_url: str) -> None: expected_op_name = f"{dag_id}__{task_id}" + assets_to_trigger = {} # key is (repo_location, repo_name, job_name), value is list of asset keys # create graphql client response = requests.post(f"{dagster_url}/graphql", json={"query": ASSET_NODES_QUERY}, timeout=3) for asset_node in response.json()["data"]["assetNodes"]: - if asset_node["opName"] == expected_op_name: + tags = {tag["key"]: tag["value"] for tag in asset_node["tags"]} + # match assets based on conventional dag_id__task_id naming or based on explicit tags + if asset_node["opName"] == expected_op_name or ( + tags.get(DAG_ID_TAG) == dag_id and tags.get(TASK_ID_TAG) == task_id + ): repo_location = asset_node["jobs"][0]["repository"]["location"]["name"] repo_name = asset_node["jobs"][0]["repository"]["name"] job_name = asset_node["jobs"][0]["name"] diff --git a/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/gql_queries.py b/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/gql_queries.py index 6447612204395..f4de7bfab76a6 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/gql_queries.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/gql_queries.py @@ -5,6 +5,10 @@ assetKey { path } + tags { + key + value + } opName jobs { id diff --git a/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo/dags/switcheroo_dag.py b/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo_convention/dags/switcheroo_dag.py similarity index 100% rename from examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo/dags/switcheroo_dag.py rename to examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo_convention/dags/switcheroo_dag.py diff --git a/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo/dagster_defs.py b/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo_convention/dagster_defs.py similarity index 100% rename from examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo/dagster_defs.py rename to examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo_convention/dagster_defs.py diff --git a/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo_tags/dags/switcheroo_dag.py b/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo_tags/dags/switcheroo_dag.py new file mode 100644 index 0000000000000..65c59ccb91aaf --- /dev/null +++ b/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo_tags/dags/switcheroo_dag.py @@ -0,0 +1,68 @@ +import logging +import os +from datetime import datetime + +from airflow import DAG +from airflow.operators.python import PythonOperator +from dagster_airlift.in_airflow import mark_as_dagster_migrating +from dagster_airlift.migration_state import ( + AirflowMigrationState, + DagMigrationState, + TaskMigrationState, +) + +logging.basicConfig() +logging.getLogger().setLevel(logging.INFO) +requests_log = logging.getLogger("requests.packages.urllib3") +requests_log.setLevel(logging.INFO) +requests_log.propagate = True + + +def write_to_file_in_airflow_home() -> None: + airflow_home = os.environ["AIRFLOW_HOME"] + with open(os.path.join(airflow_home, "airflow_home_file.txt"), "w") as f: + f.write("Hello") + + +def write_to_other_file_in_airflow_home() -> None: + airflow_home = os.environ["AIRFLOW_HOME"] + with open(os.path.join(airflow_home, "other_airflow_home_file.txt"), "w") as f: + f.write("Hello") + + +default_args = { + "owner": "airflow", + "depends_on_past": False, + "start_date": datetime(2023, 1, 1), + "retries": 1, +} + +dag = DAG( + "the_dag", default_args=default_args, schedule_interval=None, is_paused_upon_creation=False +) +op_to_migrate = PythonOperator( + task_id="some_task", python_callable=write_to_file_in_airflow_home, dag=dag +) +op_doesnt_migrate = PythonOperator( + task_id="other_task", python_callable=write_to_other_file_in_airflow_home, dag=dag +) +# Add a dependency between the two tasks +op_doesnt_migrate.set_upstream(op_to_migrate) + +# # set up the debugger +# print("Waiting for debugger to attach...") +# debugpy.listen(("localhost", 7778)) +# debugpy.wait_for_client() +mark_as_dagster_migrating( + global_vars=globals(), + migration_state=AirflowMigrationState( + dags={ + "the_dag": DagMigrationState( + tasks={ + "some_task": TaskMigrationState(task_id="some_task", migrated=True), + "other_task": TaskMigrationState(task_id="other_task", migrated=True), + } + ) + } + ), +) diff --git a/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo_tags/dagster_defs.py b/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo_tags/dagster_defs.py new file mode 100644 index 0000000000000..0f4b03261e96d --- /dev/null +++ b/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo_tags/dagster_defs.py @@ -0,0 +1,13 @@ +from dagster import Definitions, asset +from dagster_airlift.core import dag_defs, task_defs + + +@asset +def my_asset_for_some_task(): + return "asset_value" + + +defs = dag_defs( + "the_dag", + task_defs("some_task", Definitions(assets=[my_asset_for_some_task])), +) diff --git a/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/test_operator_switcheroo.py b/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/test_operator_switcheroo.py index 9d8f72ec6621b..1dd3731f69418 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/test_operator_switcheroo.py +++ b/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/test_operator_switcheroo.py @@ -7,14 +7,30 @@ from dagster._time import get_current_timestamp +# Two different test targets +# The first uses convention-based binding of assets to tasks, e.g. +# op named the_dag__some_task +# The second uses `dag_defs` and `task_defs` to attach tags to assets, which +# in turn are used to bind assets to tasks. +@pytest.fixture( + name="test_dir", + params=[ + "airflow_op_switcheroo_convention", + "airflow_op_switcheroo_tags", + ], +) +def test_dir_fixture(request: pytest.FixtureRequest) -> Path: + return Path(__file__).parent / request.param + + @pytest.fixture(name="dags_dir") -def setup_dags_dir() -> Path: - return Path(__file__).parent / "airflow_op_switcheroo" / "dags" +def setup_dags_dir(test_dir: Path) -> Path: + return test_dir / "dags" @pytest.fixture(name="dagster_defs_path") -def setup_dagster_defs_path() -> str: - return str(Path(__file__).parent / "airflow_op_switcheroo" / "dagster_defs.py") +def setup_dagster_defs_path(test_dir: Path) -> str: + return str(test_dir / "dagster_defs.py") def test_migrated_operator(airflow_instance: None, dagster_dev: None) -> None: @@ -50,5 +66,6 @@ def test_migrated_operator(airflow_instance: None, dagster_dev: None) -> None: run for run in runs if set(list(run.asset_selection)) == {AssetKey(["the_dag__some_task"])} # type: ignore + or set(list(run.asset_selection)) == {AssetKey(["my_asset_for_some_task"])} # type: ignore ][0] assert some_task_run.status == DagsterRunStatus.SUCCESS