Skip to content

Commit

Permalink
[RFC] dont enforce output values to be passed for multi_asset
Browse files Browse the repository at this point in the history
  • Loading branch information
alangenfeld committed Aug 8, 2023
1 parent f813d72 commit 360c246
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from dagster import In, OpDefinition, Out, Output, job
import pytest

from dagster import DagsterInvariantViolationError, In, Nothing, OpDefinition, Out, Output, job, op


def test_op_def_direct():
Expand All @@ -16,3 +18,68 @@ def the_job(x):

result = the_job.execute_in_process(input_values={"x": 5})
assert result.success


def test_multi_out_implicit_none():
#
# non-optional Nothing
#
@op(out={"a": Out(Nothing), "b": Out(Nothing)})
def implicit():
pass

implicit()

@job
def implicit_job():
implicit()

result = implicit_job.execute_in_process()
assert result.success

#
# optional (fails)
#

@op(out={"a": Out(Nothing), "b": Out(Nothing, is_required=False)})
def optional():
pass

with pytest.raises(
DagsterInvariantViolationError,
match="has multiple outputs, but only one output was returned",
):
optional()

@job
def optional_job():
optional()

with pytest.raises(
DagsterInvariantViolationError,
match="has multiple outputs, but only one output was returned",
):
optional_job.execute_in_process()

#
# untyped (fails)
#
@op(out={"a": Out(), "b": Out()})
def untyped():
pass

with pytest.raises(
DagsterInvariantViolationError,
match="has multiple outputs, but only one output was returned",
):
untyped()

@job
def untyped_job():
untyped()

with pytest.raises(
DagsterInvariantViolationError,
match="has multiple outputs, but only one output was returned",
):
untyped_job.execute_in_process()
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,13 @@ def get_mapping_key(self) -> Optional[str]:
# asset related methods
#############################################################################################

@public
@property
def has_assets_def(self) -> bool:
"""If there is a backing AssetsDefinition for what is currently executing."""
assets_def = self.job_def.asset_layer.assets_def_for_node(self.node_handle)
return assets_def is not None

@public
@property
def assets_def(self) -> AssetsDefinition:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,10 @@ def op(self) -> Node:
def op_def(self) -> OpDefinition:
return self._op_def

@property
def has_assets_def(self) -> bool:
return self._assets_def is not None

@property
def assets_def(self) -> AssetsDefinition:
if self._assets_def is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _zip_and_iterate_op_result(
result: Any, context: OpExecutionContext, output_defs: Sequence[OutputDefinition]
) -> Iterator[Tuple[int, Any, OutputDefinition]]:
if len(output_defs) > 1:
_validate_multi_return(context, result, output_defs)
result = _validate_multi_return(context, result, output_defs)
for position, (output_def, element) in enumerate(zip(output_defs, result)):
yield position, output_def, element
else:
Expand All @@ -143,7 +143,16 @@ def _validate_multi_return(
context: OpExecutionContext,
result: Any,
output_defs: Sequence[OutputDefinition],
) -> None:
) -> Any:
# special cases for implicit/explicit returned None
if result is None:
# extrapolate None -> (None, None, ...) when appropriate
if all(
output_def.dagster_type.is_nothing and output_def.is_required
for output_def in output_defs
):
return [None for _ in output_defs]

# When returning from an op with multiple outputs, the returned object must be a tuple of the same length as the number of outputs. At the time of the op's construction, we verify that a provided annotation is a tuple with the same length as the number of outputs, so if the result matches the number of output defs on the op, it will transitively also match the annotation.
if not isinstance(result, tuple):
raise DagsterInvariantViolationError(
Expand All @@ -162,6 +171,7 @@ def _validate_multi_return(
f"{len(output_tuple)} outputs, while "
f"{context.op_def.node_type_str} has {len(output_defs)} outputs."
)
return result


def _get_annotation_for_output_position(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@
DagsterInvalidPropertyError,
DagsterInvariantViolationError,
)
from dagster._core.execution.context.compute import AssetExecutionContext
from dagster._core.instance import DagsterInstance
from dagster._core.storage.mem_io_manager import InMemoryIOManager
from dagster._core.test_utils import instance_for_test
from dagster._core.types.dagster_type import Nothing


def test_with_replaced_asset_keys():
Expand Down Expand Up @@ -1539,3 +1541,55 @@ def test_asset_key_with_prefix():

with pytest.raises(CheckError):
AssetKey("foo").with_prefix(1)


def _exec_asset(asset_def):
asset_job = define_asset_job("testing", [asset_def]).resolve(
asset_graph=AssetGraph.from_assets([asset_def])
)

result = asset_job.execute_in_process()
assert result.success

return result.asset_materializations_for_node(asset_def.node_def.name)


def test_multi_asset_return_none():
@multi_asset(
outs={
"asset1": AssetOut(dagster_type=Nothing),
"asset2": AssetOut(dagster_type=Nothing),
},
)
def my_function():
# ...materialize assets without IO manager
pass

# via job
_exec_asset(my_function)

# direct invoke
my_function()

@multi_asset(
outs={
"asset1": AssetOut(dagster_type=Nothing, is_required=False),
"asset2": AssetOut(dagster_type=Nothing, is_required=False),
},
can_subset=True,
)
def subset(context: AssetExecutionContext):
# ...use context.selected_asset_keys materialize subset of assets without IO manager
pass

with pytest.raises(
DagsterInvariantViolationError,
match="has multiple outputs, but only one output was returned",
):
_exec_asset(subset)

with pytest.raises(
DagsterInvariantViolationError,
match="has multiple outputs, but only one output was returned",
):
subset(build_asset_context())
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dagster import In, OpDefinition, Out, Output, job


def test_op_def_direct():
def the_op_fn(_, inputs):
assert inputs["x"] == 5
yield Output(inputs["x"] + 1, output_name="the_output")

op_def = OpDefinition(
the_op_fn, "the_op", ins={"x": In(dagster_type=int)}, outs={"the_output": Out(int)}
)

@job
def the_job(x):
op_def(x)

result = the_job.execute_in_process(input_values={"x": 5})
assert result.success

0 comments on commit 360c246

Please sign in to comment.