Skip to content

Commit

Permalink
api with subobject that contains all the methods
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed Jan 23, 2024
1 parent 3016287 commit e0845f3
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 103 deletions.
10 changes: 5 additions & 5 deletions examples/partition_example/partition_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def relativedelta(*args, **kwargs):
metadata={"partition_expr": "LastModifiedDate"},
)
def salesforce_customers(context: AssetExecutionContext) -> pd.DataFrame:
start_date_str = context.asset_partition_key_for_output()
start_date_str = context.partition_context.partition_key

timezone = pytz.timezone("GMT") # Replace 'Your_Timezone' with the desired timezone
start_obj = datetime.datetime.strptime(start_date_str, "%Y-%m-%d").replace(tzinfo=timezone)
Expand Down Expand Up @@ -65,7 +65,7 @@ def realized_vol(context: AssetExecutionContext, orats_daily_prices: pd.DataFram
The volatility is calculated using various methods such as close-to-close, Parkinson, Hodges-Tompkins, and Yang-Zhang.
The function returns a DataFrame with the calculated volatilities.
"""
trade_date = context.asset_partition_key_for_output()
trade_date = context.partition_context.partition_key
ticker_id = 1

df = all_realvols(orats_daily_prices, ticker_id, trade_date)
Expand All @@ -80,7 +80,7 @@ def realized_vol(context: AssetExecutionContext, orats_daily_prices: pd.DataFram

@asset(io_manager_def="parquet_io_manager", partitions_def=hourly_partitions)
def my_custom_df(context: AssetExecutionContext) -> pd.DataFrame:
start, end = context.asset_partitions_time_window_for_output()
start, end = context.partition_context.partition_time_window

df = pd.DataFrame({"timestamp": pd.date_range(start, end, freq="5T")})
df["count"] = df["timestamp"].map(lambda a: random.randint(1, 1000))
Expand All @@ -93,7 +93,7 @@ def fetch_blog_posts_from_external_api(*args, **kwargs):

@asset(partitions_def=HourlyPartitionsDefinition(start_date="2022-01-01-00:00"))
def blog_posts(context: AssetExecutionContext) -> List[Dict]:
partition_datetime_str = context.asset_partition_key_for_output()
partition_datetime_str = context.partition_context.partition_key
hour = datetime.datetime.fromisoformat(partition_datetime_str)
posts = fetch_blog_posts_from_external_api(hour_when_posted=hour)
return posts
Expand All @@ -106,7 +106,7 @@ def blog_posts(context: AssetExecutionContext) -> List[Dict]:
key_prefix=["snowflake", "eldermark_proxy"],
)
def resident(context: AssetExecutionContext) -> Output[pd.DataFrame]:
start, end = context.asset_partitions_time_window_for_output()
start, end = context.partition_context.partition_time_window
filter_str = f"LastMod_Stamp >= {start.timestamp()} AND LastMod_Stamp < {end.timestamp()}"

records = context.resources.eldermark.fetch_obj(obj="Resident", filter=filter_str)
Expand Down
280 changes: 189 additions & 91 deletions python_modules/dagster/dagster/_core/execution/context/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,6 +1374,16 @@ def _copy_docs_from_op_execution_context(obj):
"dagster_run": "run",
"run_config": "run.run_config",
"run_tags": "run.tags",
"asset_partition_key_for_output": "partition_context.partition_key",
"asset_partitions_time_window_for_output": "partition_context.partition_time_window",
"asset_partition_key_range_for_output": "partition_context.partition_key_range",
"asset_partition_key_range_for_input": "partition_context.upstream_partition_key_range",
"asset_partition_key_for_input": "partition_context.upstream_partition_key",
"asset_partitions_def_for_output": "assets_def.partitions_def",
"asset_partitions_def_for_input": "partition_context.upstream_partitions_def",
"asset_partition_keys_for_output": "partition_context.partition_keys",
"asset_partition_keys_for_input": "partition_context.upstream_partition_keys",
"asset_partitions_time_window_for_input": "partition_context.upstream_partitions_time_window",
}

ALTERNATE_EXPRESSIONS = {
Expand All @@ -1400,11 +1410,82 @@ def _get_deprecation_kwargs(attr: str):
return deprecation_kwargs


class PartitionContext: # TODO - better name?
def __init__(self, op_execution_context: OpExecutionContext):
self.op_execution_context = op_execution_context

@public
@property
def is_partitioned_materialization(self) -> bool:
return self.op_execution_context.has_partition_key

@public
@property
def partition_key(self) -> str:
return self.op_execution_context.partition_key

@public
@property
def partition_keys(self) -> Sequence[str]:
return self.op_execution_context.partition_keys

@public
@property
def partition_key_range(self) -> PartitionKeyRange:
return self.op_execution_context.partition_key_range

@public
@property
def partition_time_window(self) -> TimeWindow:
return self.op_execution_context.partition_time_window

@public
def upstream_partition_key(self, key: CoercibleToAssetKey) -> str:
return self.op_execution_context._step_execution_context.asset_partition_key_for_upstream( # noqa: SLF001
AssetKey.from_coercible(key)
)

@public
def upstream_partition_keys(self, key: CoercibleToAssetKey) -> Sequence[str]:
return list(
self.op_execution_context._step_execution_context.asset_partitions_subset_for_upstream( # noqa: SLF001
AssetKey.from_coercible(key)
).get_partition_keys()
)

@public
def upstream_partition_key_range(self, key: CoercibleToAssetKey) -> PartitionKeyRange:
return self.op_execution_context._step_execution_context.asset_partition_key_range_for_upstream( # noqa: SLF001
AssetKey.from_coercible(key)
)

@public
def upstream_partitions_time_window(
self, key: CoercibleToAssetKey
) -> TimeWindow: # TODO align on plurality of partition(s)
return self.op_execution_context._step_execution_context.asset_partitions_time_window_for_upstream( # noqa: SLF001
AssetKey.from_coercible(key)
)

@public
def upstream_partitions_def(self, key: CoercibleToAssetKey) -> PartitionsDefinition:
result = self.op_execution_context._step_execution_context.job_def.asset_layer.partitions_def_for_asset( # noqa: SLF001
AssetKey.from_coercible(key)
)
if result is None:
raise DagsterInvariantViolationError(
f"Attempting to access partitions def for asset {key}, but it is not" " partitioned"
)

return result


class AssetExecutionContext(OpExecutionContext):
def __init__(self, op_execution_context: OpExecutionContext) -> None:
self._op_execution_context = check.inst_param(
op_execution_context, "op_execution_context", OpExecutionContext
)
self._partition_context = PartitionContext(op_execution_context=self._op_execution_context)

@staticmethod
def get() -> "AssetExecutionContext":
Expand All @@ -1417,6 +1498,10 @@ def get() -> "AssetExecutionContext":
def op_execution_context(self) -> OpExecutionContext:
return self._op_execution_context

@property
def partition_context(self) -> PartitionContext:
return self._partition_context

####### Top-level properties/methods on AssetExecutionContext

@public
Expand Down Expand Up @@ -1529,6 +1614,110 @@ def has_tag(self, key: str) -> bool:
def get_tag(self, key: str) -> Optional[str]:
return self.op_execution_context.get_tag(key)

@deprecated(**_get_deprecation_kwargs("has_partition_key"))
@public
@property
@_copy_docs_from_op_execution_context
def has_partition_key(self) -> bool:
return self.op_execution_context.has_partition_key

@deprecated(**_get_deprecation_kwargs("partition_key"))
@public
@property
@_copy_docs_from_op_execution_context
def partition_key(self) -> str:
return self.op_execution_context.partition_key

@deprecated(**_get_deprecation_kwargs("partition_keys"))
@public
@property
@_copy_docs_from_op_execution_context
def partition_keys(self) -> Sequence[str]:
return self.op_execution_context.partition_keys

@deprecated(breaking_version="2.0", additional_warn_text="Use `partition_key_range` instead.")
@public
@property
@_copy_docs_from_op_execution_context
def asset_partition_key_range(self) -> PartitionKeyRange:
return self.op_execution_context.asset_partition_key_range

@deprecated(**_get_deprecation_kwargs("partition_key_range"))
@public
@property
@_copy_docs_from_op_execution_context
def partition_key_range(self) -> PartitionKeyRange:
return self.op_execution_context.partition_key_range

@deprecated(**_get_deprecation_kwargs("partition_time_window"))
@public
@property
@_copy_docs_from_op_execution_context
def partition_time_window(self) -> TimeWindow:
return self.op_execution_context.partition_time_window

@deprecated(**_get_deprecation_kwargs("asset_partition_key_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_key_for_output(self, output_name: str = "result") -> str:
return self.op_execution_context.asset_partition_key_for_output(output_name=output_name)

@deprecated(**_get_deprecation_kwargs("asset_partitions_time_window_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partitions_time_window_for_output(self, output_name: str = "result") -> TimeWindow:
return self.op_execution_context.asset_partitions_time_window_for_output(output_name)

@deprecated(**_get_deprecation_kwargs("asset_partition_key_range_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_key_range_for_output(
self, output_name: str = "result"
) -> PartitionKeyRange:
return self.op_execution_context.asset_partition_key_range_for_output(output_name)

@deprecated(**_get_deprecation_kwargs("asset_partition_key_range_for_input"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRange:
return self.op_execution_context.asset_partition_key_range_for_input(input_name)

@deprecated(**_get_deprecation_kwargs("asset_partition_key_for_input"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_key_for_input(self, input_name: str) -> str:
return self.op_execution_context.asset_partition_key_for_input(input_name)

@deprecated(**_get_deprecation_kwargs("asset_partitions_def_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partitions_def_for_output(self, output_name: str = "result") -> PartitionsDefinition:
return self.op_execution_context.asset_partitions_def_for_output(output_name=output_name)

@deprecated(**_get_deprecation_kwargs("asset_partitions_def_for_input"))
@public
@_copy_docs_from_op_execution_context
def asset_partitions_def_for_input(self, input_name: str) -> PartitionsDefinition:
return self.op_execution_context.asset_partitions_def_for_input(input_name=input_name)

@deprecated(**_get_deprecation_kwargs("asset_partition_keys_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_keys_for_output(self, output_name: str = "result") -> Sequence[str]:
return self.op_execution_context.asset_partition_keys_for_output(output_name=output_name)

@deprecated(**_get_deprecation_kwargs("asset_partition_keys_for_input"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_keys_for_input(self, input_name: str) -> Sequence[str]:
return self.op_execution_context.asset_partition_keys_for_input(input_name=input_name)

@deprecated(**_get_deprecation_kwargs("asset_partitions_time_window_for_input"))
@public
@_copy_docs_from_op_execution_context
def asset_partitions_time_window_for_input(self, input_name: str = "result") -> TimeWindow:
return self.op_execution_context.asset_partitions_time_window_for_input(input_name)

########## pass-through to op context

#### op related
Expand Down Expand Up @@ -1646,97 +1835,6 @@ def step_launcher(self) -> Optional[StepLauncher]:
def get_step_execution_context(self) -> StepExecutionContext:
return self.op_execution_context.get_step_execution_context()

#### partition_related

@public
@property
@_copy_docs_from_op_execution_context
def has_partition_key(self) -> bool:
return self.op_execution_context.has_partition_key

@public
@property
@_copy_docs_from_op_execution_context
def partition_key(self) -> str:
return self.op_execution_context.partition_key

@public
@property
@_copy_docs_from_op_execution_context
def partition_keys(self) -> Sequence[str]:
return self.op_execution_context.partition_keys

@deprecated(breaking_version="2.0", additional_warn_text="Use `partition_key_range` instead.")
@public
@property
@_copy_docs_from_op_execution_context
def asset_partition_key_range(self) -> PartitionKeyRange:
return self.op_execution_context.asset_partition_key_range

@public
@property
@_copy_docs_from_op_execution_context
def partition_key_range(self) -> PartitionKeyRange:
return self.op_execution_context.partition_key_range

@public
@property
@_copy_docs_from_op_execution_context
def partition_time_window(self) -> TimeWindow:
return self.op_execution_context.partition_time_window

@public
@_copy_docs_from_op_execution_context
def asset_partition_key_for_output(self, output_name: str = "result") -> str:
return self.op_execution_context.asset_partition_key_for_output(output_name=output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partitions_time_window_for_output(self, output_name: str = "result") -> TimeWindow:
return self.op_execution_context.asset_partitions_time_window_for_output(output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_key_range_for_output(
self, output_name: str = "result"
) -> PartitionKeyRange:
return self.op_execution_context.asset_partition_key_range_for_output(output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRange:
return self.op_execution_context.asset_partition_key_range_for_input(input_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_key_for_input(self, input_name: str) -> str:
return self.op_execution_context.asset_partition_key_for_input(input_name)

@public
@_copy_docs_from_op_execution_context
def asset_partitions_def_for_output(self, output_name: str = "result") -> PartitionsDefinition:
return self.op_execution_context.asset_partitions_def_for_output(output_name=output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partitions_def_for_input(self, input_name: str) -> PartitionsDefinition:
return self.op_execution_context.asset_partitions_def_for_input(input_name=input_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_keys_for_output(self, output_name: str = "result") -> Sequence[str]:
return self.op_execution_context.asset_partition_keys_for_output(output_name=output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_keys_for_input(self, input_name: str) -> Sequence[str]:
return self.op_execution_context.asset_partition_keys_for_input(input_name=input_name)

@public
@_copy_docs_from_op_execution_context
def asset_partitions_time_window_for_input(self, input_name: str = "result") -> TimeWindow:
return self.op_execution_context.asset_partitions_time_window_for_input(input_name)

#### Event log related

@_copy_docs_from_op_execution_context
Expand Down
Loading

0 comments on commit e0845f3

Please sign in to comment.