Skip to content

Commit

Permalink
add concurrency key to the op definition
Browse files Browse the repository at this point in the history
  • Loading branch information
prha committed Dec 3, 2024
1 parent 7941570 commit 0fc225d
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def asset(
check_specs: Optional[Sequence[AssetCheckSpec]] = ...,
owners: Optional[Sequence[str]] = ...,
kinds: Optional[AbstractSet[str]] = ...,
concurrency_key: Optional[str] = ...,
**kwargs,
) -> Callable[[Callable[..., Any]], AssetsDefinition]: ...

Expand Down Expand Up @@ -183,6 +184,7 @@ def asset(
check_specs: Optional[Sequence[AssetCheckSpec]] = None,
owners: Optional[Sequence[str]] = None,
kinds: Optional[AbstractSet[str]] = None,
concurrency_key: Optional[str] = ...,
**kwargs,
) -> Union[AssetsDefinition, Callable[[Callable[..., Any]], AssetsDefinition]]:
"""Create a definition for how to compute an asset.
Expand Down Expand Up @@ -261,6 +263,8 @@ def asset(
non_argument_deps (Optional[Union[Set[AssetKey], Set[str]]]): Deprecated, use deps instead.
Set of asset keys that are upstream dependencies, but do not pass an input to the asset.
Hidden parameter not exposed in the decorator signature, but passed in kwargs.
concurrency_key (Optional[str]): A string that identifies the concurrency limit group that governs
this asset's execution.
Examples:
.. code-block:: python
Expand Down Expand Up @@ -318,6 +322,7 @@ def my_asset(my_upstream_asset: int) -> int:
check_specs=check_specs,
key=key,
owners=owners,
concurrency_key=concurrency_key,
)

if compute_fn is not None:
Expand Down Expand Up @@ -390,6 +395,7 @@ class AssetDecoratorArgs(NamedTuple):
key: Optional[CoercibleToAssetKey]
check_specs: Optional[Sequence[AssetCheckSpec]]
owners: Optional[Sequence[str]]
concurrency_key: Optional[str]


class ResourceRelatedState(NamedTuple):
Expand Down Expand Up @@ -514,6 +520,7 @@ def create_assets_def_from_fn_and_decorator_args(
can_subset=False,
decorator_name="@asset",
execution_type=AssetExecutionType.MATERIALIZATION,
concurrency_key=args.concurrency_key,
)

builder = DecoratorAssetsDefinitionBuilder.from_asset_outs_in_asset_centric_decorator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ class DecoratorAssetsDefinitionBuilderArgs(NamedTuple):
specs: Sequence[AssetSpec]
upstream_asset_deps: Optional[Iterable[AssetDep]]
execution_type: Optional[AssetExecutionType]
concurrency_key: Optional[str]

@property
def check_specs(self) -> Sequence[AssetCheckSpec]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
retry_policy: Optional[RetryPolicy] = None,
ins: Optional[Mapping[str, In]] = None,
out: Optional[Union[Out, Mapping[str, Out]]] = None,
concurrency_key: Optional[str] = None,
):
self.name = check.opt_str_param(name, "name")
self.decorator_takes_context = check.bool_param(
Expand All @@ -65,6 +66,7 @@ def __init__(
self.tags = tags
self.code_version = code_version
self.retry_policy = retry_policy
self.concurrency_key = concurrency_key

# config will be checked within OpDefinition
self.config_schema = config_schema
Expand Down Expand Up @@ -132,6 +134,7 @@ def __call__(self, fn: Callable[..., Any]) -> "OpDefinition":
code_version=self.code_version,
retry_policy=self.retry_policy,
version=None, # code_version has replaced version
concurrency_key=self.concurrency_key,
)
update_wrapper(op_def, compute_fn.decorated_fn)
return op_def
Expand All @@ -154,6 +157,7 @@ def op(
version: Optional[str] = ...,
retry_policy: Optional[RetryPolicy] = ...,
code_version: Optional[str] = ...,
concurrency_key: Optional[str] = None,
) -> _Op: ...


Expand All @@ -173,6 +177,7 @@ def op(
version: Optional[str] = None,
retry_policy: Optional[RetryPolicy] = None,
code_version: Optional[str] = None,
concurrency_key: Optional[str] = None,
) -> Union["OpDefinition", _Op]:
"""Create an op with the specified parameters from the decorated function.
Expand Down Expand Up @@ -266,6 +271,7 @@ def multi_out() -> Tuple[str, int]:
retry_policy=retry_policy,
ins=ins,
out=out,
concurrency_key=concurrency_key,
)


Expand Down
33 changes: 33 additions & 0 deletions python_modules/dagster/dagster/_core/definitions/op_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class OpDefinition(NodeDefinition, IHasInternalInit):
code_version (Optional[str]): (Experimental) Version of the code encapsulated by the op. If set,
this is used as a default code version for all outputs.
retry_policy (Optional[RetryPolicy]): The retry policy for this op.
concurrency_key (Optional[str]): A string that identifies the concurrency limit group that governs
this op's execution.
Examples:
Expand All @@ -112,6 +114,7 @@ def _add_one(_context, inputs):
_required_resource_keys: AbstractSet[str]
_version: Optional[str]
_retry_policy: Optional[RetryPolicy]
_concurrency_key: Optional[str]

def __init__(
self,
Expand All @@ -126,6 +129,7 @@ def __init__(
version: Optional[str] = None,
retry_policy: Optional[RetryPolicy] = None,
code_version: Optional[str] = None,
concurrency_key: Optional[str] = None,
):
from dagster._core.definitions.decorators.op_decorator import (
DecoratedOpFunction,
Expand Down Expand Up @@ -170,6 +174,7 @@ def __init__(
check.opt_set_param(required_resource_keys, "required_resource_keys", of_type=str)
)
self._retry_policy = check.opt_inst_param(retry_policy, "retry_policy", RetryPolicy)
self._concurrency_key = _validate_concurrency_key(concurrency_key, tags)

positional_inputs = (
self._compute_fn.positional_inputs()
Expand Down Expand Up @@ -199,6 +204,7 @@ def dagster_internal_init(
version: Optional[str],
retry_policy: Optional[RetryPolicy],
code_version: Optional[str],
concurrency_key: Optional[str],
) -> "OpDefinition":
return OpDefinition(
compute_fn=compute_fn,
Expand All @@ -212,6 +218,7 @@ def dagster_internal_init(
version=version,
retry_policy=retry_policy,
code_version=code_version,
concurrency_key=concurrency_key,
)

@property
Expand Down Expand Up @@ -297,6 +304,11 @@ def with_retry_policy(self, retry_policy: RetryPolicy) -> "PendingNodeInvocation
"""Creates a copy of this op with the given retry policy."""
return super(OpDefinition, self).with_retry_policy(retry_policy)

@property
def concurrency_key(self) -> Optional[str]:
"""Optional[str]: The concurrency key for this op."""
return self._concurrency_key

def is_from_decorator(self) -> bool:
from dagster._core.definitions.decorators.op_decorator import DecoratedOpFunction

Expand Down Expand Up @@ -381,6 +393,7 @@ def with_replaced_properties(
code_version=self._version,
retry_policy=self.retry_policy,
version=None, # code_version replaces version
concurrency_key=self.concurrency_key,
)

def copy_for_configured(
Expand Down Expand Up @@ -587,3 +600,23 @@ def _validate_context_type_hint(fn):
def _is_result_object_type(ttype):
# Is this type special result object type
return ttype in (MaterializeResult, ObserveResult, AssetCheckResult)


def _validate_concurrency_key(concurrency_key, tags):
from dagster._core.storage.tags import GLOBAL_CONCURRENCY_TAG

check.opt_str_param(concurrency_key, "concurrency_key")
tags = check.opt_mapping_param(tags, "tags")
tag_concurrency_key = tags.get(GLOBAL_CONCURRENCY_TAG)
if concurrency_key and tag_concurrency_key and concurrency_key != tag_concurrency_key:
raise DagsterInvalidDefinitionError(
f"Op '{concurrency_key}' has a concurrency key '{concurrency_key}' that conflicts with the concurrency key tag '{tag_concurrency_key}'."
)

if concurrency_key:
return concurrency_key

if tag_concurrency_key:
return tag_concurrency_key

return None

0 comments on commit 0fc225d

Please sign in to comment.