From 58f609730e487d7d9bd25a336a2a4a59413f52aa Mon Sep 17 00:00:00 2001 From: alangenfeld Date: Mon, 2 Oct 2023 13:32:19 -0500 Subject: [PATCH] [MaterializeResult] more direct invocation tests & fixes --- .../_core/definitions/op_invocation.py | 21 ++- .../_core/execution/context/compute.py | 5 +- .../core_tests/test_op_invocation.py | 4 +- .../test_materialize_result.py | 152 ++++++++++++++---- 4 files changed, 137 insertions(+), 45 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/op_invocation.py b/python_modules/dagster/dagster/_core/definitions/op_invocation.py index 9282f21558b21..984dd8003626b 100644 --- a/python_modules/dagster/dagster/_core/definitions/op_invocation.py +++ b/python_modules/dagster/dagster/_core/definitions/op_invocation.py @@ -334,10 +334,19 @@ def _key_for_result(result: MaterializeResult, context: "BoundOpExecutionContext return next(iter(context.assets_def.keys)) raise DagsterInvariantViolationError( - "Unable to resolve unset asset_key for MaterializeResult, set explicitly when constructing." + "MaterializeResult did not include asset_key and it can not be inferred. Specify which" + f" asset_key, options are: {context.assets_def.keys}" ) +def _output_name_for_result_obj( + event: MaterializeResult, + context: "BoundOpExecutionContext", +): + asset_key = _key_for_result(event, context) + return context.assets_def.get_output_name_for_asset_key(asset_key) + + def _handle_gen_event( event: T, op_def: "OpDefinition", @@ -351,8 +360,7 @@ def _handle_gen_event( ): return event elif isinstance(event, MaterializeResult): - asset_key = _key_for_result(event, context) - output_name = context.assets_def.get_output_name_for_asset_key(asset_key) + output_name = _output_name_for_result_obj(event, context) outputs_seen.add(output_name) return event else: @@ -439,8 +447,8 @@ def type_check_gen(gen): yield Output(output_name=output_def.name, value=None) else: raise DagsterInvariantViolationError( - f"Invocation of {op_def.node_type_str} '{context.alias}' did not" - f" return an output for non-optional output '{output_def.name}'" + f'Invocation of {op_def.node_type_str} "{context.alias}" did not' + f' return an output for non-optional output "{output_def.name}"' ) return type_check_gen(result) @@ -458,6 +466,9 @@ def _type_check_function_output( for event in validate_and_coerce_op_result_to_iterator(result, context, op_def.output_defs): if isinstance(event, (Output, DynamicOutput)): _type_check_output(output_defs_by_name[event.output_name], event, context) + elif isinstance(event, (MaterializeResult)): + # ensure result objects are contextually valid + _output_name_for_result_obj(event, context) return result diff --git a/python_modules/dagster/dagster/_core/execution/context/compute.py b/python_modules/dagster/dagster/_core/execution/context/compute.py index 1dcb2834e521e..6d23ea700e6ad 100644 --- a/python_modules/dagster/dagster/_core/execution/context/compute.py +++ b/python_modules/dagster/dagster/_core/execution/context/compute.py @@ -531,10 +531,7 @@ def asset_key(self) -> AssetKey: "Cannot call `context.asset_key` in a multi_asset with more than one asset. Use" " `context.asset_key_for_output` instead." ) - # pass in the output name to handle the case when a multi_asset has a single AssetOut - return self.asset_key_for_output( - output_name=next(iter(self.assets_def.keys_by_output_name.keys())) - ) + return next(iter(self.assets_def.keys_by_output_name.values())) @public @property diff --git a/python_modules/dagster/dagster_tests/core_tests/test_op_invocation.py b/python_modules/dagster/dagster_tests/core_tests/test_op_invocation.py index 8a52ae5417672..6c6353ef27f26 100644 --- a/python_modules/dagster/dagster_tests/core_tests/test_op_invocation.py +++ b/python_modules/dagster/dagster_tests/core_tests/test_op_invocation.py @@ -597,8 +597,8 @@ def op_multiple_outputs_not_sent(): with pytest.raises( DagsterInvariantViolationError, match=( - "Invocation of op 'op_multiple_outputs_not_sent' did not return an output " - "for non-optional output '1'" + 'Invocation of op "op_multiple_outputs_not_sent" did not return an output ' + 'for non-optional output "1"' ), ): list(op_multiple_outputs_not_sent()) diff --git a/python_modules/dagster/dagster_tests/definitions_tests/test_materialize_result.py b/python_modules/dagster/dagster_tests/definitions_tests/test_materialize_result.py index 784ad7116e7cf..670b0cdf6cd0d 100644 --- a/python_modules/dagster/dagster_tests/definitions_tests/test_materialize_result.py +++ b/python_modules/dagster/dagster_tests/definitions_tests/test_materialize_result.py @@ -1,3 +1,4 @@ +import asyncio from typing import Generator, Tuple import pytest @@ -18,6 +19,7 @@ multi_asset, ) from dagster._core.errors import DagsterInvariantViolationError, DagsterStepOutputNotFoundError +from dagster._core.execution.context.invocation import build_asset_context from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecordStatus @@ -47,17 +49,32 @@ def ret_mismatch(context: AssetExecutionContext): metadata={"one": 1}, ) + # core execution with pytest.raises( DagsterInvariantViolationError, match="Asset key random not found in AssetsDefinition", ): materialize([ret_mismatch]) + # direct invocation + with pytest.raises( + DagsterInvariantViolationError, + match="Asset key random not found in AssetsDefinition", + ): + ret_mismatch(build_asset_context()) + + # tuple @asset def ret_two(): return MaterializeResult(metadata={"one": 1}), MaterializeResult(metadata={"two": 2}) - materialize([ret_two]) + # core execution + result = materialize([ret_two]) + assert result.success + + # direct invocation + direct_results = ret_two() + assert len(direct_results) == 2 def test_return_materialization_with_asset_checks(): @@ -71,6 +88,7 @@ def ret_checks(context: AssetExecutionContext): ] ) + # core execution materialize([ret_checks], instance=instance) asset_check_executions = instance.event_log_storage.get_asset_check_executions( asset_key=ret_checks.key, @@ -80,6 +98,11 @@ def ret_checks(context: AssetExecutionContext): assert len(asset_check_executions) == 1 assert asset_check_executions[0].status == AssetCheckExecutionRecordStatus.SUCCEEDED + # direct invocation + context = build_asset_context() + direct_results = ret_checks(context) + assert direct_results + def test_multi_asset(): @multi_asset(outs={"one": AssetOut(), "two": AssetOut()}) @@ -88,7 +111,11 @@ def outs_multi_asset(): asset_key="two", metadata={"baz": "qux"} ) - materialize([outs_multi_asset]) + assert materialize([outs_multi_asset]).success + + res = outs_multi_asset() + assert res[0].metadata["foo"] == "bar" + assert res[1].metadata["baz"] == "qux" @multi_asset(specs=[AssetSpec(["prefix", "one"]), AssetSpec(["prefix", "two"])]) def specs_multi_asset(): @@ -96,7 +123,11 @@ def specs_multi_asset(): asset_key=["prefix", "one"], metadata={"foo": "bar"} ), MaterializeResult(asset_key=["prefix", "two"], metadata={"baz": "qux"}) - materialize([specs_multi_asset]) + assert materialize([specs_multi_asset]).success + + res = specs_multi_asset() + assert res[0].metadata["foo"] == "bar" + assert res[1].metadata["baz"] == "qux" def test_return_materialization_multi_asset(): @@ -122,6 +153,9 @@ def multi(): assert "two" in mats[1].metadata assert mats[1].tags + direct_results = list(multi()) + assert len(direct_results) == 2 + # # missing a non optional out # @@ -141,6 +175,12 @@ def missing(): ): _exec_asset(missing) + with pytest.raises( + DagsterInvariantViolationError, + match='Invocation of op "missing" did not return an output for non-optional output "two"', + ): + list(missing()) + # # missing asset_key # @@ -162,6 +202,15 @@ def no_key(): ): _exec_asset(no_key) + with pytest.raises( + DagsterInvariantViolationError, + match=( + "MaterializeResult did not include asset_key and it can not be inferred. Specify which" + " asset_key, options are:" + ), + ): + list(no_key()) + # # return tuple success # @@ -186,6 +235,9 @@ def ret_multi(): assert "two" in mats[1].metadata assert mats[1].tags + res = ret_multi() + assert len(res) == 2 + # # return list error # @@ -212,6 +264,15 @@ def ret_list(): ): _exec_asset(ret_list) + with pytest.raises( + DagsterInvariantViolationError, + match=( + "When using multiple outputs, either yield each output, or return a tuple containing a" + " value for each output." + ), + ): + ret_list() + def test_materialize_result_output_typing(): # Test that the return annotation MaterializeResult is interpreted as a Nothing type, since we @@ -279,41 +340,17 @@ def generator_asset() -> Generator[MaterializeResult, None, None]: materialize([generator_asset], resources={"io_manager": TestingIOManager()}) -def test_direct_invocation_materialize_result(): - @asset - def my_asset() -> MaterializeResult: - return MaterializeResult(metadata={"foo": "bar"}) - - res = my_asset() - assert res.metadata["foo"] == "bar" - - @multi_asset(specs=[AssetSpec("one"), AssetSpec("two")]) - def specs_multi_asset(): - return MaterializeResult(asset_key="one", metadata={"foo": "bar"}), MaterializeResult( - asset_key="two", metadata={"baz": "qux"} - ) - - res = specs_multi_asset() - assert res[0].metadata["foo"] == "bar" - assert res[1].metadata["baz"] == "qux" - - @multi_asset(outs={"one": AssetOut(), "two": AssetOut()}) - def outs_multi_asset(): - return MaterializeResult(asset_key="one", metadata={"foo": "bar"}), MaterializeResult( - asset_key="two", metadata={"baz": "qux"} - ) - - res = outs_multi_asset() - assert res[0].metadata["foo"] == "bar" - assert res[1].metadata["baz"] == "qux" - - -def test_direct_invocation_for_generators(): +def test_materialize_result_generators(): @asset def generator_asset() -> Generator[MaterializeResult, None, None]: yield MaterializeResult(metadata={"foo": "bar"}) + res = _exec_asset(generator_asset) + assert len(res) == 1 + assert res[0].metadata["foo"].value == "bar" + res = list(generator_asset()) + assert len(res) == 1 assert res[0].metadata["foo"] == "bar" @multi_asset(specs=[AssetSpec("one"), AssetSpec("two")]) @@ -321,7 +358,13 @@ def generator_specs_multi_asset(): yield MaterializeResult(asset_key="one", metadata={"foo": "bar"}) yield MaterializeResult(asset_key="two", metadata={"baz": "qux"}) + res = _exec_asset(generator_specs_multi_asset) + assert len(res) == 2 + assert res[0].metadata["foo"].value == "bar" + assert res[1].metadata["baz"].value == "qux" + res = list(generator_specs_multi_asset()) + assert len(res) == 2 assert res[0].metadata["foo"] == "bar" assert res[1].metadata["baz"] == "qux" @@ -330,11 +373,52 @@ def generator_outs_multi_asset(): yield MaterializeResult(asset_key="one", metadata={"foo": "bar"}) yield MaterializeResult(asset_key="two", metadata={"baz": "qux"}) + res = _exec_asset(generator_outs_multi_asset) + assert len(res) == 2 + assert res[0].metadata["foo"].value == "bar" + assert res[1].metadata["baz"].value == "qux" + res = list(generator_outs_multi_asset()) + assert len(res) == 2 + assert res[0].metadata["foo"] == "bar" + assert res[1].metadata["baz"] == "qux" + + @multi_asset(specs=[AssetSpec("one"), AssetSpec("two")]) + async def async_specs_multi_asset(): + return MaterializeResult(asset_key="one", metadata={"foo": "bar"}), MaterializeResult( + asset_key="two", metadata={"baz": "qux"} + ) + + res = _exec_asset(async_specs_multi_asset) + assert len(res) == 2 + assert res[0].metadata["foo"].value == "bar" + assert res[1].metadata["baz"].value == "qux" + + res = asyncio.run(async_specs_multi_asset()) + assert len(res) == 2 assert res[0].metadata["foo"] == "bar" assert res[1].metadata["baz"] == "qux" - # need to test async generator case and coroutine see op_invocation.py:_type_check_output_wrapper for all cases + @multi_asset(specs=[AssetSpec("one"), AssetSpec("two")]) + async def async_gen_specs_multi_asset(): + yield MaterializeResult(asset_key="one", metadata={"foo": "bar"}) + yield MaterializeResult(asset_key="two", metadata={"baz": "qux"}) + + res = _exec_asset(async_gen_specs_multi_asset) + assert len(res) == 2 + assert res[0].metadata["foo"].value == "bar" + assert res[1].metadata["baz"].value == "qux" + + async def _run_async_gen(): + results = [] + async for result in async_gen_specs_multi_asset(): + results.append(result) + return results + + res = asyncio.run(_run_async_gen()) + assert len(res) == 2 + assert res[0].metadata["foo"] == "bar" + assert res[1].metadata["baz"] == "qux" def test_materialize_result_with_partitions():