diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes.py index d830df6b96e23..98775dd18c4ca 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/pipes.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes.py @@ -3,13 +3,22 @@ import os import random import string -import sys import time from contextlib import contextmanager -from threading import Thread -from typing import TYPE_CHECKING, Any, Dict, Iterator, Literal, Mapping, Optional, Sequence, List, Generator -from typing import TypedDict -import signal +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generator, + Iterator, + List, + Literal, + Mapping, + Optional, + Sequence, + TypedDict, +) + import boto3 import dagster._check as check from botocore.exceptions import ClientError @@ -171,9 +180,8 @@ class PipesCloudWatchMessageReader(PipesMessageReader): """Message reader that consumes AWS CloudWatch logs to read pipes messages.""" def __init__(self, client: Optional[boto3.client] = None): - """ - Args: - client (boto3.client): boto3 CloudWatch client. + """Args: + client (boto3.client): boto3 CloudWatch client. """ self.client = client or boto3.client("logs") @@ -190,17 +198,18 @@ def read_messages( self._handler = None def consume_cloudwatch_logs( - self, log_group: str, log_stream: str, start_time: Optional[int] = None, end_time: Optional[int] = None, + self, + log_group: str, + log_stream: str, + start_time: Optional[int] = None, + end_time: Optional[int] = None, ) -> None: handler = check.not_none( self._handler, "Can only consume logs within context manager scope." ) for events_batch in self._get_all_cloudwatch_events( - log_group=log_group, - log_stream=log_stream, - start_time=start_time, - end_time=end_time + log_group=log_group, log_stream=log_stream, start_time=start_time, end_time=end_time ): for event in events_batch: for log_line in event["message"].splitlines(): @@ -210,16 +219,14 @@ def no_messages_debug_text(self) -> str: return "Attempted to read messages by extracting them from the tail of CloudWatch logs directly." def _get_all_cloudwatch_events( - self, - log_group: str, - log_stream: str, - start_time: Optional[int] = None, - end_time: Optional[int] = None - ) -> Generator[List[CloudWatchEvent], None, None]: - """ - Returns batches of CloudWatch events until the stream is complete or end_time. - """ - params = { + self, + log_group: str, + log_stream: str, + start_time: Optional[int] = None, + end_time: Optional[int] = None, + ) -> Generator[List[CloudWatchEvent], None, None]: + """Returns batches of CloudWatch events until the stream is complete or end_time.""" + params: Dict[str, Any] = { "logGroupName": log_group, "logStreamName": log_stream, } @@ -229,18 +236,14 @@ def _get_all_cloudwatch_events( if end_time is not None: params["endTime"] = end_time - response = self.client.get_log_events( - **params - ) + response = self.client.get_log_events(**params) while events := response.get("events"): yield events params["nextToken"] = response["nextForwardToken"] - response = self.client.get_log_events( - **params - ) + response = self.client.get_log_events(**params) class PipesLambdaEventContextInjector(PipesEnvContextInjector): @@ -442,7 +445,6 @@ def run( try: response = self._client.start_job_run(**params) - except ClientError as err: context.log.error( "Couldn't create job %s. Here's why: %s: %s", @@ -454,8 +456,9 @@ def run( run_id = response["JobRunId"] - log_group = self._client.get_job_run(JobName=job_name, RunId=run_id)["JobRun"]["LogGroupName"] - self._register_interruption_handler(context, job_name, run_id) + log_group = self._client.get_job_run(JobName=job_name, RunId=run_id)["JobRun"][ + "LogGroupName" + ] context.log.info(f"Started AWS Glue job {job_name} run: {run_id}") response = self._wait_for_job_run_completion(job_name, run_id) @@ -469,8 +472,9 @@ def run( if isinstance(self._message_reader, PipesGlueLogsMessageReader): # TODO: receive messages from a background thread in real-time - self._message_reader.consume_cloudwatch_logs(f"{log_group}/output", run_id, - start_time=int(start_timestamp)) + self._message_reader.consume_cloudwatch_logs( + f"{log_group}/output", run_id, start_time=int(start_timestamp) + ) return PipesClientCompletedInvocation(session)