diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes.py index 012ea9d98acab..6bd7b35bcf6f3 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/pipes.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes.py @@ -5,7 +5,19 @@ import string import time from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Dict, Iterator, Literal, Mapping, Optional, Sequence +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generator, + Iterator, + List, + Literal, + Mapping, + Optional, + Sequence, + TypedDict, +) import boto3 import dagster._check as check @@ -157,10 +169,22 @@ def no_messages_debug_text(self) -> str: ) +class CloudWatchEvent(TypedDict): + timestamp: int + message: str + ingestionTime: int + + @experimental class PipesCloudWatchMessageReader(PipesMessageReader): """Message reader that consumes AWS CloudWatch logs to read pipes messages.""" + def __init__(self, client: Optional[boto3.client] = None): + """Args: + client (boto3.client): boto3 CloudWatch client. + """ + self.client = client or boto3.client("logs") + @contextmanager def read_messages( self, @@ -174,13 +198,53 @@ def read_messages( self._handler = None def consume_cloudwatch_logs( - self, client: boto3.client, log_group: str, log_stream: str + self, + log_group: str, + log_stream: str, + start_time: Optional[int] = None, + end_time: Optional[int] = None, ) -> None: - raise NotImplementedError("CloudWatch logs are not yet supported in the pipes protocol.") + handler = check.not_none( + self._handler, "Can only consume logs within context manager scope." + ) + + for events_batch in self._get_all_cloudwatch_events( + log_group=log_group, log_stream=log_stream, start_time=start_time, end_time=end_time + ): + for event in events_batch: + for log_line in event["message"].splitlines(): + extract_message_or_forward_to_stdout(handler, log_line) def no_messages_debug_text(self) -> str: return "Attempted to read messages by extracting them from the tail of CloudWatch logs directly." + def _get_all_cloudwatch_events( + self, + log_group: str, + log_stream: str, + start_time: Optional[int] = None, + end_time: Optional[int] = None, + ) -> Generator[List[CloudWatchEvent], None, None]: + """Returns batches of CloudWatch events until the stream is complete or end_time.""" + params: Dict[str, Any] = { + "logGroupName": log_group, + "logStreamName": log_stream, + } + + if start_time is not None: + params["startTime"] = start_time + if end_time is not None: + params["endTime"] = end_time + + response = self.client.get_log_events(**params) + + while events := response.get("events"): + yield events + + params["nextToken"] = response["nextForwardToken"] + + response = self.client.get_log_events(**params) + class PipesLambdaEventContextInjector(PipesEnvContextInjector): def no_messages_debug_text(self) -> str: @@ -203,11 +267,11 @@ class PipesLambdaClient(PipesClient, TreatAsResourceParam): def __init__( self, - client: boto3.client, + client: Optional[boto3.client] = None, context_injector: Optional[PipesContextInjector] = None, message_reader: Optional[PipesMessageReader] = None, ): - self._client = client + self._client = client or boto3.client("lambda") self._message_reader = message_reader or PipesLambdaLogsMessageReader() self._context_injector = context_injector or PipesLambdaEventContextInjector() @@ -272,12 +336,11 @@ def run( class PipesGlueContextInjector(PipesS3ContextInjector): def no_messages_debug_text(self) -> str: - return "Attempted to inject context via Glue job arguments." + return "Attempted to inject context via Glue job Arguments" class PipesGlueLogsMessageReader(PipesCloudWatchMessageReader): - def no_messages_debug_text(self) -> str: - return "Attempted to read messages by extracting them from the tail of CloudWatch logs directly." + pass @experimental @@ -285,22 +348,22 @@ class PipesGlueClient(PipesClient, TreatAsResourceParam): """A pipes client for invoking AWS Glue jobs. Args: - client (boto3.client): The boto Glue client used to call invoke. context_injector (Optional[PipesContextInjector]): A context injector to use to inject context into the Glue job, for example, :py:class:`PipesGlueContextInjector`. message_reader (Optional[PipesMessageReader]): A message reader to use to read messages from the glue job run. Defaults to :py:class:`PipesGlueLogsMessageReader`. + client (Optional[boto3.client]): The boto Glue client used to launch the Glue job """ def __init__( self, - client: boto3.client, context_injector: PipesContextInjector, message_reader: Optional[PipesMessageReader] = None, + client: Optional[boto3.client] = None, ): - self._client = client + self._client = client or boto3.client("glue") self._context_injector = context_injector - self._message_reader = message_reader or PipesCloudWatchMessageReader() + self._message_reader = message_reader or PipesGlueLogsMessageReader() @classmethod def _is_dagster_maintained(cls) -> bool: @@ -377,19 +440,10 @@ def run( # so we need to filter them out params = {k: v for k, v in params.items() if v is not None} - try: - response = self._client.start_job_run(**params) - run_id = response["JobRunId"] - context.log.info(f"Started AWS Glue job {job_name} run: {run_id}") - response = self._wait_for_job_run_completion(job_name, run_id) - - if response["JobRun"]["JobRunState"] == "FAILED": - raise RuntimeError( - f"Glue job {job_name} run {run_id} failed:\n{response['JobRun']['ErrorMessage']}" - ) - else: - context.log.info(f"Glue job {job_name} run {run_id} completed successfully") + start_timestamp = time.time() * 1000 # unix time in ms + try: + run_id = self._client.start_job_run(**params)["JobRunId"] except ClientError as err: context.log.error( "Couldn't create job %s. Here's why: %s: %s", @@ -399,11 +453,27 @@ def run( ) raise - # TODO: get logs from CloudWatch. there are 2 separate streams for stdout and driver stderr to read from - # the log group can be found in the response from start_job_run, and the log stream is the job run id - # worker logs have log streams like: _ but we probably don't need to read those + response = self._client.get_job_run(JobName=job_name, RunId=run_id) + log_group = response["JobRun"]["LogGroupName"] + context.log.info(f"Started AWS Glue job {job_name} run: {run_id}") + + response = self._wait_for_job_run_completion(job_name, run_id) + + if response["JobRun"]["JobRunState"] == "FAILED": + raise RuntimeError( + f"Glue job {job_name} run {run_id} failed:\n{response['JobRun']['ErrorMessage']}" + ) + else: + context.log.info(f"Glue job {job_name} run {run_id} completed successfully") + + if isinstance(self._message_reader, PipesCloudWatchMessageReader): + # TODO: consume messages in real-time via a background thread + # so we don't have to wait for the job run to complete + # before receiving any logs + self._message_reader.consume_cloudwatch_logs( + f"{log_group}/output", run_id, start_time=int(start_timestamp) + ) - # should probably have a way to return the lambda result payload return PipesClientCompletedInvocation(session) def _wait_for_job_run_completion(self, job_name: str, run_id: str) -> Dict[str, Any]: diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/fake_glue.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/fake_glue.py index 77812e618336e..744b30e549353 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/fake_glue.py +++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/fake_glue.py @@ -1,24 +1,37 @@ import subprocess import sys import tempfile -from typing import Dict, Optional +import time +from typing import Dict, Literal, Optional import boto3 class LocalGlueMockClient: - def __init__(self, s3_client: boto3.client, glue_client: boto3.client): + def __init__( + self, + aws_endpoint_url: str, # usually received from moto + s3_client: boto3.client, + glue_client: boto3.client, + pipes_messages_backend: Literal["s3", "cloudwatch"], + cloudwatch_client: Optional[boto3.client] = None, + ): """This class wraps moto3 clients for S3 and Glue, and provides a way to "run" Glue jobs locally. This is necessary because moto3 does not actually run anything when you start a Glue job, so we won't be able to receive any Dagster messages from it. + If pipes_messages_backend is configured to be CloudWatch, it also uploads stderr and stdout logs to CloudWatch + as if this has been done by Glue. """ + self.aws_endpoint_url = aws_endpoint_url self.s3_client = s3_client self.glue_client = glue_client + self.pipes_messages_backend = pipes_messages_backend + self.cloudwatch_client = cloudwatch_client def get_job_run(self, *args, **kwargs): return self.glue_client.get_job_run(*args, **kwargs) - def start_job_run(self, JobName: str, Arguments: Optional[Dict[str, str]], *args, **kwargs): + def start_job_run(self, JobName: str, Arguments: Optional[Dict[str, str]], **kwargs): params = { "JobName": JobName, } @@ -44,12 +57,46 @@ def start_job_run(self, JobName: str, Arguments: Optional[Dict[str, str]], *args result = subprocess.run( [sys.executable, f.name, *args], check=False, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, + env={ + "AWS_ENDPOINT_URL": self.aws_endpoint_url, + "TESTING_PIPES_MESSAGES_BACKEND": self.pipes_messages_backend, + }, + capture_output=True, ) # mock the job run with moto response = self.glue_client.start_job_run(**params) + job_run_id = response["JobRunId"] + + job_run_response = self.glue_client.get_job_run(JobName=JobName, RunId=job_run_id) + log_group = job_run_response["JobRun"]["LogGroupName"] + + if self.pipes_messages_backend == "cloudwatch": + assert ( + self.cloudwatch_client is not None + ), "cloudwatch_client has to be provided with cloudwatch messages backend" + + self.cloudwatch_client.create_log_group( + logGroupName=f"{log_group}/output", + ) + + self.cloudwatch_client.create_log_stream( + logGroupName=f"{log_group}/output", + logStreamName=job_run_id, + ) + + for line in result.stderr.decode().split( + "\n" + ): # uploading log lines one by one is good enough for tests + if line: + self.cloudwatch_client.put_log_events( + logGroupName=f"{log_group}/output", # yes, Glue routes stderr to /output + logStreamName=job_run_id, + logEvents=[{"timestamp": int(time.time() * 1000), "message": str(line)}], + ) + time.sleep( + 0.01 + ) # make sure the logs will be properly filtered by ms timestamp when accessed next time # replace run state with actual results response["JobRun"] = {} diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/test_pipes.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/test_pipes.py index 0bcc0f7a26f19..d516cee6270ea 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/test_pipes.py +++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/test_pipes.py @@ -6,7 +6,7 @@ import textwrap from contextlib import contextmanager from tempfile import NamedTemporaryFile -from typing import Any, Callable, Iterator +from typing import Any, Callable, Iterator, Literal import boto3 import pytest @@ -24,6 +24,7 @@ from dagster._core.pipes.utils import PipesEnvContextInjector from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecordStatus from dagster_aws.pipes import ( + PipesCloudWatchMessageReader, PipesGlueClient, PipesLambdaClient, PipesLambdaLogsMessageReader, @@ -283,6 +284,7 @@ def fake_lambda_asset(context): def external_s3_glue_script(s3_client) -> Iterator[str]: # This is called in an external process and so cannot access outer scope def script_fn(): + import os import time import boto3 @@ -293,12 +295,18 @@ def script_fn(): open_dagster_pipes, ) - client = boto3.client("s3", region_name="us-east-1", endpoint_url="http://localhost:5193") - context_loader = PipesS3ContextLoader(client=client) - message_writer = PipesS3MessageWriter(client, interval=0.001) + s3_client = boto3.client( + "s3", region_name="us-east-1", endpoint_url="http://localhost:5193" + ) + + messages_backend = os.environ["TESTING_PIPES_MESSAGES_BACKEND"] + if messages_backend == "s3": + message_writer = PipesS3MessageWriter(s3_client, interval=0.001) + else: + message_writer = None with open_dagster_pipes( - context_loader=context_loader, + context_loader=PipesS3ContextLoader(client=s3_client), message_writer=message_writer, params_loader=PipesCliArgsParamsLoader(), ) as context: @@ -341,9 +349,25 @@ def glue_client(moto_server, external_s3_glue_script, s3_client) -> boto3.client return client -def test_glue_s3_pipes(capsys, s3_client, glue_client): +@pytest.fixture +def cloudwatch_client(moto_server, external_s3_glue_script, s3_client) -> boto3.client: + return boto3.client("logs", region_name="us-east-1", endpoint_url=_MOTO_SERVER_URL) + + +@pytest.mark.parametrize("pipes_messages_backend", ["s3", "cloudwatch"]) +def test_glue_pipes( + capsys, + s3_client, + glue_client, + cloudwatch_client, + pipes_messages_backend: Literal["s3", "cloudwatch"], +): context_injector = PipesS3ContextInjector(bucket=_S3_TEST_BUCKET, client=s3_client) - message_reader = PipesS3MessageReader(bucket=_S3_TEST_BUCKET, client=s3_client, interval=0.001) + message_reader = ( + PipesS3MessageReader(bucket=_S3_TEST_BUCKET, client=s3_client, interval=0.001) + if pipes_messages_backend == "s3" + else PipesCloudWatchMessageReader(client=cloudwatch_client) + ) @asset(check_specs=[AssetCheckSpec(name="foo_check", asset=AssetKey(["foo"]))]) def foo(context: AssetExecutionContext, pipes_glue_client: PipesGlueClient): @@ -355,7 +379,13 @@ def foo(context: AssetExecutionContext, pipes_glue_client: PipesGlueClient): return results pipes_glue_client = PipesGlueClient( - client=LocalGlueMockClient(glue_client=glue_client, s3_client=s3_client), + client=LocalGlueMockClient( + aws_endpoint_url=_MOTO_SERVER_URL, + glue_client=glue_client, + s3_client=s3_client, + cloudwatch_client=cloudwatch_client, + pipes_messages_backend=pipes_messages_backend, + ), context_injector=context_injector, message_reader=message_reader, )