diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes.py index 6bd7b35bcf6f3..8085df62e6da6 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/pipes.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes.py @@ -25,6 +25,7 @@ from dagster import PipesClient from dagster._annotations import experimental from dagster._core.definitions.resource_annotation import TreatAsResourceParam +from dagster._core.errors import DagsterExecutionInterruptedError from dagster._core.execution.context.compute import OpExecutionContext from dagster._core.pipes.client import ( PipesClientCompletedInvocation, @@ -350,9 +351,10 @@ class PipesGlueClient(PipesClient, TreatAsResourceParam): Args: context_injector (Optional[PipesContextInjector]): A context injector to use to inject context into the Glue job, for example, :py:class:`PipesGlueContextInjector`. - message_reader (Optional[PipesMessageReader]): A message reader to use to read messages + message_reader (Optional[PipesGlueMessageReader]): A message reader to use to read messages from the glue job run. Defaults to :py:class:`PipesGlueLogsMessageReader`. client (Optional[boto3.client]): The boto Glue client used to launch the Glue job + forward_termination (bool): Whether to cancel the Glue job run when the Dagster process receives a termination signal. """ def __init__( @@ -360,10 +362,12 @@ def __init__( context_injector: PipesContextInjector, message_reader: Optional[PipesMessageReader] = None, client: Optional[boto3.client] = None, + forward_termination: bool = False, ): self._client = client or boto3.client("glue") self._context_injector = context_injector - self._message_reader = message_reader or PipesGlueLogsMessageReader() + self._message_reader = message_reader or PipesCloudWatchMessageReader() + self.forward_termination = check.bool_param(forward_termination, "forward_termination") @classmethod def _is_dagster_maintained(cls) -> bool: @@ -457,7 +461,12 @@ def run( log_group = response["JobRun"]["LogGroupName"] context.log.info(f"Started AWS Glue job {job_name} run: {run_id}") - response = self._wait_for_job_run_completion(job_name, run_id) + try: + response = self._wait_for_job_run_completion(job_name, run_id) + except DagsterExecutionInterruptedError: + if self.forward_termination: + self._terminate_job_run(context=context, job_name=job_name, run_id=run_id) + raise if response["JobRun"]["JobRunState"] == "FAILED": raise RuntimeError( @@ -467,9 +476,7 @@ def run( context.log.info(f"Glue job {job_name} run {run_id} completed successfully") if isinstance(self._message_reader, PipesCloudWatchMessageReader): - # TODO: consume messages in real-time via a background thread - # so we don't have to wait for the job run to complete - # before receiving any logs + # TODO: receive messages from a background thread in real-time self._message_reader.consume_cloudwatch_logs( f"{log_group}/output", run_id, start_time=int(start_timestamp) ) @@ -482,3 +489,17 @@ def _wait_for_job_run_completion(self, job_name: str, run_id: str) -> Dict[str, if response["JobRun"]["JobRunState"] in ["FAILED", "SUCCEEDED"]: return response time.sleep(5) + + def _terminate_job_run(self, context: OpExecutionContext, job_name: str, run_id: str): + """Creates a handler which will gracefully stop the Run in case of external termination. + It will stop the Glue job before doing so. + """ + context.log.warning(f"[pipes] execution interrupted, stopping Glue job run {run_id}...") + response = self._client.batch_stop_job_run(JobName=job_name, JobRunIds=[run_id]) + runs = response["SuccessfulSubmissions"] + if len(runs) > 0: + context.log.warning(f"Successfully stopped Glue job run {run_id}.") + else: + context.log.warning( + f"Something went wrong during Glue job run termination: {response['errors']}" + )