Skip to content

Commit

Permalink
[dagster-airlift][partitions] add a default implementation with time …
Browse files Browse the repository at this point in the history
…partition handling to airlift (#25096)

## Summary & Motivation
Take 3 of adding time window partition handling to airlift.
Most of the implementation here has already been gone over. The only new
bits are how it relates to the top-level pluggable API we decided on. It
makes things a little more weird because we're transforming the asset
materializations after the fact.

## How I Tested These Changes
Same unit test battery as last time, minus the timezone test. We no
longer require the same timezone as long as timestamps match up.
## Changelog

Insert changelog entry or "NOCHANGELOG" here.
NOCHANGELOG
  • Loading branch information
dpeng817 authored Oct 15, 2024
1 parent 3234086 commit dfb8e60
Show file tree
Hide file tree
Showing 8 changed files with 331 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
TASK_MAPPING_METADATA_KEY = "dagster-airlift/task-mapping"
AUTOMAPPED_TASK_METADATA_KEY = "dagster-airlift/automapped-task"
# This represents the timestamp used in ordering the materializatons.
EFFECTIVE_TIMESTAMP_METADATA_KEY = "dagster-airlift/effective_timestamp"
EFFECTIVE_TIMESTAMP_METADATA_KEY = "dagster-airlift/effective-timestamp"
AIRFLOW_TASK_INSTANCE_LOGICAL_DATE_METADATA_KEY = (
"dagster-airlift/airflow-task-instance-logical-date"
)
AIRFLOW_RUN_ID_METADATA_KEY = "dagster-airlift/airflow-run-id"
DAG_RUN_ID_TAG_KEY = "dagster-airlift/airflow-dag-run-id"
DAG_ID_TAG_KEY = "dagster-airlift/airflow-dag-id"
TASK_ID_TAG_KEY = "dagster-airlift/airflow-task-id"
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import cached_property
from typing import AbstractSet, Mapping, Set

from dagster import AssetKey, Definitions
from dagster import AssetKey, AssetSpec, Definitions
from dagster._record import record

from dagster_airlift.core.airflow_instance import AirflowInstance
Expand Down Expand Up @@ -31,6 +31,10 @@ def instance_name(self) -> str:
def mapping_info(self) -> AirliftMetadataMappingInfo:
return AirliftMetadataMappingInfo(asset_specs=list(self.mapped_defs.get_all_asset_specs()))

@cached_property
def all_asset_specs_by_key(self) -> Mapping[AssetKey, AssetSpec]:
return {spec.key: spec for spec in self.mapped_defs.get_all_asset_specs()}

def task_ids_in_dag(self, dag_id: str) -> Set[str]:
return self.mapping_info.task_id_map[dag_id]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from dagster._utils.warnings import suppress_dagster_warnings

from dagster_airlift.core.airflow_instance import AirflowInstance
from dagster_airlift.core.sensor.event_translation import DagsterEventTransformerFn
from dagster_airlift.core.sensor.event_translation import (
DagsterEventTransformerFn,
default_event_transformer,
)
from dagster_airlift.core.sensor.sensor_builder import (
DEFAULT_AIRFLOW_SENSOR_INTERVAL_SECONDS,
build_airflow_polling_sensor_defs,
Expand Down Expand Up @@ -68,7 +71,7 @@ def build_defs_from_airflow_instance(
airflow_instance: AirflowInstance,
defs: Optional[Definitions] = None,
sensor_minimum_interval_seconds: int = DEFAULT_AIRFLOW_SENSOR_INTERVAL_SECONDS,
event_transformer_fn: Optional[DagsterEventTransformerFn] = None,
event_transformer_fn: DagsterEventTransformerFn = default_event_transformer,
) -> Definitions:
mapped_defs = build_airflow_mapped_defs(airflow_instance=airflow_instance, defs=defs)
return Definitions.merge(
Expand Down Expand Up @@ -123,7 +126,6 @@ def build_full_automapped_dags_from_airflow_instance(
minimum_interval_seconds=sensor_minimum_interval_seconds,
mapped_defs=resolved_defs,
airflow_instance=airflow_instance,
event_transformer_fn=None,
),
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import AbstractSet, Any, Callable, Iterable, Mapping, Sequence, Union
from collections import defaultdict
from typing import AbstractSet, Any, Callable, Iterable, Mapping, Sequence, Union, cast

from dagster import (
AssetMaterialization,
Expand All @@ -11,9 +12,14 @@
)
from dagster._core.definitions.asset_check_evaluation import AssetCheckEvaluation
from dagster._core.definitions.asset_key import AssetKey
from dagster._time import get_current_timestamp
from dagster._core.definitions.time_window_partitions import TimeWindowPartitionsDefinition
from dagster._time import datetime_from_timestamp, get_current_timestamp

from dagster_airlift.constants import EFFECTIVE_TIMESTAMP_METADATA_KEY
from dagster_airlift.constants import (
AIRFLOW_RUN_ID_METADATA_KEY,
AIRFLOW_TASK_INSTANCE_LOGICAL_DATE_METADATA_KEY,
EFFECTIVE_TIMESTAMP_METADATA_KEY,
)
from dagster_airlift.core.airflow_defs_data import AirflowDefinitionsData
from dagster_airlift.core.airflow_instance import DagRun, TaskInstance
from dagster_airlift.core.serialization.serialized_data import DagHandle
Expand All @@ -25,6 +31,66 @@
]


def default_event_transformer(
context: SensorEvaluationContext,
airflow_data: AirflowDefinitionsData,
materializations: Sequence[AssetMaterialization],
) -> Iterable[AssetEvent]:
"""The default event transformer function, which attaches a partition key to materializations which are from time-window partitioned assets."""
cached_partition_calculations = defaultdict(dict)
for mat in materializations:
asset_spec = airflow_data.all_asset_specs_by_key[mat.asset_key]
if not asset_spec.partitions_def or not isinstance(
asset_spec.partitions_def, TimeWindowPartitionsDefinition
):
yield mat
continue
airflow_logical_date_timestamp: float = cast(
TimestampMetadataValue, mat.metadata[AIRFLOW_TASK_INSTANCE_LOGICAL_DATE_METADATA_KEY]
).value
partitions_def = cast(TimeWindowPartitionsDefinition, asset_spec.partitions_def)
calcs_for_def = cached_partition_calculations[partitions_def]
if airflow_logical_date_timestamp not in calcs_for_def:
cached_partition_calculations[partitions_def][airflow_logical_date_timestamp] = (
get_partition_key_from_timestamp(
partitions_def=cast(TimeWindowPartitionsDefinition, asset_spec.partitions_def),
timestamp=airflow_logical_date_timestamp,
)
)
partition = cached_partition_calculations[partitions_def][airflow_logical_date_timestamp]
partitioned_mat = mat._replace(partition=partition)
yield partitioned_mat


def get_partition_key_from_timestamp(
partitions_def: TimeWindowPartitionsDefinition,
timestamp: float,
) -> str:
datetime_in_tz = datetime_from_timestamp(timestamp, partitions_def.timezone)
# Assuming that "logical_date" lies on a partition, the previous partition window
# (where upper bound can be the passed-in date, which is why we set respect_bounds=False)
# will end on the logical date. This would indicate that there is a partition for the logical date.
partition_window = check.not_none(
partitions_def.get_prev_partition_window(datetime_in_tz, respect_bounds=False),
f"Could not find partition for airflow logical date {datetime_in_tz.isoformat()}. This likely means that your partition range is too small to cover the logical date.",
)
check.invariant(
datetime_in_tz.timestamp() == partition_window.end.timestamp(),
(
"Expected logical date to match a partition in the partitions definition. This likely means that "
"The partition range is not aligned with the scheduling interval in airflow."
),
)
check.invariant(
datetime_in_tz.timestamp() >= partitions_def.start.timestamp(),
(
"provided date is before the start of the partitions definition. "
"Ensure that the start date of your PartitionsDefinition is early enough to capture the provided date {datetime_in_tz.isoformat()}."
),
)
return partitions_def.get_partition_key_for_timestamp(timestamp)


def get_timestamp_from_materialization(event: AssetEvent) -> float:
return check.float_param(
event.metadata[EFFECTIVE_TIMESTAMP_METADATA_KEY].value,
Expand Down Expand Up @@ -71,6 +137,7 @@ def get_dag_run_metadata(dag_run: DagRun) -> Mapping[str, Any]:
def get_common_metadata(dag_run: DagRun) -> Mapping[str, Any]:
return {
"Airflow Run ID": dag_run.run_id,
AIRFLOW_RUN_ID_METADATA_KEY: dag_run.run_id,
"Run Metadata (raw)": JsonMetadataValue(dag_run.metadata),
"Run Type": dag_run.run_type,
"Airflow Config": JsonMetadataValue(dag_run.config),
Expand All @@ -88,6 +155,9 @@ def get_task_instance_metadata(dag_run: DagRun, task_instance: TaskInstance) ->
EFFECTIVE_TIMESTAMP_METADATA_KEY: TimestampMetadataValue(
task_instance.end_date.timestamp()
),
AIRFLOW_TASK_INSTANCE_LOGICAL_DATE_METADATA_KEY: TimestampMetadataValue(
task_instance.logical_date.timestamp()
),
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from dagster_airlift.core.sensor.event_translation import (
AssetEvent,
DagsterEventTransformerFn,
default_event_transformer,
get_timestamp_from_materialization,
synthetic_mats_for_mapped_asset_keys,
synthetic_mats_for_mapped_dag_asset_keys,
Expand Down Expand Up @@ -81,7 +82,7 @@ def build_airflow_polling_sensor_defs(
*,
mapped_defs: Definitions,
airflow_instance: AirflowInstance,
event_transformer_fn: Optional[DagsterEventTransformerFn],
event_transformer_fn: DagsterEventTransformerFn = default_event_transformer,
minimum_interval_seconds: int = DEFAULT_AIRFLOW_SENSOR_INTERVAL_SECONDS,
) -> Definitions:
"""The constructed sensor polls the Airflow instance for activity, and inserts asset events into Dagster's event log.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def make_instance(
- timedelta(
seconds=1
), # Ensure that the task ends before the full "dag" completes.
logical_date=dag_run.logical_date,
)
for task_id in dag_and_task_structure[dag_run.dag_id]
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
from dagster_airlift.core import (
build_defs_from_airflow_instance as build_defs_from_airflow_instance,
)
from dagster_airlift.core.sensor.event_translation import DagsterEventTransformerFn
from dagster_airlift.core.sensor.event_translation import (
DagsterEventTransformerFn,
default_event_transformer,
)
from dagster_airlift.core.utils import metadata_for_dag_mapping, metadata_for_task_mapping
from dagster_airlift.test import make_dag_run, make_instance

Expand All @@ -42,7 +45,7 @@ def fully_loaded_repo_from_airflow_asset_graph(
additional_defs: Definitions = Definitions(),
create_runs: bool = True,
dag_level_asset_overrides: Optional[Dict[str, List[str]]] = None,
event_transformer_fn: Optional[DagsterEventTransformerFn] = None,
event_transformer_fn: DagsterEventTransformerFn = default_event_transformer,
) -> RepositoryDefinition:
defs = load_definitions_airflow_asset_graph(
assets_per_task,
Expand All @@ -62,7 +65,7 @@ def load_definitions_airflow_asset_graph(
create_runs: bool = True,
create_assets_defs: bool = True,
dag_level_asset_overrides: Optional[Dict[str, List[str]]] = None,
event_transformer_fn: Optional[DagsterEventTransformerFn] = None,
event_transformer_fn: DagsterEventTransformerFn = default_event_transformer,
) -> Definitions:
assets = []
dag_and_task_structure = defaultdict(list)
Expand Down Expand Up @@ -133,7 +136,7 @@ def build_and_invoke_sensor(
instance: DagsterInstance,
additional_defs: Definitions = Definitions(),
dag_level_asset_overrides: Optional[Dict[str, List[str]]] = None,
event_transformer_fn: Optional[DagsterEventTransformerFn] = None,
event_transformer_fn: DagsterEventTransformerFn = default_event_transformer,
) -> Tuple[SensorResult, SensorEvaluationContext]:
repo_def = fully_loaded_repo_from_airflow_asset_graph(
assets_per_task,
Expand Down
Loading

0 comments on commit dfb8e60

Please sign in to comment.