diff --git a/python_modules/dagster-pipes/dagster_pipes/__init__.py b/python_modules/dagster-pipes/dagster_pipes/__init__.py index a6e651c47b10a..83927d254db55 100644 --- a/python_modules/dagster-pipes/dagster_pipes/__init__.py +++ b/python_modules/dagster-pipes/dagster_pipes/__init__.py @@ -7,7 +7,7 @@ import time import warnings import zlib -from abc import ABC, abstractmethod +from abc import ABC, abstractmethod, abstractproperty from contextlib import ExitStack, contextmanager from io import StringIO from queue import Queue @@ -354,11 +354,6 @@ def _normalize_param_metadata( return new_metadata -def _param_from_env_var(env_var: str) -> Any: - raw_value = os.environ.get(env_var) - return decode_env_var(raw_value) if raw_value is not None else None - - def encode_env_var(value: Any) -> str: """Encode value by serializing to JSON, compressing with zlib, and finally encoding with base64. `base64_encode(compress(to_json(value)))` in function notation. @@ -726,18 +721,32 @@ def write_message(self, message: PipesMessage) -> None: DAGSTER_PIPES_MESSAGES_ENV_VAR = "DAGSTER_PIPES_MESSAGES" -class PipesEnvVarParamsLoader(PipesParamsLoader): - """Params loader that extracts params from environment variables.""" +class PipesSourceParamsLoader(PipesParamsLoader): + """Abstract params loader that extracts params from a Mapping source object.""" + + @abstractproperty + def source(self) -> Mapping[str, str]: + ... def is_dagster_pipes_process(self) -> bool: # use the presence of DAGSTER_PIPES_CONTEXT to discern if we are in a pipes process - return DAGSTER_PIPES_CONTEXT_ENV_VAR in os.environ + return DAGSTER_PIPES_CONTEXT_ENV_VAR in self.source def load_context_params(self) -> PipesParams: - return _param_from_env_var(DAGSTER_PIPES_CONTEXT_ENV_VAR) + raw_value = self.source[DAGSTER_PIPES_CONTEXT_ENV_VAR] + return decode_env_var(raw_value) def load_messages_params(self) -> PipesParams: - return _param_from_env_var(DAGSTER_PIPES_MESSAGES_ENV_VAR) + raw_value = self.source[DAGSTER_PIPES_MESSAGES_ENV_VAR] + return decode_env_var(raw_value) + + +class PipesEnvVarParamsLoader(PipesParamsLoader): + """Params loader that extracts params from environment variables.""" + + @property + def source(self): + return os.environ # ######################## diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes.py index 2e6ea6c8fc68a..1531e01b1f153 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/pipes.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes.py @@ -1,18 +1,32 @@ +import base64 import json import os import random import string from contextlib import contextmanager -from typing import TYPE_CHECKING, Iterator, Optional, Sequence +from typing import TYPE_CHECKING, Any, Iterator, Mapping, Optional, Sequence import boto3 import dagster._check as check from botocore.exceptions import ClientError +from dagster import PipesClient, ResourceParam +from dagster._annotations import experimental +from dagster._core.execution.context.compute import OpExecutionContext from dagster._core.pipes.client import ( + PipesClientCompletedInvocation, PipesContextInjector, + PipesMessageReader, PipesParams, ) -from dagster._core.pipes.utils import PipesBlobStoreMessageReader, PipesLogReader +from dagster._core.pipes.context import PipesMessageHandler +from dagster._core.pipes.utils import ( + PipesBlobStoreMessageReader, + PipesEnvContextInjector, + PipesLogReader, + extract_message_or_forward_to_stdout, + open_pipes_session, +) +from dagster_pipes import PipesDefaultMessageWriter if TYPE_CHECKING: from dagster_pipes import PipesContextData @@ -102,3 +116,114 @@ def no_messages_debug_text(self) -> str: " PipesS3MessageWriter to be explicitly passed to open_dagster_pipes in the external" " process." ) + + +@experimental +class PipesLambdaLogsMessageReader(PipesMessageReader): + @contextmanager + def read_messages( + self, + handler: PipesMessageHandler, + ) -> Iterator[PipesParams]: + self._handler = handler + try: + yield {PipesDefaultMessageWriter.STDIO_KEY: PipesDefaultMessageWriter.STDERR} + finally: + self._handler = None + + def consume_lambda_logs(self, response) -> None: + handler = check.not_none( + self._handler, "Can only consume logs within context manager scope." + ) + + log_result = base64.b64decode(response["LogResult"]).decode("utf-8") + + for log_line in log_result.splitlines(): + extract_message_or_forward_to_stdout(handler, log_line) + + def no_messages_debug_text(self) -> str: + return ( + "Attempted to read messages by extracting them from the tail of lambda logs directly." + ) + + +@experimental +class PipesLambdaEventContextInjector(PipesEnvContextInjector): + def no_messages_debug_text(self) -> str: + return "Attempted to inject context via the lambda event input." + + +@experimental +class _PipesLambdaClient(PipesClient): + """A pipes client for invoking AWS lambda. + + By default context is injected via the lambda input event and messages are parsed out of the + 4k tail of logs. S3 + + Args: + client (boto3.client): The boto lambda client used to call invoke. + context_injector (Optional[PipesContextInjector]): A context injector to use to inject + context into the lambda function. Defaults to :py:class:`PipesLambdaEventContextInjector`. + message_reader (Optional[PipesMessageReader]): A message reader to use to read messages + from the lambda function. Defaults to :py:class:`PipesLambdaLogsMessageReader`. + """ + + def __init__( + self, + client: boto3.client, + context_injector: Optional[PipesContextInjector] = None, + message_reader: Optional[PipesMessageReader] = None, + ): + self._client = client + self._message_reader = message_reader or PipesLambdaLogsMessageReader() + self._context_injector = context_injector or PipesLambdaEventContextInjector() + + @classmethod + def _is_dagster_maintained(cls) -> bool: + return True + + def run( + self, + *, + function_name: str, + event: Mapping[str, Any], + context: OpExecutionContext, + ): + """Synchronously invoke a lambda function, enriched with the pipes protocol. + + Args: + function_name (str): The name of the function to use. + event (Mapping[str, Any]): A JSON serializable object to pass as input to the lambda. + context (OpExecutionContext): The context of the currently executing Dagster op or asset. + """ + with open_pipes_session( + context=context, + message_reader=self._message_reader, + context_injector=self._context_injector, + ) as session: + response = self._client.invoke( + FunctionName=function_name, + InvocationType="RequestResponse", + Payload=json.dumps( + { + **event, + **session.get_bootstrap_env_vars(), + } + ), + LogType="Tail", + ) + if isinstance(self._message_reader, PipesLambdaLogsMessageReader): + self._message_reader.consume_lambda_logs(response) + + if "FunctionError" in response: + err_payload = json.loads(response["Payload"].read().decode("utf-8")) + + raise Exception( + f"Lambda Function Error ({response['FunctionError']}):\n{json.dumps(err_payload, indent=2)}" + ) + + # way to return the payload? + return PipesClientCompletedInvocation(tuple(session.get_results())) + + +PipesLambdaClient = ResourceParam[_PipesLambdaClient] diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/__init__.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/fake_lambda.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/fake_lambda.py new file mode 100644 index 0000000000000..45c4effe7a330 --- /dev/null +++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/fake_lambda.py @@ -0,0 +1,132 @@ +import base64 +import io +import json +import os +import subprocess +import sys +import tempfile +import traceback +from typing import Any, Dict + +from dagster_pipes import PipesSourceParamsLoader, open_dagster_pipes + + +class LambdaFunctions: + @staticmethod + def trunc_logs(event, context): + sys.stdout.write("O" * 1024 * 3) + sys.stderr.write("E" * 1024 * 3) + + @staticmethod + def small_logs(event, context): + print("S" * event["size"]) + + @staticmethod + def pipes_basic(event, _lambda_context): + class LambdaEventLoader(PipesSourceParamsLoader): + def __init__(self, event): + self._event = event + + @property + def source(self): + return self._event + + with open_dagster_pipes(params_loader=LambdaEventLoader(event)) as dagster_context: + dagster_context.report_asset_materialization(metadata={"meta": "data"}) + + @staticmethod + def error(event, _lambda_context): + raise Exception("boom") + + +class FakeLambdaContext: + pass + + +LOG_TAIL_LIMIT = 4096 + + +class FakeLambdaClient: + def invoke(self, **kwargs): + # emulate lambda constraints with a subprocess invocation + # * json serialized "Payload" result + # * 4k log output as base64 "LogResult" + + with tempfile.TemporaryDirectory() as tempdir: + in_path = os.path.join(tempdir, "in.json") + out_path = os.path.join(tempdir, "out.json") + log_path = os.path.join(tempdir, "logs") + + with open(in_path, "w") as f: + f.write(kwargs["Payload"]) + + with open(log_path, "w") as log_file: + result = subprocess.run( + [ + sys.executable, + os.path.join(os.path.dirname(__file__), "fake_lambda.py"), + kwargs["FunctionName"], + in_path, + out_path, + ], + check=False, + env={}, # env vars part of lambda fn definition, can't vary at runtime + stdout=log_file, + stderr=log_file, + ) + + response: Dict[str, Any] = {} + + if result.returncode == 42: + response["FunctionError"] = "Unhandled" + + elif result.returncode != 0: + with open(log_path, "r") as f: + print(f.read()) + result.check_returncode() + + with open(out_path, "rb") as f: + payload = io.BytesIO(f.read()) + + response["Payload"] = payload + + if kwargs.get("LogType") == "Tail": + logs_len = os.path.getsize(log_path) + with open(log_path, "rb") as log_file: + if logs_len > LOG_TAIL_LIMIT: + log_file.seek(-LOG_TAIL_LIMIT, os.SEEK_END) + + outro = log_file.read() + + log_result = base64.encodebytes(outro) + + response["LogResult"] = log_result + + return response + + +if __name__ == "__main__": + assert len(sys.argv) == 4, "python fake_lambda.py " + _, fn_name, in_path, out_path = sys.argv + + event = json.load(open(in_path)) + fn = getattr(LambdaFunctions, fn_name) + + val = None + return_code = 0 + try: + val = fn(event, FakeLambdaContext()) + except Exception as e: + tb = traceback.TracebackException.from_exception(e) + val = { + "errorMessage": str(tb), + "errorType": tb.exc_type.__name__, + "stackTrace": tb.stack.format(), + "requestId": "fake-request-id", + } + return_code = 42 + + with open(out_path, "w") as f: + json.dump(val, f) + + sys.exit(return_code) diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/test_pipes.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/test_pipes.py new file mode 100644 index 0000000000000..b79590fe3af15 --- /dev/null +++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/test_pipes.py @@ -0,0 +1,103 @@ +import base64 +import json + +import pytest +from dagster import asset, materialize, open_pipes_session +from dagster._core.pipes.utils import PipesEnvContextInjector +from dagster_aws.pipes import PipesLambdaClient, PipesLambdaLogsMessageReader + +from .fake_lambda import LOG_TAIL_LIMIT, FakeLambdaClient, LambdaFunctions + + +def test_fake_lambda_logs(): + event = {} + response = FakeLambdaClient().invoke( + FunctionName=LambdaFunctions.trunc_logs.__name__, + InvocationType="RequestResponse", + Payload=json.dumps(event), + LogType="Tail", + ) + + log_result = base64.b64decode(response["LogResult"]) + assert len(log_result) == LOG_TAIL_LIMIT + + small_size = 512 + response = FakeLambdaClient().invoke( + FunctionName=LambdaFunctions.small_logs.__name__, + InvocationType="RequestResponse", + Payload=json.dumps({"size": small_size}), + LogType="Tail", + ) + + log_result = base64.b64decode(response["LogResult"]) + assert len(log_result) == small_size + 1 # size + \n + + +def test_manual_fake_lambda_pipes(): + @asset + def fake_lambda_asset(context): + context_injector = PipesEnvContextInjector() + message_reader = PipesLambdaLogsMessageReader() + + with open_pipes_session( + context=context, + message_reader=message_reader, + context_injector=context_injector, + ) as session: + user_event = {} + response = FakeLambdaClient().invoke( + FunctionName=LambdaFunctions.pipes_basic.__name__, + InvocationType="RequestResponse", + Payload=json.dumps( + { + **user_event, + **session.get_bootstrap_env_vars(), + } + ), + LogType="Tail", + ) + message_reader.consume_lambda_logs(response) + yield from session.get_results() + + result = materialize([fake_lambda_asset]) + assert result.success + mat_evts = result.get_asset_materialization_events() + assert len(mat_evts) == 1 + assert mat_evts[0].materialization.metadata["meta"].value == "data" + + +def test_fake_client_lambda_pipes(): + @asset + def fake_lambda_asset(context): + return ( + PipesLambdaClient(FakeLambdaClient()) + .run( + context=context, + function_name=LambdaFunctions.pipes_basic.__name__, + event={}, + ) + .get_materialize_result() + ) + + result = materialize([fake_lambda_asset]) + assert result.success + mat_evts = result.get_asset_materialization_events() + assert len(mat_evts) == 1 + assert mat_evts[0].materialization.metadata["meta"].value == "data" + + +def test_fake_client_lambda_error(): + @asset + def fake_lambda_asset(context): + yield from ( + PipesLambdaClient(FakeLambdaClient()) + .run( + context=context, + function_name=LambdaFunctions.error.__name__, + event={}, + ) + .get_results() + ) + + with pytest.raises(Exception, match="Lambda Function Error"): + materialize([fake_lambda_asset]) diff --git a/python_modules/libraries/dagster-docker/dagster_docker/pipes.py b/python_modules/libraries/dagster-docker/dagster_docker/pipes.py index 694daf1842ad8..e8d7df2ca8d9b 100644 --- a/python_modules/libraries/dagster-docker/dagster_docker/pipes.py +++ b/python_modules/libraries/dagster-docker/dagster_docker/pipes.py @@ -75,7 +75,7 @@ class _PipesDockerClient(PipesClient): the docker client. context_injector (Optional[PipesContextInjector]): A context injector to use to inject context into the docker container process. Defaults to :py:class:`PipesEnvContextInjector`. - message_reader (Optional[PipesContextInjector]): A message reader to use to read messages + message_reader (Optional[PipesMessageReader]): A message reader to use to read messages from the docker container process. Defaults to :py:class:`DockerLogsMessageReader`. """