Skip to content

Commit

Permalink
thread concurrency key into execution step
Browse files Browse the repository at this point in the history
  • Loading branch information
prha committed Dec 3, 2024
1 parent 0fc225d commit 0c1183c
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 26 deletions.
9 changes: 4 additions & 5 deletions python_modules/dagster/dagster/_core/execution/plan/active.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from dagster._core.execution.plan.state import KnownExecutionState
from dagster._core.execution.plan.step import ExecutionStep
from dagster._core.execution.retries import RetryMode, RetryState
from dagster._core.storage.tags import GLOBAL_CONCURRENCY_TAG, PRIORITY_TAG
from dagster._core.storage.tags import PRIORITY_TAG
from dagster._utils.interrupts import pop_captured_interrupt
from dagster._utils.tags import TagConcurrencyLimitsCounter

Expand Down Expand Up @@ -350,15 +350,14 @@ def get_steps_to_execute(
if run_scoped_concurrency_limits_counter:
run_scoped_concurrency_limits_counter.update_counters_with_launched_item(step)

step_concurrency_key = step.tags.get(GLOBAL_CONCURRENCY_TAG)
if step_concurrency_key and self._instance_concurrency_context:
if step.concurrency_key and self._instance_concurrency_context:
try:
step_priority = int(step.tags.get(PRIORITY_TAG, 0))
except ValueError:
step_priority = 0

if not self._instance_concurrency_context.claim(
step_concurrency_key, step.key, step_priority
step.concurrency_key, step.key, step_priority
):
continue

Expand Down Expand Up @@ -656,7 +655,7 @@ def concurrency_event_iterator(
):
step = self.get_step_by_key(step_key)
step_context = plan_context.for_step(step)
step_concurrency_key = cast(str, step.tags.get(GLOBAL_CONCURRENCY_TAG))
step_concurrency_key = cast(str, step.concurrency_key)
self._messaged_concurrency_slots[step_key] = time.time()
is_initial_message = last_messaged_timestamp is None
yield DagsterEvent.step_concurrency_blocked(
Expand Down
3 changes: 3 additions & 0 deletions python_modules/dagster/dagster/_core/execution/plan/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def _build_from_sorted_nodes(
),
step_outputs=step_outputs,
tags=node.tags,
concurrency_key=node.definition.concurrency_key,
)
elif has_pending_input:
new_step = UnresolvedCollectExecutionStep(
Expand All @@ -326,6 +327,7 @@ def _build_from_sorted_nodes(
),
step_outputs=step_outputs,
tags=node.tags,
concurrency_key=node.definition.concurrency_key,
)
else:
new_step = ExecutionStep(
Expand All @@ -334,6 +336,7 @@ def _build_from_sorted_nodes(
step_inputs=cast(List[StepInput], step_inputs),
step_outputs=step_outputs,
tags=node.tags,
concurrency_key=node.definition.concurrency_key,
)

self.add_step(new_step)
Expand Down
16 changes: 16 additions & 0 deletions python_modules/dagster/dagster/_core/execution/plan/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ def kind(self) -> StepKind:
def tags(self) -> Optional[Mapping[str, str]]:
pass

@property
@abstractmethod
def concurrency_key(self) -> Optional[str]:
pass

@property
@abstractmethod
def step_inputs(
Expand Down Expand Up @@ -132,6 +137,7 @@ class ExecutionStep(
("tags", Mapping[str, str]),
("logging_tags", Mapping[str, str]),
("key", str),
("concurrency_key", Optional[str]),
],
),
IExecutionStep,
Expand All @@ -145,6 +151,7 @@ def __new__(
step_inputs: Sequence[StepInput],
step_outputs: Sequence[StepOutput],
tags: Optional[Mapping[str, str]],
concurrency_key: Optional[str],
logging_tags: Optional[Mapping[str, str]] = None,
key: Optional[str] = None,
):
Expand All @@ -161,6 +168,7 @@ def __new__(
for so in check.sequence_param(step_outputs, "step_outputs", of_type=StepOutput)
},
tags=tags or {},
concurrency_key=check.opt_str_param(concurrency_key, "concurrency_key"),
logging_tags=merge_dicts(
{
"step_key": handle.to_key(),
Expand Down Expand Up @@ -231,6 +239,7 @@ class UnresolvedMappedExecutionStep(
("step_input_dict", Mapping[str, Union[StepInput, UnresolvedMappedStepInput]]),
("step_output_dict", Mapping[str, StepOutput]),
("tags", Mapping[str, str]),
("concurrency_key", Optional[str]),
],
),
IExecutionStep,
Expand All @@ -244,6 +253,7 @@ def __new__(
step_inputs: Sequence[Union[StepInput, UnresolvedMappedStepInput]],
step_outputs: Sequence[StepOutput],
tags: Optional[Mapping[str, str]],
concurrency_key: Optional[str],
):
return super(UnresolvedMappedExecutionStep, cls).__new__(
cls,
Expand All @@ -260,6 +270,7 @@ def __new__(
for so in check.sequence_param(step_outputs, "step_outputs", of_type=StepOutput)
},
tags=check.opt_mapping_param(tags, "tags", key_type=str),
concurrency_key=check.opt_str_param(concurrency_key, "concurrency_key"),
)

@property
Expand Down Expand Up @@ -364,6 +375,7 @@ def resolve(
step_inputs=resolved_inputs,
step_outputs=self.step_outputs,
tags=self.tags,
concurrency_key=self.concurrency_key,
)
)

Expand All @@ -389,6 +401,7 @@ class UnresolvedCollectExecutionStep(
("step_input_dict", Mapping[str, Union[StepInput, UnresolvedCollectStepInput]]),
("step_output_dict", Mapping[str, StepOutput]),
("tags", Mapping[str, str]),
("concurrency_key", Optional[str]),
],
),
IExecutionStep,
Expand All @@ -402,6 +415,7 @@ def __new__(
step_inputs: Sequence[Union[StepInput, UnresolvedCollectStepInput]],
step_outputs: Sequence[StepOutput],
tags: Optional[Mapping[str, str]],
concurrency_key: Optional[str],
):
return super(UnresolvedCollectExecutionStep, cls).__new__(
cls,
Expand All @@ -418,6 +432,7 @@ def __new__(
for so in check.sequence_param(step_outputs, "step_outputs", of_type=StepOutput)
},
tags=check.opt_mapping_param(tags, "tags", key_type=str),
concurrency_key=check.opt_str_param(concurrency_key, "concurrency_key"),
)

@property
Expand Down Expand Up @@ -499,4 +514,5 @@ def resolve(
step_inputs=resolved_inputs,
step_outputs=self.step_outputs,
tags=self.tags,
concurrency_key=self.concurrency_key,
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
RunOpConcurrency,
RunRecord,
)
from dagster._core.storage.tags import GLOBAL_CONCURRENCY_TAG
from dagster._time import get_current_timestamp


Expand All @@ -28,11 +27,10 @@ def compute_run_op_concurrency_info_for_snapshot(
for step in plan_snapshot.steps:
if step.key not in root_step_keys:
continue
concurrency_key = step.tags.get(GLOBAL_CONCURRENCY_TAG) if step.tags else None
if concurrency_key is None:
if step.concurrency_key is None:
has_unconstrained_root_nodes = True
else:
concurrency_key_counts[concurrency_key] += 1
concurrency_key_counts[step.concurrency_key] += 1

if len(concurrency_key_counts) == 0:
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
UnresolvedCollectExecutionStep,
UnresolvedMappedExecutionStep,
)
from dagster._core.storage.tags import GLOBAL_CONCURRENCY_TAG
from dagster._serdes import create_snapshot_id, whitelist_for_serdes
from dagster._utils.error import SerializableErrorInfo

Expand Down Expand Up @@ -148,6 +149,7 @@ class ExecutionStepSnap(
("metadata_items", Sequence["ExecutionPlanMetadataItemSnap"]),
("tags", Optional[Mapping[str, str]]),
("step_handle", Optional[StepHandleUnion]),
("step_concurrency_key", Optional[str]),
],
)
):
Expand All @@ -161,6 +163,7 @@ def __new__(
metadata_items: Sequence["ExecutionPlanMetadataItemSnap"],
tags: Optional[Mapping[str, str]] = None,
step_handle: Optional[StepHandleUnion] = None,
step_concurrency_key: Optional[str] = None,
):
return super(ExecutionStepSnap, cls).__new__(
cls,
Expand All @@ -174,8 +177,22 @@ def __new__(
),
tags=check.opt_nullable_mapping_param(tags, "tags", key_type=str, value_type=str),
step_handle=check.opt_inst_param(step_handle, "step_handle", StepHandleTypes),
# stores the concurrency key arg as separate from the concurrency_key property since the
# snapshot may have been generated before concurrency_key was added as a separate
# argument
step_concurrency_key=check.opt_str_param(step_concurrency_key, "step_concurrency_key"),
)

@property
def concurrency_key(self):
# Need to sthere in case the snapshot was created before concurrency_key was added as
# a separate argument from tags
if self.step_concurrency_key:
return self.step_concurrency_key
if not self.tags:
return None
return self.tags.get(GLOBAL_CONCURRENCY_TAG)


@whitelist_for_serdes
class ExecutionStepInputSnap(
Expand Down Expand Up @@ -308,6 +325,7 @@ def _snapshot_from_execution_step(execution_step: IExecutionStep) -> ExecutionSt
),
tags=execution_step.tags,
step_handle=execution_step.handle,
step_concurrency_key=execution_step.concurrency_key,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,19 @@ def test_recover_in_between_steps():
)


def define_concurrency_job():
@op(tags={GLOBAL_CONCURRENCY_TAG: "foo"})
def define_concurrency_job(use_tags):
if use_tags:
tags = {GLOBAL_CONCURRENCY_TAG: "foo"}
concurrency_kwarg = None
else:
tags = None
concurrency_kwarg = "foo"

@op(tags=tags, concurrency_key=concurrency_kwarg)
def foo_op():
pass

@op(tags={GLOBAL_CONCURRENCY_TAG: "foo"})
@op(tags=tags, concurrency_key=concurrency_kwarg)
def bar_op():
pass

Expand All @@ -178,8 +185,9 @@ def foo_job():
return foo_job


def test_active_concurrency():
foo_job = define_concurrency_job()
@pytest.mark.parametrize("use_tags", [True, False])
def test_active_concurrency(use_tags):
foo_job = define_concurrency_job(use_tags)
run_id = make_new_run_id()

with tempfile.TemporaryDirectory() as temp_dir:
Expand Down Expand Up @@ -258,8 +266,15 @@ def free_step(self, step_key) -> None:
pass


def define_concurrency_retry_job():
@op(tags={GLOBAL_CONCURRENCY_TAG: "foo"})
def define_concurrency_retry_job(use_tags):
if use_tags:
tags = {GLOBAL_CONCURRENCY_TAG: "foo"}
concurrency_kwarg = None
else:
tags = None
concurrency_kwarg = "foo"

@op(tags=tags, concurrency_key=concurrency_kwarg)
def foo_op():
pass

Expand All @@ -275,9 +290,10 @@ def foo_job():
return foo_job


def test_active_concurrency_sleep():
@pytest.mark.parametrize("use_tags", [True, False])
def test_active_concurrency_sleep(use_tags):
instance_concurrency_context = MockInstanceConcurrencyContext(2.0)
foo_job = define_concurrency_retry_job()
foo_job = define_concurrency_retry_job(use_tags)
with pytest.raises(DagsterExecutionInterruptedError):
with create_execution_plan(foo_job).start(
RetryMode.ENABLED,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from dagster._core.execution.api import execute_job, execute_run_iterator
from dagster._core.instance import DagsterInstance
from dagster._core.storage.dagster_run import DagsterRunStatus
from dagster._core.storage.tags import GLOBAL_CONCURRENCY_TAG
from dagster._core.test_utils import poll_for_finished_run
from dagster._core.workspace.context import WorkspaceRequestContext

Expand All @@ -19,12 +18,12 @@
)


@op(tags={GLOBAL_CONCURRENCY_TAG: "foo"})
@op(concurrency_key="foo")
def should_never_execute(_x):
assert False # this should never execute


@op(tags={GLOBAL_CONCURRENCY_TAG: "foo"})
@op(concurrency_key="foo")
def throw_error():
raise Exception("bad programmer")

Expand All @@ -34,14 +33,14 @@ def error_graph():
should_never_execute(throw_error())


@op(tags={GLOBAL_CONCURRENCY_TAG: "foo"})
@op(concurrency_key="foo")
def simple_op(context):
time.sleep(0.1)
foo_info = context.instance.event_log_storage.get_concurrency_info("foo")
return {"active": foo_info.active_slot_count, "pending": foo_info.pending_step_count}


@op(tags={GLOBAL_CONCURRENCY_TAG: "foo"})
@op(concurrency_key="foo")
def second_op(context, _):
time.sleep(0.1)
foo_info = context.instance.event_log_storage.get_concurrency_info("foo")
Expand All @@ -67,7 +66,7 @@ def two_tier_graph():
second_op(simple_op())


@op(tags={GLOBAL_CONCURRENCY_TAG: "foo"}, retry_policy=RetryPolicy(max_retries=1))
@op(concurrency_key="foo", retry_policy=RetryPolicy(max_retries=1))
def retry_op():
raise Failure("I fail")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
StepHandler,
)
from dagster._core.instance import DagsterInstance
from dagster._core.storage.tags import GLOBAL_CONCURRENCY_TAG

# from dagster._core.storage.tags import GLOBAL_CONCURRENCY_TAG
from dagster._core.test_utils import environ, instance_for_test
from dagster._utils.merger import merge_dicts
from dagster._utils.test.definitions import lazy_definitions, scoped_definitions_load_context
Expand Down Expand Up @@ -496,7 +497,7 @@ def test_dynamic_failure_retry(job_fn, config_fn):
assert_expected_failure_behavior(job_fn, config_fn)


@op(tags={GLOBAL_CONCURRENCY_TAG: "foo"})
@op(concurrency_key="foo")
def simple_op(context):
time.sleep(0.1)
foo_info = context.instance.event_log_storage.get_concurrency_info("foo")
Expand Down

0 comments on commit 0c1183c

Please sign in to comment.