Skip to content

Commit

Permalink
[prototype] return AssetMaterialization
Browse files Browse the repository at this point in the history
  • Loading branch information
alangenfeld committed Jun 26, 2023
1 parent 37be28c commit 51870da
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 8 deletions.
19 changes: 17 additions & 2 deletions python_modules/dagster/dagster/_core/definitions/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import dagster._seven as seven
from dagster._annotations import PublicAttr, public
from dagster._core.definitions.data_version import DataVersion
from dagster._core.errors import DagsterInvariantViolationError
from dagster._core.storage.tags import MULTIDIMENSIONAL_PARTITION_PREFIX, SYSTEM_TAG_PREFIX
from dagster._serdes import whitelist_for_serdes
from dagster._serdes.serdes import NamedTupleSerializer
Expand Down Expand Up @@ -476,7 +477,7 @@ class AssetMaterialization(
Args:
asset_key (Union[str, List[str], AssetKey]): A key to identify the materialized asset across
job runs
job runs. Optional in cases when the key can be inferred from the current context.
description (Optional[str]): A longer human-readable description of the materialized value.
partition (Optional[str]): The name of the partition
that was materialized.
Expand All @@ -490,18 +491,32 @@ class AssetMaterialization(

def __new__(
cls,
asset_key: CoercibleToAssetKey,
asset_key: Optional[CoercibleToAssetKey] = None,
description: Optional[str] = None,
metadata: Optional[Mapping[str, RawMetadataValue]] = None,
partition: Optional[str] = None,
tags: Optional[Mapping[str, str]] = None,
):
from dagster._core.definitions.multi_dimensional_partitions import MultiPartitionKey
from dagster._core.execution.context.compute import get_execution_context

if isinstance(asset_key, AssetKey):
check.inst_param(asset_key, "asset_key", AssetKey)
elif isinstance(asset_key, str):
asset_key = AssetKey(parse_asset_key_string(asset_key))
elif asset_key is None:
current_ctx = get_execution_context()
if current_ctx is None:
raise DagsterInvariantViolationError(
"Could not infer asset_key, not currently in the context of an execution."
)
keys = current_ctx.selected_asset_keys
if len(keys) != 1:
raise DagsterInvariantViolationError(
f"Could not infer asset_key, there are {len(keys)} in the current execution"
" context. Specify the appropriate asset_key."
)
asset_key = next(iter(keys))
else:
check.sequence_param(asset_key, "asset_key", of_type=str)
asset_key = AssetKey(asset_key)
Expand Down
18 changes: 18 additions & 0 deletions python_modules/dagster/dagster/_core/execution/context/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ def __init__(

self._output_metadata: Dict[str, Any] = {}
self._seen_outputs: Dict[str, Union[str, Set[str]]] = {}
self._seen_user_asset_mats = {}

self._input_asset_records: Dict[AssetKey, Optional["EventLogRecord"]] = {}
self._is_external_input_asset_records_loaded = False
Expand Down Expand Up @@ -846,6 +847,23 @@ def is_sda_step(self) -> bool:
return True
return False

def asset_key_for_output(self, output_name: str) -> AssetKey:
# note: duped on AssetExecutionContext
asset_output_info = self.job_def.asset_layer.asset_info_for_output(
node_handle=self.node_handle, output_name=output_name
)
if asset_output_info is None:
check.failed(f"Output '{output_name}' has no asset")
else:
return asset_output_info.key

def observe_user_asset_mat(self, asset_key, event):
# will need to store N events for partitions
self._seen_user_asset_mats[asset_key] = event

def get_observed_user_asset_mat(self, asset_key):
return self._seen_user_asset_mats.get(asset_key)

def set_data_version(self, asset_key: AssetKey, data_version: "DataVersion") -> None:
self._data_version_cache[asset_key] = data_version

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from typing_extensions import get_args

import dagster._check as check
from dagster._config.pythonic_config import Config
from dagster._core.definitions import (
AssetMaterialization,
Expand Down Expand Up @@ -195,17 +196,22 @@ def validate_and_coerce_op_result_to_iterator(
# this happens when a user explicitly returns a generator in the op
for event in result:
yield event
elif isinstance(result, (AssetMaterialization, ExpectationResult)):

# [A] yield it here...
elif isinstance(result, AssetMaterialization):
yield result

elif isinstance(result, (ExpectationResult)):
raise DagsterInvariantViolationError(
f"Error in {context.describe_op()}: If you are "
"returning an AssetMaterialization "
"or an ExpectationResult from "
f"{context.op_def.node_type_str} you must yield them "
"returning an ExpectationResult from "
f"{context.op_def.node_type_str} you must yield it "
"directly, or log them using the OpExecutionContext.log_event method to avoid "
"ambiguity with an implied result from returning a "
"value. Check out the docs on logging events here: "
"https://docs.dagster.io/concepts/ops-jobs-graphs/op-events#op-events-and-exceptions"
)

elif result is not None and not output_defs:
raise DagsterInvariantViolationError(
f"Error in {context.describe_op()}: Unexpectedly returned output of type"
Expand Down Expand Up @@ -285,3 +291,5 @@ def validate_and_coerce_op_result_to_iterator(
"https://docs.dagster.io/concepts/ops-jobs-graphs/graphs#with-conditional-branching"
)
yield Output(output_name=output_def.name, value=element)
else:
check.failed("do we ever hit this unhandled case?")
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,16 @@ def _step_output_error_checked_user_event_sequence(
output_names = list([output_def.name for output_def in step.step_outputs])

for user_event in user_event_sequence:
if not isinstance(user_event, (Output, DynamicOutput)):
if not isinstance(user_event, (Output, DynamicOutput, AssetMaterialization)):
yield user_event
continue

# [A] ... to swallow here
if isinstance(user_event, AssetMaterialization):
step_context.observe_user_asset_mat(user_event.asset_key, user_event)
# defer yielding til post resolve to apply tags
continue

# do additional processing on Outputs
output = user_event
if not step.has_step_output(cast(str, output.output_name)):
Expand Down Expand Up @@ -165,6 +171,15 @@ def _step_output_error_checked_user_event_sequence(
f'Emitting implicit Nothing for output "{step_output_def.name}" on {op_label}'
)
yield Output(output_name=step_output_def.name, value=None)

if step_context.is_sda_step and step_context.get_observed_user_asset_mat(
step_context.asset_key_for_output(step_output_def.name)
):
# think its fine to omit log
# step_context.log.info(
# f"Emitting implicit Nothing for materialized asset {op_label}"
# )
yield Output(output_name=step_output_def.name, value=None)
elif not step_output_def.is_dynamic:
raise DagsterStepOutputNotFoundError(
(
Expand Down Expand Up @@ -376,6 +391,7 @@ def core_dagster_event_sequence_for_step(
yield evt
# for now, I'm ignoring AssetMaterializations yielded manually, but we might want
# to do something with these in the above path eventually
# ^ wat?
elif isinstance(user_event, AssetMaterialization):
yield DagsterEvent.asset_materialization(step_context, user_event)
elif isinstance(user_event, AssetObservation):
Expand Down Expand Up @@ -490,7 +506,11 @@ def _get_output_asset_materializations(
if backfill_id:
tags[BACKFILL_ID_TAG] = backfill_id

if asset_partitions:
user_event = step_context.get_observed_user_asset_mat(asset_key)
if asset_partitions and user_event:
# this will be a bit involved
check.failed("unhandled")
elif asset_partitions:
for partition in asset_partitions:
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
Expand All @@ -507,6 +527,20 @@ def _get_output_asset_materializations(
metadata=all_metadata,
tags=tags,
)
elif user_event:
if tags:
yield AssetMaterialization(
**{ # dirty mergey
**user_event._asdict(),
"tags": {
**(user_event.tags if user_event.tags else {}),
**tags,
},
},
)
else:
yield user_event

else:
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@
from dagster._check import CheckError
from dagster._core.definitions import AssetIn, SourceAsset, asset, multi_asset
from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy
from dagster._core.definitions.events import AssetMaterialization
from dagster._core.errors import (
DagsterInvalidDefinitionError,
DagsterInvalidInvocationError,
DagsterInvalidPropertyError,
)
from dagster._core.execution.context.compute import AssetExecutionContext
from dagster._core.instance import DagsterInstance
from dagster._core.storage.mem_io_manager import InMemoryIOManager
from dagster._core.test_utils import instance_for_test
Expand Down Expand Up @@ -1476,3 +1478,68 @@ def blah(context):
defs = Definitions(assets=[blah])
defs.get_implicit_global_asset_job_def().execute_in_process()
assert executed["yes"]


def test_return_materialization():
#
# status quo - use add add_output_metadata
#
@asset
def add(context: AssetExecutionContext):
context.add_output_metadata(
metadata={"one": 1},
)

asset_job = define_asset_job("bar", [add]).resolve([add], [])

result = asset_job.execute_in_process()
assert result.success

mats = result.asset_materializations_for_node(add.node_def.name)
assert len(mats) == 1
# working with core metadata repr values sucks, ie IntMetadataValue
assert "one" in mats[0].metadata
assert mats[0].tags

#
# side quest: may want to update this pattern to work as well
#
@asset
def logged(context: AssetExecutionContext):
context.log_event(
AssetMaterialization(
metadata={"one": 1},
)
)

asset_job = define_asset_job("bar", [logged]).resolve([logged], [])

result = asset_job.execute_in_process()
assert result.success

mats = result.asset_materializations_for_node(logged.node_def.name)
# should we change this? currently get implicit materialization for output + logged event
assert len(mats) == 2
assert "one" in mats[0].metadata
# assert mats[0].tags # fails
# assert "one" in mats[1].metadata # fails
assert mats[1].tags

#
# main exploration
#
@asset
def ret_untyped(context: AssetExecutionContext):
return AssetMaterialization(
metadata={"one": 1},
)

asset_job = define_asset_job("bar", [ret_untyped]).resolve([ret_untyped], [])

result = asset_job.execute_in_process()
assert result.success

mats = result.asset_materializations_for_node(ret_untyped.node_def.name)
assert len(mats) == 1, mats
assert "one" in mats[0].metadata
assert mats[0].tags

0 comments on commit 51870da

Please sign in to comment.