Skip to content

Commit

Permalink
[prototype] lambda pipes client
Browse files Browse the repository at this point in the history
  • Loading branch information
alangenfeld committed Nov 10, 2023
1 parent 525c3cc commit 57fa0ab
Show file tree
Hide file tree
Showing 6 changed files with 383 additions and 14 deletions.
31 changes: 20 additions & 11 deletions python_modules/dagster-pipes/dagster_pipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
import warnings
import zlib
from abc import ABC, abstractmethod
from abc import ABC, abstractmethod, abstractproperty
from contextlib import ExitStack, contextmanager
from io import StringIO
from queue import Queue
Expand Down Expand Up @@ -354,11 +354,6 @@ def _normalize_param_metadata(
return new_metadata


def _param_from_env_var(env_var: str) -> Any:
raw_value = os.environ.get(env_var)
return decode_env_var(raw_value) if raw_value is not None else None


def encode_env_var(value: Any) -> str:
"""Encode value by serializing to JSON, compressing with zlib, and finally encoding with base64.
`base64_encode(compress(to_json(value)))` in function notation.
Expand Down Expand Up @@ -726,18 +721,32 @@ def write_message(self, message: PipesMessage) -> None:
DAGSTER_PIPES_MESSAGES_ENV_VAR = "DAGSTER_PIPES_MESSAGES"


class PipesEnvVarParamsLoader(PipesParamsLoader):
"""Params loader that extracts params from environment variables."""
class PipesSourceParamsLoader(PipesParamsLoader):
"""Abstract params loader that extracts params from a Mapping source object."""

@abstractproperty
def source(self) -> Mapping[str, str]:
...

def is_dagster_pipes_process(self) -> bool:
# use the presence of DAGSTER_PIPES_CONTEXT to discern if we are in a pipes process
return DAGSTER_PIPES_CONTEXT_ENV_VAR in os.environ
return DAGSTER_PIPES_CONTEXT_ENV_VAR in self.source

def load_context_params(self) -> PipesParams:
return _param_from_env_var(DAGSTER_PIPES_CONTEXT_ENV_VAR)
raw_value = self.source[DAGSTER_PIPES_CONTEXT_ENV_VAR]
return decode_env_var(raw_value)

def load_messages_params(self) -> PipesParams:
return _param_from_env_var(DAGSTER_PIPES_MESSAGES_ENV_VAR)
raw_value = self.source[DAGSTER_PIPES_MESSAGES_ENV_VAR]
return decode_env_var(raw_value)


class PipesEnvVarParamsLoader(PipesParamsLoader):
"""Params loader that extracts params from environment variables."""

@property
def source(self):
return os.environ


# ########################
Expand Down
129 changes: 127 additions & 2 deletions python_modules/libraries/dagster-aws/dagster_aws/pipes.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,32 @@
import base64
import json
import os
import random
import string
from contextlib import contextmanager
from typing import TYPE_CHECKING, Iterator, Optional, Sequence
from typing import TYPE_CHECKING, Any, Iterator, Mapping, Optional, Sequence

import boto3
import dagster._check as check
from botocore.exceptions import ClientError
from dagster import PipesClient, ResourceParam
from dagster._annotations import experimental
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.pipes.client import (
PipesClientCompletedInvocation,
PipesContextInjector,
PipesMessageReader,
PipesParams,
)
from dagster._core.pipes.utils import PipesBlobStoreMessageReader, PipesLogReader
from dagster._core.pipes.context import PipesMessageHandler
from dagster._core.pipes.utils import (
PipesBlobStoreMessageReader,
PipesEnvContextInjector,
PipesLogReader,
extract_message_or_forward_to_stdout,
open_pipes_session,
)
from dagster_pipes import PipesDefaultMessageWriter

if TYPE_CHECKING:
from dagster_pipes import PipesContextData
Expand Down Expand Up @@ -102,3 +116,114 @@ def no_messages_debug_text(self) -> str:
" PipesS3MessageWriter to be explicitly passed to open_dagster_pipes in the external"
" process."
)


@experimental
class PipesLambdaLogsMessageReader(PipesMessageReader):
@contextmanager
def read_messages(
self,
handler: PipesMessageHandler,
) -> Iterator[PipesParams]:
self._handler = handler
try:
yield {PipesDefaultMessageWriter.STDIO_KEY: PipesDefaultMessageWriter.STDERR}
finally:
self._handler = None

def consume_lambda_logs(self, response) -> None:
handler = check.not_none(
self._handler, "Can only consume logs within context manager scope."
)

log_result = base64.b64decode(response["LogResult"]).decode("utf-8")

for log_line in log_result.splitlines():
extract_message_or_forward_to_stdout(handler, log_line)

def no_messages_debug_text(self) -> str:
return (
"Attempted to read messages by extracting them from the tail of lambda logs directly."
)


@experimental
class PipesLambdaEventContextInjector(PipesEnvContextInjector):
def no_messages_debug_text(self) -> str:
return "Attempted to inject context via the lambda event input."


@experimental
class _PipesLambdaClient(PipesClient):
"""A pipes client for invoking AWS lambda.
By default context is injected via the lambda input event and messages are parsed out of the
4k tail of logs. S3
Args:
client (boto3.client): The boto lambda client used to call invoke.
context_injector (Optional[PipesContextInjector]): A context injector to use to inject
context into the lambda function. Defaults to :py:class:`PipesLambdaEventContextInjector`.
message_reader (Optional[PipesMessageReader]): A message reader to use to read messages
from the lambda function. Defaults to :py:class:`PipesLambdaLogsMessageReader`.
"""

def __init__(
self,
client: boto3.client,
context_injector: Optional[PipesContextInjector] = None,
message_reader: Optional[PipesMessageReader] = None,
):
self._client = client
self._message_reader = message_reader or PipesLambdaLogsMessageReader()
self._context_injector = context_injector or PipesLambdaEventContextInjector()

@classmethod
def _is_dagster_maintained(cls) -> bool:
return True

def run(
self,
*,
function_name: str,
event: Mapping[str, Any],
context: OpExecutionContext,
):
"""Synchronously invoke a lambda function, enriched with the pipes protocol.
Args:
function_name (str): The name of the function to use.
event (Mapping[str, Any]): A JSON serializable object to pass as input to the lambda.
context (OpExecutionContext): The context of the currently executing Dagster op or asset.
"""
with open_pipes_session(
context=context,
message_reader=self._message_reader,
context_injector=self._context_injector,
) as session:
response = self._client.invoke(
FunctionName=function_name,
InvocationType="RequestResponse",
Payload=json.dumps(
{
**event,
**session.get_bootstrap_env_vars(),
}
),
LogType="Tail",
)
if isinstance(self._message_reader, PipesLambdaLogsMessageReader):
self._message_reader.consume_lambda_logs(response)

if "FunctionError" in response:
err_payload = json.loads(response["Payload"].read().decode("utf-8"))

raise Exception(
f"Lambda Function Error ({response['FunctionError']}):\n{json.dumps(err_payload, indent=2)}"
)

# way to return the payload?
return PipesClientCompletedInvocation(tuple(session.get_results()))


PipesLambdaClient = ResourceParam[_PipesLambdaClient]
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import base64
import io
import json
import os
import subprocess
import sys
import tempfile
import traceback
from typing import Any, Dict

from dagster_pipes import PipesSourceParamsLoader, open_dagster_pipes


class LambdaFunctions:
@staticmethod
def trunc_logs(event, context):
sys.stdout.write("O" * 1024 * 3)
sys.stderr.write("E" * 1024 * 3)

@staticmethod
def small_logs(event, context):
print("S" * event["size"])

@staticmethod
def pipes_basic(event, _lambda_context):
class LambdaEventLoader(PipesSourceParamsLoader):
def __init__(self, event):
self._event = event

@property
def source(self):
return self._event

with open_dagster_pipes(params_loader=LambdaEventLoader(event)) as dagster_context:
dagster_context.report_asset_materialization(metadata={"meta": "data"})

@staticmethod
def error(event, _lambda_context):
raise Exception("boom")


class FakeLambdaContext:
pass


LOG_TAIL_LIMIT = 4096


class FakeLambdaClient:
def invoke(self, **kwargs):
# emulate lambda constraints with a subprocess invocation
# * json serialized "Payload" result
# * 4k log output as base64 "LogResult"

with tempfile.TemporaryDirectory() as tempdir:
in_path = os.path.join(tempdir, "in.json")
out_path = os.path.join(tempdir, "out.json")
log_path = os.path.join(tempdir, "logs")

with open(in_path, "w") as f:
f.write(kwargs["Payload"])

with open(log_path, "w") as log_file:
result = subprocess.run(
[
sys.executable,
os.path.join(os.path.dirname(__file__), "fake_lambda.py"),
kwargs["FunctionName"],
in_path,
out_path,
],
check=False,
env={}, # env vars part of lambda fn definition, can't vary at runtime
stdout=log_file,
stderr=log_file,
)

response: Dict[str, Any] = {}

if result.returncode == 42:
response["FunctionError"] = "Unhandled"

elif result.returncode != 0:
with open(log_path, "r") as f:
print(f.read())
result.check_returncode()

with open(out_path, "rb") as f:
payload = io.BytesIO(f.read())

response["Payload"] = payload

if kwargs.get("LogType") == "Tail":
logs_len = os.path.getsize(log_path)
with open(log_path, "rb") as log_file:
if logs_len > LOG_TAIL_LIMIT:
log_file.seek(-LOG_TAIL_LIMIT, os.SEEK_END)

outro = log_file.read()

log_result = base64.encodebytes(outro)

response["LogResult"] = log_result

return response


if __name__ == "__main__":
assert len(sys.argv) == 4, "python fake_lambda.py <fn_name> <in_path> <out_path>"
_, fn_name, in_path, out_path = sys.argv

event = json.load(open(in_path))
fn = getattr(LambdaFunctions, fn_name)

val = None
return_code = 0
try:
val = fn(event, FakeLambdaContext())
except Exception as e:
tb = traceback.TracebackException.from_exception(e)
val = {
"errorMessage": str(tb),
"errorType": tb.exc_type.__name__,
"stackTrace": tb.stack.format(),
"requestId": "fake-request-id",
}
return_code = 42

with open(out_path, "w") as f:
json.dump(val, f)

sys.exit(return_code)
Loading

0 comments on commit 57fa0ab

Please sign in to comment.