From 3b9df099ee33529d43cbf09a11a034b7532f3bb5 Mon Sep 17 00:00:00 2001 From: JamieDeMaria Date: Tue, 16 Jan 2024 13:27:16 -0500 Subject: [PATCH] api with subobject that contains all the methods --- .../partition_example/partition_example.py | 10 +- .../_core/execution/context/compute.py | 280 ++++++++++++------ .../test_asset_partition_mappings.py | 14 +- 3 files changed, 201 insertions(+), 103 deletions(-) diff --git a/examples/partition_example/partition_example.py b/examples/partition_example/partition_example.py index 5fc8e80c2499c..7a47ee081dcb4 100644 --- a/examples/partition_example/partition_example.py +++ b/examples/partition_example/partition_example.py @@ -33,7 +33,7 @@ def relativedelta(*args, **kwargs): metadata={"partition_expr": "LastModifiedDate"}, ) def salesforce_customers(context: AssetExecutionContext) -> pd.DataFrame: - start_date_str = context.asset_partition_key_for_output() + start_date_str = context.partition_context.partition_key timezone = pytz.timezone("GMT") # Replace 'Your_Timezone' with the desired timezone start_obj = datetime.datetime.strptime(start_date_str, "%Y-%m-%d").replace(tzinfo=timezone) @@ -65,7 +65,7 @@ def realized_vol(context: AssetExecutionContext, orats_daily_prices: pd.DataFram The volatility is calculated using various methods such as close-to-close, Parkinson, Hodges-Tompkins, and Yang-Zhang. The function returns a DataFrame with the calculated volatilities. """ - trade_date = context.asset_partition_key_for_output() + trade_date = context.partition_context.partition_key ticker_id = 1 df = all_realvols(orats_daily_prices, ticker_id, trade_date) @@ -80,7 +80,7 @@ def realized_vol(context: AssetExecutionContext, orats_daily_prices: pd.DataFram @asset(io_manager_def="parquet_io_manager", partitions_def=hourly_partitions) def my_custom_df(context: AssetExecutionContext) -> pd.DataFrame: - start, end = context.asset_partitions_time_window_for_output() + start, end = context.partition_context.partition_time_window df = pd.DataFrame({"timestamp": pd.date_range(start, end, freq="5T")}) df["count"] = df["timestamp"].map(lambda a: random.randint(1, 1000)) @@ -93,7 +93,7 @@ def fetch_blog_posts_from_external_api(*args, **kwargs): @asset(partitions_def=HourlyPartitionsDefinition(start_date="2022-01-01-00:00")) def blog_posts(context: AssetExecutionContext) -> List[Dict]: - partition_datetime_str = context.asset_partition_key_for_output() + partition_datetime_str = context.partition_context.partition_key hour = datetime.datetime.fromisoformat(partition_datetime_str) posts = fetch_blog_posts_from_external_api(hour_when_posted=hour) return posts @@ -106,7 +106,7 @@ def blog_posts(context: AssetExecutionContext) -> List[Dict]: key_prefix=["snowflake", "eldermark_proxy"], ) def resident(context: AssetExecutionContext) -> Output[pd.DataFrame]: - start, end = context.asset_partitions_time_window_for_output() + start, end = context.partition_context.partition_time_window filter_str = f"LastMod_Stamp >= {start.timestamp()} AND LastMod_Stamp < {end.timestamp()}" records = context.resources.eldermark.fetch_obj(obj="Resident", filter=filter_str) diff --git a/python_modules/dagster/dagster/_core/execution/context/compute.py b/python_modules/dagster/dagster/_core/execution/context/compute.py index fab5486935c53..84695128ba8a3 100644 --- a/python_modules/dagster/dagster/_core/execution/context/compute.py +++ b/python_modules/dagster/dagster/_core/execution/context/compute.py @@ -1374,6 +1374,16 @@ def _copy_docs_from_op_execution_context(obj): "dagster_run": "run", "run_config": "run.run_config", "run_tags": "run.tags", + "asset_partition_key_for_output": "partition_context.partition_key", + "asset_partitions_time_window_for_output": "partition_context.partition_time_window", + "asset_partition_key_range_for_output": "partition_context.partition_key_range", + "asset_partition_key_range_for_input": "partition_context.upstream_partition_key_range", + "asset_partition_key_for_input": "partition_context.upstream_partition_key", + "asset_partitions_def_for_output": "assets_def.partitions_def", + "asset_partitions_def_for_input": "partition_context.upstream_partitions_def", + "asset_partition_keys_for_output": "partition_context.partition_keys", + "asset_partition_keys_for_input": "partition_context.upstream_partition_keys", + "asset_partitions_time_window_for_input": "partition_context.upstream_partitions_time_window", } ALTERNATE_EXPRESSIONS = { @@ -1400,12 +1410,83 @@ def _get_deprecation_kwargs(attr: str): return deprecation_kwargs +class PartitionContext: # TODO - better name? + def __init__(self, op_execution_context: OpExecutionContext): + self.op_execution_context = op_execution_context + + @public + @property + def is_partitioned_materialization(self) -> bool: + return self.op_execution_context.has_partition_key + + @public + @property + def partition_key(self) -> str: + return self.op_execution_context.partition_key + + @public + @property + def partition_keys(self) -> Sequence[str]: + return self.op_execution_context.partition_keys + + @public + @property + def partition_key_range(self) -> PartitionKeyRange: + return self.op_execution_context.partition_key_range + + @public + @property + def partition_time_window(self) -> TimeWindow: + return self.op_execution_context.partition_time_window + + @public + def upstream_partition_key(self, key: CoercibleToAssetKey) -> str: + return self.op_execution_context._step_execution_context.asset_partition_key_for_upstream( # noqa: SLF001 + AssetKey.from_coercible(key) + ) + + @public + def upstream_partition_keys(self, key: CoercibleToAssetKey) -> Sequence[str]: + return list( + self.op_execution_context._step_execution_context.asset_partitions_subset_for_upstream( # noqa: SLF001 + AssetKey.from_coercible(key) + ).get_partition_keys() + ) + + @public + def upstream_partition_key_range(self, key: CoercibleToAssetKey) -> PartitionKeyRange: + return self.op_execution_context._step_execution_context.asset_partition_key_range_for_upstream( # noqa: SLF001 + AssetKey.from_coercible(key) + ) + + @public + def upstream_partitions_time_window( + self, key: CoercibleToAssetKey + ) -> TimeWindow: # TODO align on plurality of partition(s) + return self.op_execution_context._step_execution_context.asset_partitions_time_window_for_upstream( # noqa: SLF001 + AssetKey.from_coercible(key) + ) + + @public + def upstream_partitions_def(self, key: CoercibleToAssetKey) -> PartitionsDefinition: + result = self.op_execution_context._step_execution_context.job_def.asset_layer.partitions_def_for_asset( # noqa: SLF001 + AssetKey.from_coercible(key) + ) + if result is None: + raise DagsterInvariantViolationError( + f"Attempting to access partitions def for asset {key}, but it is not" " partitioned" + ) + + return result + + class AssetExecutionContext(OpExecutionContext): def __init__(self, op_execution_context: OpExecutionContext) -> None: self._op_execution_context = check.inst_param( op_execution_context, "op_execution_context", OpExecutionContext ) self._step_execution_context = self._op_execution_context._step_execution_context # noqa: SLF001 + self._partition_context = PartitionContext(op_execution_context=self._op_execution_context) @staticmethod def get() -> "AssetExecutionContext": @@ -1418,6 +1499,10 @@ def get() -> "AssetExecutionContext": def op_execution_context(self) -> OpExecutionContext: return self._op_execution_context + @property + def partition_context(self) -> PartitionContext: + return self._partition_context + ####### Top-level properties/methods on AssetExecutionContext @public @@ -1530,6 +1615,110 @@ def has_tag(self, key: str) -> bool: def get_tag(self, key: str) -> Optional[str]: return self.op_execution_context.get_tag(key) + @deprecated(**_get_deprecation_kwargs("has_partition_key")) + @public + @property + @_copy_docs_from_op_execution_context + def has_partition_key(self) -> bool: + return self.op_execution_context.has_partition_key + + @deprecated(**_get_deprecation_kwargs("partition_key")) + @public + @property + @_copy_docs_from_op_execution_context + def partition_key(self) -> str: + return self.op_execution_context.partition_key + + @deprecated(**_get_deprecation_kwargs("partition_keys")) + @public + @property + @_copy_docs_from_op_execution_context + def partition_keys(self) -> Sequence[str]: + return self.op_execution_context.partition_keys + + @deprecated(breaking_version="2.0", additional_warn_text="Use `partition_key_range` instead.") + @public + @property + @_copy_docs_from_op_execution_context + def asset_partition_key_range(self) -> PartitionKeyRange: + return self.op_execution_context.asset_partition_key_range + + @deprecated(**_get_deprecation_kwargs("partition_key_range")) + @public + @property + @_copy_docs_from_op_execution_context + def partition_key_range(self) -> PartitionKeyRange: + return self.op_execution_context.partition_key_range + + @deprecated(**_get_deprecation_kwargs("partition_time_window")) + @public + @property + @_copy_docs_from_op_execution_context + def partition_time_window(self) -> TimeWindow: + return self.op_execution_context.partition_time_window + + @deprecated(**_get_deprecation_kwargs("asset_partition_key_for_output")) + @public + @_copy_docs_from_op_execution_context + def asset_partition_key_for_output(self, output_name: str = "result") -> str: + return self.op_execution_context.asset_partition_key_for_output(output_name=output_name) + + @deprecated(**_get_deprecation_kwargs("asset_partitions_time_window_for_output")) + @public + @_copy_docs_from_op_execution_context + def asset_partitions_time_window_for_output(self, output_name: str = "result") -> TimeWindow: + return self.op_execution_context.asset_partitions_time_window_for_output(output_name) + + @deprecated(**_get_deprecation_kwargs("asset_partition_key_range_for_output")) + @public + @_copy_docs_from_op_execution_context + def asset_partition_key_range_for_output( + self, output_name: str = "result" + ) -> PartitionKeyRange: + return self.op_execution_context.asset_partition_key_range_for_output(output_name) + + @deprecated(**_get_deprecation_kwargs("asset_partition_key_range_for_input")) + @public + @_copy_docs_from_op_execution_context + def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRange: + return self.op_execution_context.asset_partition_key_range_for_input(input_name) + + @deprecated(**_get_deprecation_kwargs("asset_partition_key_for_input")) + @public + @_copy_docs_from_op_execution_context + def asset_partition_key_for_input(self, input_name: str) -> str: + return self.op_execution_context.asset_partition_key_for_input(input_name) + + @deprecated(**_get_deprecation_kwargs("asset_partitions_def_for_output")) + @public + @_copy_docs_from_op_execution_context + def asset_partitions_def_for_output(self, output_name: str = "result") -> PartitionsDefinition: + return self.op_execution_context.asset_partitions_def_for_output(output_name=output_name) + + @deprecated(**_get_deprecation_kwargs("asset_partitions_def_for_input")) + @public + @_copy_docs_from_op_execution_context + def asset_partitions_def_for_input(self, input_name: str) -> PartitionsDefinition: + return self.op_execution_context.asset_partitions_def_for_input(input_name=input_name) + + @deprecated(**_get_deprecation_kwargs("asset_partition_keys_for_output")) + @public + @_copy_docs_from_op_execution_context + def asset_partition_keys_for_output(self, output_name: str = "result") -> Sequence[str]: + return self.op_execution_context.asset_partition_keys_for_output(output_name=output_name) + + @deprecated(**_get_deprecation_kwargs("asset_partition_keys_for_input")) + @public + @_copy_docs_from_op_execution_context + def asset_partition_keys_for_input(self, input_name: str) -> Sequence[str]: + return self.op_execution_context.asset_partition_keys_for_input(input_name=input_name) + + @deprecated(**_get_deprecation_kwargs("asset_partitions_time_window_for_input")) + @public + @_copy_docs_from_op_execution_context + def asset_partitions_time_window_for_input(self, input_name: str = "result") -> TimeWindow: + return self.op_execution_context.asset_partitions_time_window_for_input(input_name) + ########## pass-through to op context #### op related @@ -1647,97 +1836,6 @@ def step_launcher(self) -> Optional[StepLauncher]: def get_step_execution_context(self) -> StepExecutionContext: return self.op_execution_context.get_step_execution_context() - #### partition_related - - @public - @property - @_copy_docs_from_op_execution_context - def has_partition_key(self) -> bool: - return self.op_execution_context.has_partition_key - - @public - @property - @_copy_docs_from_op_execution_context - def partition_key(self) -> str: - return self.op_execution_context.partition_key - - @public - @property - @_copy_docs_from_op_execution_context - def partition_keys(self) -> Sequence[str]: - return self.op_execution_context.partition_keys - - @deprecated(breaking_version="2.0", additional_warn_text="Use `partition_key_range` instead.") - @public - @property - @_copy_docs_from_op_execution_context - def asset_partition_key_range(self) -> PartitionKeyRange: - return self.op_execution_context.asset_partition_key_range - - @public - @property - @_copy_docs_from_op_execution_context - def partition_key_range(self) -> PartitionKeyRange: - return self.op_execution_context.partition_key_range - - @public - @property - @_copy_docs_from_op_execution_context - def partition_time_window(self) -> TimeWindow: - return self.op_execution_context.partition_time_window - - @public - @_copy_docs_from_op_execution_context - def asset_partition_key_for_output(self, output_name: str = "result") -> str: - return self.op_execution_context.asset_partition_key_for_output(output_name=output_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partitions_time_window_for_output(self, output_name: str = "result") -> TimeWindow: - return self.op_execution_context.asset_partitions_time_window_for_output(output_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partition_key_range_for_output( - self, output_name: str = "result" - ) -> PartitionKeyRange: - return self.op_execution_context.asset_partition_key_range_for_output(output_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRange: - return self.op_execution_context.asset_partition_key_range_for_input(input_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partition_key_for_input(self, input_name: str) -> str: - return self.op_execution_context.asset_partition_key_for_input(input_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partitions_def_for_output(self, output_name: str = "result") -> PartitionsDefinition: - return self.op_execution_context.asset_partitions_def_for_output(output_name=output_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partitions_def_for_input(self, input_name: str) -> PartitionsDefinition: - return self.op_execution_context.asset_partitions_def_for_input(input_name=input_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partition_keys_for_output(self, output_name: str = "result") -> Sequence[str]: - return self.op_execution_context.asset_partition_keys_for_output(output_name=output_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partition_keys_for_input(self, input_name: str) -> Sequence[str]: - return self.op_execution_context.asset_partition_keys_for_input(input_name=input_name) - - @public - @_copy_docs_from_op_execution_context - def asset_partitions_time_window_for_input(self, input_name: str = "result") -> TimeWindow: - return self.op_execution_context.asset_partitions_time_window_for_input(input_name) - #### Event log related @_copy_docs_from_op_execution_context diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/partition_mapping_tests/test_asset_partition_mappings.py b/python_modules/dagster/dagster_tests/asset_defs_tests/partition_mapping_tests/test_asset_partition_mappings.py index 7c1a07f8916a7..52abd085a5b7c 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/partition_mapping_tests/test_asset_partition_mappings.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/partition_mapping_tests/test_asset_partition_mappings.py @@ -601,7 +601,7 @@ def upstream(): ) def downstream(context: AssetExecutionContext): upstream_key = datetime.strptime( - context.asset_partition_key_for_input("upstream"), "%Y-%m-%d" + context.partition_context.upstream_partition_key("upstream"), "%Y-%m-%d" ) current_partition_key = datetime.strptime(context.partition_key, "%Y-%m-%d") @@ -654,10 +654,10 @@ def multi_asset_1(): @multi_asset(specs=[asset_3, asset_4], partitions_def=partitions_def) def multi_asset_2(context: AssetExecutionContext): asset_1_key = datetime.strptime( - context.asset_partition_key_for_input("asset_1"), "%Y-%m-%d" + context.partition_context.upstream_partition_key("asset_1"), "%Y-%m-%d" ) asset_2_key = datetime.strptime( - context.asset_partition_key_for_input("asset_2"), "%Y-%m-%d" + context.partition_context.upstream_partition_key("asset_2"), "%Y-%m-%d" ) current_partition_key = datetime.strptime(context.partition_key, "%Y-%m-%d") @@ -760,7 +760,7 @@ def test_self_dependent_partition_mapping_with_asset_deps(): ) def self_dependent(context: AssetExecutionContext): upstream_key = datetime.strptime( - context.asset_partition_key_for_input("self_dependent"), "%Y-%m-%d" + context.partition_context.upstream_partition_key("self_dependent"), "%Y-%m-%d" ) current_partition_key = datetime.strptime(context.partition_key, "%Y-%m-%d") @@ -787,7 +787,7 @@ def self_dependent(context: AssetExecutionContext): @multi_asset(specs=[asset_1], partitions_def=partitions_def) def the_multi_asset(context: AssetExecutionContext): asset_1_key = datetime.strptime( - context.asset_partition_key_for_input("asset_1"), "%Y-%m-%d" + context.partition_context.upstream_partition_key("asset_1"), "%Y-%m-%d" ) current_partition_key = datetime.strptime(context.partition_key, "%Y-%m-%d") @@ -810,7 +810,7 @@ def upstream(): deps=[AssetDep(upstream, partition_mapping=SpecificPartitionsPartitionMapping(["apple"]))], ) def downstream(context: AssetExecutionContext): - assert context.asset_partition_key_for_input("upstream") == "apple" + assert context.partition_context.upstream_partition_key("upstream") == "apple" assert context.partition_key == "orange" with instance_for_test() as instance: @@ -840,7 +840,7 @@ def asset_1_multi_asset(): @multi_asset(specs=[asset_2], partitions_def=partitions_def) def asset_2_multi_asset(context: AssetExecutionContext): - assert context.asset_partition_key_for_input("asset_1") == "apple" + assert context.partition_context.upstream_partition_key("asset_1") == "apple" assert context.partition_key == "orange" with instance_for_test() as instance: