Skip to content

Commit

Permalink
[MaterializeResult] more direct invocation tests & fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
alangenfeld committed Oct 3, 2023
1 parent 56994a4 commit b0cd32b
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 45 deletions.
21 changes: 16 additions & 5 deletions python_modules/dagster/dagster/_core/definitions/op_invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,19 @@ def _key_for_result(result: MaterializeResult, context: "BoundOpExecutionContext
return next(iter(context.assets_def.keys))

raise DagsterInvariantViolationError(
"Unable to resolve unset asset_key for MaterializeResult, set explicitly when constructing."
"MaterializeResult did not include asset_key and it can not be inferred. Specify which"
f" asset_key, options are: {context.assets_def.keys}"
)


def _output_name_for_result_obj(
event: MaterializeResult,
context: "BoundOpExecutionContext",
):
asset_key = _key_for_result(event, context)
return context.assets_def.get_output_name_for_asset_key(asset_key)


def _handle_gen_event(
event: T,
op_def: "OpDefinition",
Expand All @@ -351,8 +360,7 @@ def _handle_gen_event(
):
return event
elif isinstance(event, MaterializeResult):
asset_key = _key_for_result(event, context)
output_name = context.assets_def.get_output_name_for_asset_key(asset_key)
output_name = _output_name_for_result_obj(event, context)
outputs_seen.add(output_name)
return event
else:
Expand Down Expand Up @@ -439,8 +447,8 @@ def type_check_gen(gen):
yield Output(output_name=output_def.name, value=None)
else:
raise DagsterInvariantViolationError(
f"Invocation of {op_def.node_type_str} '{context.alias}' did not"
f" return an output for non-optional output '{output_def.name}'"
f'Invocation of {op_def.node_type_str} "{context.alias}" did not'
f' return an output for non-optional output "{output_def.name}"'
)

return type_check_gen(result)
Expand All @@ -458,6 +466,9 @@ def _type_check_function_output(
for event in validate_and_coerce_op_result_to_iterator(result, context, op_def.output_defs):
if isinstance(event, (Output, DynamicOutput)):
_type_check_output(output_defs_by_name[event.output_name], event, context)
elif isinstance(event, (MaterializeResult)):
# ensure result objects are contextually valid
_output_name_for_result_obj(event, context)

return result

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -531,10 +531,7 @@ def asset_key(self) -> AssetKey:
"Cannot call `context.asset_key` in a multi_asset with more than one asset. Use"
" `context.asset_key_for_output` instead."
)
# pass in the output name to handle the case when a multi_asset has a single AssetOut
return self.asset_key_for_output(
output_name=next(iter(self.assets_def.keys_by_output_name.keys()))
)
return next(iter(self.assets_def.keys_by_output_name.values()))

@public
@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -597,8 +597,8 @@ def op_multiple_outputs_not_sent():
with pytest.raises(
DagsterInvariantViolationError,
match=(
"Invocation of op 'op_multiple_outputs_not_sent' did not return an output "
"for non-optional output '1'"
'Invocation of op "op_multiple_outputs_not_sent" did not return an output '
'for non-optional output "1"'
),
):
list(op_multiple_outputs_not_sent())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Generator, Tuple

import pytest
Expand All @@ -19,6 +20,7 @@
)
from dagster._core.definitions.asset_check_spec import AssetCheckKey
from dagster._core.errors import DagsterInvariantViolationError, DagsterStepOutputNotFoundError
from dagster._core.execution.context.invocation import build_asset_context
from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecordStatus


Expand Down Expand Up @@ -48,17 +50,32 @@ def ret_mismatch(context: AssetExecutionContext):
metadata={"one": 1},
)

# core execution
with pytest.raises(
DagsterInvariantViolationError,
match="Asset key random not found in AssetsDefinition",
):
materialize([ret_mismatch])

# direct invocation
with pytest.raises(
DagsterInvariantViolationError,
match="Asset key random not found in AssetsDefinition",
):
ret_mismatch(build_asset_context())

# tuple
@asset
def ret_two():
return MaterializeResult(metadata={"one": 1}), MaterializeResult(metadata={"two": 2})

materialize([ret_two])
# core execution
result = materialize([ret_two])
assert result.success

# direct invocation
direct_results = ret_two()
assert len(direct_results) == 2


def test_return_materialization_with_asset_checks():
Expand All @@ -72,6 +89,7 @@ def ret_checks(context: AssetExecutionContext):
]
)

# core execution
materialize([ret_checks], instance=instance)
asset_check_executions = instance.event_log_storage.get_asset_check_execution_history(
AssetCheckKey(asset_key=ret_checks.key, name="foo_check"),
Expand All @@ -80,6 +98,11 @@ def ret_checks(context: AssetExecutionContext):
assert len(asset_check_executions) == 1
assert asset_check_executions[0].status == AssetCheckExecutionRecordStatus.SUCCEEDED

# direct invocation
context = build_asset_context()
direct_results = ret_checks(context)
assert direct_results


def test_multi_asset():
@multi_asset(outs={"one": AssetOut(), "two": AssetOut()})
Expand All @@ -88,15 +111,23 @@ def outs_multi_asset():
asset_key="two", metadata={"baz": "qux"}
)

materialize([outs_multi_asset])
assert materialize([outs_multi_asset]).success

res = outs_multi_asset()
assert res[0].metadata["foo"] == "bar"
assert res[1].metadata["baz"] == "qux"

@multi_asset(specs=[AssetSpec(["prefix", "one"]), AssetSpec(["prefix", "two"])])
def specs_multi_asset():
return MaterializeResult(
asset_key=["prefix", "one"], metadata={"foo": "bar"}
), MaterializeResult(asset_key=["prefix", "two"], metadata={"baz": "qux"})

materialize([specs_multi_asset])
assert materialize([specs_multi_asset]).success

res = specs_multi_asset()
assert res[0].metadata["foo"] == "bar"
assert res[1].metadata["baz"] == "qux"


def test_return_materialization_multi_asset():
Expand All @@ -122,6 +153,9 @@ def multi():
assert "two" in mats[1].metadata
assert mats[1].tags

direct_results = list(multi())
assert len(direct_results) == 2

#
# missing a non optional out
#
Expand All @@ -141,6 +175,12 @@ def missing():
):
_exec_asset(missing)

with pytest.raises(
DagsterInvariantViolationError,
match='Invocation of op "missing" did not return an output for non-optional output "two"',
):
list(missing())

#
# missing asset_key
#
Expand All @@ -162,6 +202,15 @@ def no_key():
):
_exec_asset(no_key)

with pytest.raises(
DagsterInvariantViolationError,
match=(
"MaterializeResult did not include asset_key and it can not be inferred. Specify which"
" asset_key, options are:"
),
):
list(no_key())

#
# return tuple success
#
Expand All @@ -186,6 +235,9 @@ def ret_multi():
assert "two" in mats[1].metadata
assert mats[1].tags

res = ret_multi()
assert len(res) == 2

#
# return list error
#
Expand All @@ -212,6 +264,15 @@ def ret_list():
):
_exec_asset(ret_list)

with pytest.raises(
DagsterInvariantViolationError,
match=(
"When using multiple outputs, either yield each output, or return a tuple containing a"
" value for each output."
),
):
ret_list()


def test_materialize_result_output_typing():
# Test that the return annotation MaterializeResult is interpreted as a Nothing type, since we
Expand Down Expand Up @@ -279,49 +340,31 @@ def generator_asset() -> Generator[MaterializeResult, None, None]:
materialize([generator_asset], resources={"io_manager": TestingIOManager()})


def test_direct_invocation_materialize_result():
@asset
def my_asset() -> MaterializeResult:
return MaterializeResult(metadata={"foo": "bar"})

res = my_asset()
assert res.metadata["foo"] == "bar"

@multi_asset(specs=[AssetSpec("one"), AssetSpec("two")])
def specs_multi_asset():
return MaterializeResult(asset_key="one", metadata={"foo": "bar"}), MaterializeResult(
asset_key="two", metadata={"baz": "qux"}
)

res = specs_multi_asset()
assert res[0].metadata["foo"] == "bar"
assert res[1].metadata["baz"] == "qux"

@multi_asset(outs={"one": AssetOut(), "two": AssetOut()})
def outs_multi_asset():
return MaterializeResult(asset_key="one", metadata={"foo": "bar"}), MaterializeResult(
asset_key="two", metadata={"baz": "qux"}
)

res = outs_multi_asset()
assert res[0].metadata["foo"] == "bar"
assert res[1].metadata["baz"] == "qux"


def test_direct_invocation_for_generators():
def test_materialize_result_generators():
@asset
def generator_asset() -> Generator[MaterializeResult, None, None]:
yield MaterializeResult(metadata={"foo": "bar"})

res = _exec_asset(generator_asset)
assert len(res) == 1
assert res[0].metadata["foo"].value == "bar"

res = list(generator_asset())
assert len(res) == 1
assert res[0].metadata["foo"] == "bar"

@multi_asset(specs=[AssetSpec("one"), AssetSpec("two")])
def generator_specs_multi_asset():
yield MaterializeResult(asset_key="one", metadata={"foo": "bar"})
yield MaterializeResult(asset_key="two", metadata={"baz": "qux"})

res = _exec_asset(generator_specs_multi_asset)
assert len(res) == 2
assert res[0].metadata["foo"].value == "bar"
assert res[1].metadata["baz"].value == "qux"

res = list(generator_specs_multi_asset())
assert len(res) == 2
assert res[0].metadata["foo"] == "bar"
assert res[1].metadata["baz"] == "qux"

Expand All @@ -330,11 +373,52 @@ def generator_outs_multi_asset():
yield MaterializeResult(asset_key="one", metadata={"foo": "bar"})
yield MaterializeResult(asset_key="two", metadata={"baz": "qux"})

res = _exec_asset(generator_outs_multi_asset)
assert len(res) == 2
assert res[0].metadata["foo"].value == "bar"
assert res[1].metadata["baz"].value == "qux"

res = list(generator_outs_multi_asset())
assert len(res) == 2
assert res[0].metadata["foo"] == "bar"
assert res[1].metadata["baz"] == "qux"

@multi_asset(specs=[AssetSpec("one"), AssetSpec("two")])
async def async_specs_multi_asset():
return MaterializeResult(asset_key="one", metadata={"foo": "bar"}), MaterializeResult(
asset_key="two", metadata={"baz": "qux"}
)

res = _exec_asset(async_specs_multi_asset)
assert len(res) == 2
assert res[0].metadata["foo"].value == "bar"
assert res[1].metadata["baz"].value == "qux"

res = asyncio.run(async_specs_multi_asset())
assert len(res) == 2
assert res[0].metadata["foo"] == "bar"
assert res[1].metadata["baz"] == "qux"

# need to test async generator case and coroutine see op_invocation.py:_type_check_output_wrapper for all cases
@multi_asset(specs=[AssetSpec("one"), AssetSpec("two")])
async def async_gen_specs_multi_asset():
yield MaterializeResult(asset_key="one", metadata={"foo": "bar"})
yield MaterializeResult(asset_key="two", metadata={"baz": "qux"})

res = _exec_asset(async_gen_specs_multi_asset)
assert len(res) == 2
assert res[0].metadata["foo"].value == "bar"
assert res[1].metadata["baz"].value == "qux"

async def _run_async_gen():
results = []
async for result in async_gen_specs_multi_asset():
results.append(result)
return results

res = asyncio.run(_run_async_gen())
assert len(res) == 2
assert res[0].metadata["foo"] == "bar"
assert res[1].metadata["baz"] == "qux"


def test_materialize_result_with_partitions():
Expand Down

0 comments on commit b0cd32b

Please sign in to comment.