From 56bd783d4317199b5f97e806638a94888c02c720 Mon Sep 17 00:00:00 2001 From: Daniel Gibson Date: Sat, 28 Sep 2024 22:35:54 -0500 Subject: [PATCH] ECS Executor WIP Summary: Proving out the possibility of building out an ECS executor. It assumes that it is launched via the EcsRunLauncher and can use the same task definition of the task in which it is launched, but still allows customizing of memory,cpu, ephemeral storage, and whatever else you can override via run_task arguments. --- .../step_delegating_executor.py | 53 ++- .../step_delegating/step_handler/base.py | 13 +- .../dagster-aws/dagster_aws/ecs/executor.py | 304 ++++++++++++++++++ .../dagster-aws/dagster_aws/ecs/launcher.py | 54 +--- .../dagster-aws/dagster_aws/ecs/utils.py | 34 ++ .../dagster_docker/docker_executor.py | 10 +- .../dagster-k8s/dagster_k8s/executor.py | 18 +- 7 files changed, 404 insertions(+), 82 deletions(-) create mode 100644 python_modules/libraries/dagster-aws/dagster_aws/ecs/executor.py diff --git a/python_modules/dagster/dagster/_core/executor/step_delegating/step_delegating_executor.py b/python_modules/dagster/dagster/_core/executor/step_delegating/step_delegating_executor.py index 263f20eb1cc72..9fad7f5a620f6 100644 --- a/python_modules/dagster/dagster/_core/executor/step_delegating/step_delegating_executor.py +++ b/python_modules/dagster/dagster/_core/executor/step_delegating/step_delegating_executor.py @@ -178,6 +178,7 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut instance_concurrency_context=instance_concurrency_context, ) as active_execution: running_steps: Dict[str, ExecutionStep] = {} + step_worker_handles: Dict[str, Optional[str]] = {} if plan_context.resume_from_failure: DagsterEvent.engine_event( @@ -211,7 +212,8 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut try: health_check = self._step_handler.check_step_health( - step_handler_context + step_handler_context, + step_worker_handle=None, ) except Exception: # For now we assume that an exception indicates that the step should be resumed. @@ -237,15 +239,14 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut if should_retry_step: # health check failed, launch the step - list( - self._step_handler.launch_step( - self._get_step_handler_context( - plan_context, [step], active_execution - ) + self._step_handler.launch_step( + self._get_step_handler_context( + plan_context, [step], active_execution ) ) running_steps[step.key] = step + step_worker_handles[step.key] = None last_check_step_health_time = get_current_datetime() @@ -262,13 +263,12 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut "Executor received termination signal, forwarding to steps", EngineEventData.interrupted(list(running_steps.keys())), ) - for step in running_steps.values(): - list( - self._step_handler.terminate_step( - self._get_step_handler_context( - plan_context, [step], active_execution - ) - ) + for step_key, step in running_steps.items(): + self._step_handler.terminate_step( + self._get_step_handler_context( + plan_context, [step], active_execution + ), + step_worker_handle=step_worker_handles[step_key], ) else: DagsterEvent.engine_event( @@ -311,6 +311,7 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut ): assert isinstance(dagster_event.step_key, str) del running_steps[dagster_event.step_key] + del step_worker_handles[dagster_event.step_key] if not dagster_event.is_step_up_for_retry: active_execution.verify_complete( @@ -325,14 +326,15 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut curr_time - last_check_step_health_time ).total_seconds() >= self._check_step_health_interval_seconds: last_check_step_health_time = curr_time - for step in running_steps.values(): + for step_key, step in running_steps.items(): step_context = plan_context.for_step(step) try: health_check_result = self._step_handler.check_step_health( self._get_step_handler_context( plan_context, [step], active_execution - ) + ), + step_worker_handle=step_worker_handles[step_key], ) if not health_check_result.is_healthy: health_check_error = SerializableErrorInfo( @@ -374,11 +376,9 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut for step in active_execution.get_steps_to_execute(max_steps_to_run): running_steps[step.key] = step - list( - self._step_handler.launch_step( - self._get_step_handler_context( - plan_context, [step], active_execution - ) + step_worker_handles[step.key] = self._step_handler.launch_step( + self._get_step_handler_context( + plan_context, [step], active_execution ) ) @@ -398,12 +398,11 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut error=serializable_error, ), ) - for step in running_steps.values(): - list( - self._step_handler.terminate_step( - self._get_step_handler_context( - plan_context, [step], active_execution - ) - ) + for step_key, step in running_steps.items(): + self._step_handler.terminate_step( + self._get_step_handler_context( + plan_context, [step], active_execution + ), + step_worker_handle=step_worker_handles[step_key], ) raise diff --git a/python_modules/dagster/dagster/_core/executor/step_delegating/step_handler/base.py b/python_modules/dagster/dagster/_core/executor/step_delegating/step_handler/base.py index a40fc76408e25..ed12a386869ea 100644 --- a/python_modules/dagster/dagster/_core/executor/step_delegating/step_handler/base.py +++ b/python_modules/dagster/dagster/_core/executor/step_delegating/step_handler/base.py @@ -1,11 +1,10 @@ from abc import ABC, abstractmethod -from typing import Iterator, Mapping, NamedTuple, Optional, Sequence +from typing import Mapping, NamedTuple, Optional, Sequence from dagster import ( DagsterInstance, _check as check, ) -from dagster._core.events import DagsterEvent from dagster._core.execution.context.system import IStepContext, PlanOrchestrationContext from dagster._core.execution.plan.step import ExecutionStep from dagster._core.storage.dagster_run import DagsterRun @@ -83,13 +82,17 @@ def name(self) -> str: pass @abstractmethod - def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: + def launch_step(self, step_handler_context: StepHandlerContext) -> Optional[str]: pass @abstractmethod - def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckStepHealthResult: + def check_step_health( + self, step_handler_context: StepHandlerContext, step_worker_handle: Optional[str] + ) -> CheckStepHealthResult: pass @abstractmethod - def terminate_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: + def terminate_step( + self, step_handler_context: StepHandlerContext, step_worker_handle: Optional[str] + ) -> None: pass diff --git a/python_modules/libraries/dagster-aws/dagster_aws/ecs/executor.py b/python_modules/libraries/dagster-aws/dagster_aws/ecs/executor.py new file mode 100644 index 0000000000000..43e8e9a192b0d --- /dev/null +++ b/python_modules/libraries/dagster-aws/dagster_aws/ecs/executor.py @@ -0,0 +1,304 @@ +import json +import os +from typing import Any, List, Mapping, Optional, cast + +import boto3 +from dagster import ( + Field, + IntSource, + Permissive, + _check as check, + executor, +) +from dagster._core.definitions.executor_definition import multiple_process_executor_requirements +from dagster._core.definitions.metadata import MetadataValue +from dagster._core.events import DagsterEvent, EngineEventData +from dagster._core.execution.retries import RetryMode, get_retries_config +from dagster._core.execution.tags import get_tag_concurrency_limits_config +from dagster._core.executor.base import Executor +from dagster._core.executor.init import InitExecutorContext +from dagster._core.executor.step_delegating import ( + CheckStepHealthResult, + StepDelegatingExecutor, + StepHandler, + StepHandlerContext, +) +from dagster._utils.backoff import backoff + +from dagster_aws.ecs.container_context import EcsContainerContext +from dagster_aws.ecs.launcher import STOPPED_STATUSES, EcsRunLauncher +from dagster_aws.ecs.tasks import get_current_ecs_task, get_current_ecs_task_metadata +from dagster_aws.ecs.utils import RetryableEcsException, run_ecs_task + +DEFAULT_STEP_TASK_RETRIES = "5" + + +@executor( + name="ecs", + config_schema={ + "run_task_kwargs": Field( + Permissive({}), + is_required=False, + description=( + "Additional arguments to include while running the task. See" + " https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.run_task" + " for the available parameters. The overrides and taskDefinition arguments will" + " always be set by the run launcher." + ), + ), + "cpu": Field(IntSource, is_required=False), + "memory": Field(IntSource, is_required=False), + "ephemeral_storage": Field(IntSource, is_required=False), + "task_overrides": Field( + Permissive({}), + is_required=False, + ), + "retries": get_retries_config(), + "max_concurrent": Field( + IntSource, + is_required=False, + description=( + "Limit on the number of pods that will run concurrently within the scope " + "of a Dagster run. Note that this limit is per run, not global." + ), + ), + "tag_concurrency_limits": get_tag_concurrency_limits_config(), + }, + requirements=multiple_process_executor_requirements(), +) +def ecs_executor(init_context: InitExecutorContext) -> Executor: + """Executor which launches steps as ECS tasks.""" + run_launcher = init_context.instance.run_launcher + + check.invariant( + isinstance(run_launcher, EcsRunLauncher), + "Using the ecs_executor currently requires that the run be launched in an ECS task via the EcsRunLauncher.", + ) + + exc_cfg = init_context.executor_config + + return StepDelegatingExecutor( + EcsStepHandler( + run_task_kwargs=exc_cfg.get("run_task_kwargs"), # type: ignore + cpu=exc_cfg.get("cpu"), # type: ignore + memory=exc_cfg.get("memory"), # type: ignore + ephemeral_storage=exc_cfg.get("ephemeral_storage"), # type: ignore + task_overrides=exc_cfg.get("task_overrides"), # type:ignore + ), + retries=RetryMode.from_config(exc_cfg["retries"]), # type: ignore + max_concurrent=check.opt_int_elem(exc_cfg, "max_concurrent"), + tag_concurrency_limits=check.opt_list_elem(exc_cfg, "tag_concurrency_limits"), + should_verify_step=True, + ) + + +class EcsStepHandler(StepHandler): + @property + def name(self): + return "EcsStepHandler" + + def __init__( + self, + run_task_kwargs: Mapping[str, Any], + cpu: Optional[int], + memory: Optional[int], + ephemeral_storage: Optional[int], + task_overrides: Optional[Mapping[str, Any]], + ): + super().__init__() + + self.ecs = boto3.client("ecs") + + # confusingly, run_task expects cpu and memory value as strings + self._cpu = str(cpu) if cpu else None + self._memory = str(memory) if memory else None + + self._ephemeral_storage = ephemeral_storage + self._task_overrides = check.opt_mapping_param(task_overrides, "task_overrides") + + current_task_metadata = get_current_ecs_task_metadata() + current_task = get_current_ecs_task( + self.ecs, current_task_metadata.task_arn, current_task_metadata.cluster + ) + self._cluster_arn = current_task["clusterArn"] + self._task_definition_arn = current_task["taskDefinitionArn"] + self._run_task_kwargs = { + **run_task_kwargs, + "taskDefinition": current_task["taskDefinitionArn"], + } + + def _get_run_task_kwargs( + self, + run, + args, + step_key: str, + step_tags: Mapping[str, str], + run_launcher: EcsRunLauncher, + container_context: EcsContainerContext, + ): + run_task_kwargs = self._run_task_kwargs + + kwargs_from_tags = step_tags.get("ecs/run_task_kwargs") + if kwargs_from_tags: + run_task_kwargs = {**run_task_kwargs, **json.loads(kwargs_from_tags)} + + run_task_kwargs["tags"] = [ + *run_task_kwargs.get("tags", []), + {"key": "dagster/run-id", "value": run.run_id}, + {"key": "dagster/job-name", "value": run.job_name}, + {"key": "dagster/step-key", "value": step_key}, + ] + + if run.external_job_origin: + run_task_kwargs["tags"] = [ + *run_task_kwargs["tags"], + { + "key": "dagster/code-location", + "value": run.external_job_origin.repository_origin.code_location_origin.location_name, + }, + ] + + overrides = { + # container name has to match since we are assuming we are using the same task + # definition as the run + "containerOverrides": [ + {"name": run_launcher.get_container_name(container_context), "command": args} + ], + **self._get_task_overrides(step_tags), + } + + run_task_kwargs["overrides"] = overrides + + return run_task_kwargs + + def _get_task_overrides(self, step_tags: Mapping[str, str]) -> Mapping[str, str]: + overrides = {} + + cpu = step_tags.get("ecs/cpu", self._cpu) + memory = step_tags.get("ecs/memory", self._memory) + + if cpu: + overrides["cpu"] = cpu + if memory: + overrides["memory"] = memory + + ephemeral_storage = step_tags.get("ecs/ephemeral_storage", self._ephemeral_storage) + + if ephemeral_storage: + overrides["ephemeralStorage"] = {"sizeInGiB": int(ephemeral_storage)} + + tag_overrides = step_tags.get("ecs/task_overrides") + if tag_overrides: + overrides = {**self._task_overrides, **overrides, **json.loads(tag_overrides)} + + return overrides + + def _get_step_key(self, step_handler_context: StepHandlerContext) -> str: + step_keys_to_execute = cast( + List[str], step_handler_context.execute_step_args.step_keys_to_execute + ) + assert len(step_keys_to_execute) == 1, "Launching multiple steps is not currently supported" + return step_keys_to_execute[0] + + def _get_container_context( + self, step_handler_context: StepHandlerContext + ) -> EcsContainerContext: + return EcsContainerContext.create_for_run( + step_handler_context.dagster_run, + cast(EcsRunLauncher, step_handler_context.instance.run_launcher), + ) + + def _run_task(self, **run_task_kwargs): + return run_ecs_task(self.ecs, run_task_kwargs) + + def launch_step(self, step_handler_context: StepHandlerContext) -> Optional[str]: + step_key = self._get_step_key(step_handler_context) + + step_tags = step_handler_context.step_tags[step_key] + + container_context = self._get_container_context(step_handler_context) + + run = step_handler_context.dagster_run + + args = step_handler_context.execute_step_args.get_command_args( + skip_serialized_namedtuple=True + ) + + run_task_kwargs = self._get_run_task_kwargs( + run, + args, + step_key, + step_tags, + cast(EcsRunLauncher, step_handler_context.instance.run_launcher), + container_context, + ) + + task = backoff( + self._run_task, + retry_on=(RetryableEcsException,), + kwargs=run_task_kwargs, + max_retries=int( + os.getenv("STEP_TASK_RETRIES", DEFAULT_STEP_TASK_RETRIES), + ), + ) + + DagsterEvent.step_worker_starting( + step_handler_context.get_step_context(step_key), + message=f'Executing step "{step_key}" in ECS task.', + metadata={ + "Task ARN": MetadataValue.text(task["taskArn"]), + }, + ) + + return task["taskArn"] + + def check_step_health( + self, step_handler_context: StepHandlerContext, step_worker_handle: Optional[str] + ) -> CheckStepHealthResult: + step_key = self._get_step_key(step_handler_context) + + task_arn = step_worker_handle + cluster_arn = self._cluster_arn + + tasks = self.ecs.describe_tasks(tasks=[task_arn], cluster=cluster_arn).get("tasks") + + if not tasks: + return CheckStepHealthResult.unhealthy( + reason=f"Task {task_arn} for step {step_key} could not be found." + ) + + t = tasks[0] + if t.get("lastStatus") in STOPPED_STATUSES: + failed_containers = [] + for c in t.get("containers"): + if c.get("exitCode") != 0: + failed_containers.append(c) + if len(failed_containers) > 0: + cluster_failure_info = ( + f"Task {t.get('taskArn')} failed.\n" + f"Stop code: {t.get('stopCode')}.\n" + f"Stop reason: {t.get('stoppedReason')}.\n" + ) + for c in failed_containers: + exit_code = c.get("exitCode") + exit_code_msg = f" - exit code {exit_code}" if exit_code is not None else "" + cluster_failure_info += f"Container '{c.get('name')}' failed{exit_code_msg}.\n" + + return CheckStepHealthResult.unhealthy(reason=cluster_failure_info) + + return CheckStepHealthResult.healthy() + + def terminate_step( + self, step_handler_context: StepHandlerContext, step_worker_handle: Optional[str] + ) -> None: + task_arn = step_worker_handle + cluster_arn = self._cluster_arn + step_key = self._get_step_key(step_handler_context) + + DagsterEvent.engine_event( + step_handler_context.get_step_context(step_key), + message=f"Deleting task {task_arn} for step", + event_specific_data=EngineEventData(), + ) + + self.ecs.stop_task(task=task_arn, cluster=cluster_arn) diff --git a/python_modules/libraries/dagster-aws/dagster_aws/ecs/launcher.py b/python_modules/libraries/dagster-aws/dagster_aws/ecs/launcher.py index d5990e82c8bef..5a78925fa26f4 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/ecs/launcher.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/ecs/launcher.py @@ -47,7 +47,13 @@ get_task_definition_dict_from_current_task, get_task_kwargs_from_current_task, ) -from dagster_aws.ecs.utils import get_task_definition_family, get_task_logs, task_definitions_match +from dagster_aws.ecs.utils import ( + RetryableEcsException, + get_task_definition_family, + get_task_logs, + run_ecs_task, + task_definitions_match, +) from dagster_aws.secretsmanager import get_secrets_from_arns Tags = namedtuple("Tags", ["arn", "cluster", "cpu", "memory"]) @@ -73,9 +79,6 @@ DEFAULT_RUN_TASK_RETRIES = 5 -class RetryableEcsException(Exception): ... - - class EcsRunLauncher(RunLauncher[T_DagsterInstance], ConfigurableClass): """RunLauncher that starts a task in ECS for each Dagster job run.""" @@ -433,34 +436,7 @@ def _get_image_for_run(self, context: LaunchRunContext) -> Optional[str]: return job_origin.repository_origin.container_image def _run_task(self, **run_task_kwargs): - response = self.ecs.run_task(**run_task_kwargs) - - tasks = response["tasks"] - - if not tasks: - failures = response["failures"] - failure_messages = [] - for failure in failures: - arn = failure.get("arn") - reason = failure.get("reason") - detail = failure.get("detail") - - failure_message = ( - "Task" - + (f" {arn}" if arn else "") - + " failed." - + (f" Failure reason: {reason}" if reason else "") - + (f" Failure details: {detail}" if detail else "") - ) - failure_messages.append(failure_message) - - failure_message = "\n".join(failure_messages) if failure_messages else "Task failed." - - if "Capacity is unavailable at this time" in failure_message: - raise RetryableEcsException(failure_message) - - raise Exception(failure_message) - return tasks[0] + return run_ecs_task(self.ecs, run_task_kwargs) def launch_run(self, context: LaunchRunContext) -> None: """Launch a run in an ECS task.""" @@ -500,7 +476,7 @@ def launch_run(self, context: LaunchRunContext) -> None: container_overrides: List[Dict[str, Any]] = [ { - "name": self._get_container_name(container_context), + "name": self.get_container_name(container_context), "command": command, # containerOverrides expects cpu/memory as integers **{k: int(v) for k, v in cpu_and_memory_overrides.items()}, @@ -645,7 +621,7 @@ def _get_current_task(self): def _get_run_task_definition_family(self, run: DagsterRun) -> str: return get_task_definition_family("run", check.not_none(run.remote_job_origin)) - def _get_container_name(self, container_context: EcsContainerContext) -> str: + def get_container_name(self, container_context: EcsContainerContext) -> str: return container_context.container_name or self.container_name def _run_task_kwargs( @@ -676,7 +652,7 @@ def _run_task_kwargs( task_definition_config = DagsterEcsTaskDefinitionConfig( family, image, - self._get_container_name(container_context), + self.get_container_name(container_context), command=None, log_configuration=( { @@ -716,7 +692,7 @@ def _run_task_kwargs( family, self._get_current_task(), image, - self._get_container_name(container_context), + self.get_container_name(container_context), environment=environment, secrets=secrets if secrets else {}, include_sidecars=self.include_sidecars, @@ -734,10 +710,10 @@ def _run_task_kwargs( task_definition_config = DagsterEcsTaskDefinitionConfig.from_task_definition_dict( task_definition_dict, - self._get_container_name(container_context), + self.get_container_name(container_context), ) - container_name = self._get_container_name(container_context) + container_name = self.get_container_name(container_context) backoff( self._reuse_or_register_task_definition, @@ -893,7 +869,7 @@ def check_run_worker_health(self, run: DagsterRun): logs_client=self.logs, cluster=tags.cluster, task_arn=tags.arn, - container_name=self._get_container_name(container_context), + container_name=self.get_container_name(container_context), ) except: logging.exception(f"Error trying to get logs for failed task {tags.arn}") diff --git a/python_modules/libraries/dagster-aws/dagster_aws/ecs/utils.py b/python_modules/libraries/dagster-aws/dagster_aws/ecs/utils.py index ec3d9edade381..63627be44a450 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/ecs/utils.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/ecs/utils.py @@ -19,6 +19,40 @@ def _get_family_hash(name): return f"{name[:55]}_{name_hash}" +class RetryableEcsException(Exception): ... + + +def run_ecs_task(ecs, run_task_kwargs) -> Mapping[str, Any]: + response = ecs.run_task(**run_task_kwargs) + + tasks = response["tasks"] + + if not tasks: + failures = response["failures"] + failure_messages = [] + for failure in failures: + arn = failure.get("arn") + reason = failure.get("reason") + detail = failure.get("detail") + + failure_message = ( + "Task" + + (f" {arn}" if arn else "") + + " failed." + + (f" Failure reason: {reason}" if reason else "") + + (f" Failure details: {detail}" if detail else "") + ) + failure_messages.append(failure_message) + + failure_message = "\n".join(failure_messages) if failure_messages else "Task failed." + + if "Capacity is unavailable at this time" in failure_message: + raise RetryableEcsException(failure_message) + + raise Exception(failure_message) + return tasks[0] + + def get_task_definition_family( prefix: str, job_origin: RemoteJobOrigin, diff --git a/python_modules/libraries/dagster-docker/dagster_docker/docker_executor.py b/python_modules/libraries/dagster-docker/dagster_docker/docker_executor.py index ab60303b0f8a7..28bcd061b9363 100644 --- a/python_modules/libraries/dagster-docker/dagster_docker/docker_executor.py +++ b/python_modules/libraries/dagster-docker/dagster_docker/docker_executor.py @@ -1,4 +1,4 @@ -from typing import Iterator, List, Optional, cast +from typing import List, Optional, cast import dagster._check as check import docker @@ -222,7 +222,7 @@ def _create_step_container( **container_kwargs, ) - def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: + def launch_step(self, step_handler_context: StepHandlerContext) -> None: container_context = self._get_docker_container_context(step_handler_context) client = self._get_client(container_context) @@ -251,7 +251,7 @@ def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[Dags assert len(step_keys_to_execute) == 1, "Launching multiple steps is not currently supported" step_key = step_keys_to_execute[0] - yield DagsterEvent.step_worker_starting( + DagsterEvent.step_worker_starting( step_handler_context.get_step_context(step_key), message="Launching step in Docker container.", metadata={ @@ -294,7 +294,7 @@ def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckSt reason=f"Container status is {container.status}. Return code is {ret_code}." ) - def terminate_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: + def terminate_step(self, step_handler_context: StepHandlerContext) -> None: container_context = self._get_docker_container_context(step_handler_context) step_keys_to_execute = check.not_none( @@ -307,7 +307,7 @@ def terminate_step(self, step_handler_context: StepHandlerContext) -> Iterator[D container_name = self._get_container_name(step_handler_context) - yield DagsterEvent.engine_event( + DagsterEvent.engine_event( step_handler_context.get_step_context(step_key), message=f"Stopping Docker container {container_name} for step.", event_specific_data=EngineEventData(), diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py b/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py index c8b0d7289d4dc..64bc8d44f8d2c 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py @@ -1,4 +1,4 @@ -from typing import Iterator, List, Optional, cast +from typing import List, Optional, cast import kubernetes.config from dagster import ( @@ -261,7 +261,7 @@ def _get_k8s_step_job_name(self, step_handler_context: StepHandlerContext): return "dagster-step-%s" % (name_key) - def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: + def launch_step(self, step_handler_context: StepHandlerContext) -> None: step_key = self._get_step_key(step_handler_context) job_name = self._get_k8s_step_job_name(step_handler_context) @@ -313,7 +313,7 @@ def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[Dags ], ) - yield DagsterEvent.step_worker_starting( + DagsterEvent.step_worker_starting( step_handler_context.get_step_context(step_key), message=f'Executing step "{step_key}" in Kubernetes job {job_name}.', metadata={ @@ -324,7 +324,11 @@ def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[Dags namespace = check.not_none(container_context.namespace) self._api_client.create_namespaced_job_with_retries(body=job, namespace=namespace) - def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckStepHealthResult: + return None + + def check_step_health( + self, step_handler_context: StepHandlerContext, step_identifier: Optional[str] + ) -> CheckStepHealthResult: step_key = self._get_step_key(step_handler_context) job_name = self._get_k8s_step_job_name(step_handler_context) @@ -346,13 +350,15 @@ def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckSt return CheckStepHealthResult.healthy() - def terminate_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: + def terminate_step( + self, step_handler_context: StepHandlerContext, step_identifier: str + ) -> None: step_key = self._get_step_key(step_handler_context) job_name = self._get_k8s_step_job_name(step_handler_context) container_context = self._get_container_context(step_handler_context) - yield DagsterEvent.engine_event( + DagsterEvent.engine_event( step_handler_context.get_step_context(step_key), message=f"Deleting Kubernetes job {job_name} for step", event_specific_data=EngineEventData(),