diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes.py index b78e69dbdbac2..296a5442dbe6c 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/pipes.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes.py @@ -25,6 +25,7 @@ from dagster import PipesClient from dagster._annotations import experimental from dagster._core.definitions.resource_annotation import TreatAsResourceParam +from dagster._core.errors import DagsterExecutionInterruptedError from dagster._core.execution.context.compute import OpExecutionContext from dagster._core.pipes.client import ( PipesClientCompletedInvocation, @@ -344,6 +345,7 @@ class PipesGlueClient(PipesClient, TreatAsResourceParam): message_reader (Optional[PipesMessageReader]): A message reader to use to read messages from the glue job run. Defaults to :py:class:`PipesCloudWatchsMessageReader`. client (Optional[boto3.client]): The boto Glue client used to launch the Glue job + forward_termination (bool): Whether to cancel the Glue job run when the Dagster process receives a termination signal. """ def __init__( @@ -351,10 +353,12 @@ def __init__( context_injector: PipesContextInjector, message_reader: Optional[PipesMessageReader] = None, client: Optional[boto3.client] = None, + forward_termination: bool = False, ): self._client = client or boto3.client("glue") self._context_injector = context_injector self._message_reader = message_reader or PipesCloudWatchMessageReader() + self.forward_termination = check.bool_param(forward_termination, "forward_termination") @classmethod def _is_dagster_maintained(cls) -> bool: @@ -448,7 +452,12 @@ def run( log_group = response["JobRun"]["LogGroupName"] context.log.info(f"Started AWS Glue job {job_name} run: {run_id}") - response = self._wait_for_job_run_completion(job_name, run_id) + try: + response = self._wait_for_job_run_completion(job_name, run_id) + except DagsterExecutionInterruptedError: + if self.forward_termination: + self._terminate_job_run(context=context, job_name=job_name, run_id=run_id) + raise if response["JobRun"]["JobRunState"] == "FAILED": raise RuntimeError( @@ -473,3 +482,17 @@ def _wait_for_job_run_completion(self, job_name: str, run_id: str) -> Dict[str, if response["JobRun"]["JobRunState"] in ["FAILED", "SUCCEEDED"]: return response time.sleep(5) + + def _terminate_job_run(self, context: OpExecutionContext, job_name: str, run_id: str): + """Creates a handler which will gracefully stop the Run in case of external termination. + It will stop the Glue job before doing so. + """ + context.log.warning(f"[pipes] execution interrupted, stopping Glue job run {run_id}...") + response = self._client.batch_stop_job_run(JobName=job_name, JobRunIds=[run_id]) + runs = response["SuccessfulSubmissions"] + if len(runs) > 0: + context.log.warning(f"Successfully stopped Glue job run {run_id}.") + else: + context.log.warning( + f"Something went wrong during Glue job run termination: {response['errors']}" + )