Skip to content

Commit

Permalink
[prototype] lambda pipes client
Browse files Browse the repository at this point in the history
  • Loading branch information
alangenfeld committed Nov 29, 2023
1 parent ef4166e commit 3746da4
Show file tree
Hide file tree
Showing 6 changed files with 483 additions and 22 deletions.
73 changes: 58 additions & 15 deletions python_modules/dagster-pipes/dagster_pipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,6 @@ def _normalize_param_metadata(
return new_metadata


def _param_from_env_var(env_var: str) -> Any:
raw_value = os.environ.get(env_var)
return decode_env_var(raw_value) if raw_value is not None else None


def encode_env_var(value: Any) -> str:
"""Encode value by serializing to JSON, compressing with zlib, and finally encoding with base64.
`base64_encode(compress(to_json(value)))` in function notation.
Expand Down Expand Up @@ -675,6 +670,7 @@ class PipesDefaultMessageWriter(PipesMessageWriter):

FILE_PATH_KEY = "path"
STDIO_KEY = "stdio"
BUFFERED_STDIO_KEY = "buffered_stdio"
STDERR = "stderr"
STDOUT = "stdout"

Expand All @@ -683,17 +679,34 @@ def open(self, params: PipesParams) -> Iterator[PipesMessageWriterChannel]:
if self.FILE_PATH_KEY in params:
path = _assert_env_param_type(params, self.FILE_PATH_KEY, str, self.__class__)
yield PipesFileMessageWriterChannel(path)

elif self.STDIO_KEY in params:
stream = _assert_env_param_type(params, self.STDIO_KEY, str, self.__class__)
if stream == self.STDERR:
yield PipesStreamMessageWriterChannel(sys.stderr)
elif stream == self.STDOUT:
yield PipesStreamMessageWriterChannel(sys.stdout)
else:
if stream not in (self.STDERR, self.STDOUT):
raise DagsterPipesError(
f'Invalid value for key "std", expected "{self.STDERR}" or "{self.STDOUT}" but'
f" received {stream}"
)

target = sys.stderr if stream == self.STDERR else sys.stdout

yield PipesStreamMessageWriterChannel(target)

elif self.BUFFERED_STDIO_KEY in params:
stream = _assert_env_param_type(params, self.BUFFERED_STDIO_KEY, str, self.__class__)
if stream not in (self.STDERR, self.STDOUT):
raise DagsterPipesError(
f'Invalid value for key "std", expected "{self.STDERR}" or "{self.STDOUT}" but'
f" received {stream}"
)

target = sys.stderr if stream == self.STDERR else sys.stdout
channel = PipesBufferedStreamMessageWriterChannel(target)
try:
yield channel
finally:
channel.flush()

else:
raise DagsterPipesError(
f'Invalid params for {self.__class__.__name__}, expected key "path" or "std",'
Expand Down Expand Up @@ -722,22 +735,52 @@ def write_message(self, message: PipesMessage) -> None:
self._stream.writelines((json.dumps(message), "\n"))


class PipesBufferedStreamMessageWriterChannel(PipesMessageWriterChannel):
"""Message writer channel that buffers messages and then writes them all out to a
`TextIO` stream on close.
"""

def __init__(self, stream: TextIO):
self._buffer = []
self._stream = stream

def write_message(self, message: PipesMessage) -> None:
self._buffer.append(message)

def flush(self):
for message in self._buffer:
self._stream.writelines((json.dumps(message), "\n"))
self._buffer = []


DAGSTER_PIPES_CONTEXT_ENV_VAR = "DAGSTER_PIPES_CONTEXT"
DAGSTER_PIPES_MESSAGES_ENV_VAR = "DAGSTER_PIPES_MESSAGES"


class PipesEnvVarParamsLoader(PipesParamsLoader):
"""Params loader that extracts params from environment variables."""
class PipesMappingParamsLoader(PipesParamsLoader):
"""Params loader that extracts params from a Mapping provided at init time."""

def __init__(self, mapping: Mapping[str, str]):
self._mapping = mapping

def is_dagster_pipes_process(self) -> bool:
# use the presence of DAGSTER_PIPES_CONTEXT to discern if we are in a pipes process
return DAGSTER_PIPES_CONTEXT_ENV_VAR in os.environ
return DAGSTER_PIPES_CONTEXT_ENV_VAR in self._mapping

def load_context_params(self) -> PipesParams:
return _param_from_env_var(DAGSTER_PIPES_CONTEXT_ENV_VAR)
raw_value = self._mapping[DAGSTER_PIPES_CONTEXT_ENV_VAR]
return decode_env_var(raw_value)

def load_messages_params(self) -> PipesParams:
return _param_from_env_var(DAGSTER_PIPES_MESSAGES_ENV_VAR)
raw_value = self._mapping[DAGSTER_PIPES_MESSAGES_ENV_VAR]
return decode_env_var(raw_value)


class PipesEnvVarParamsLoader(PipesMappingParamsLoader):
"""Params loader that extracts params from environment variables."""

def __init__(self):
super().__init__(mapping=os.environ)


# ########################
Expand Down
144 changes: 142 additions & 2 deletions python_modules/libraries/dagster-aws/dagster_aws/pipes.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,32 @@
import base64
import json
import os
import random
import string
from contextlib import contextmanager
from typing import TYPE_CHECKING, Iterator, Optional, Sequence
from typing import TYPE_CHECKING, Any, Iterator, Mapping, Optional, Sequence

import boto3
import dagster._check as check
from botocore.exceptions import ClientError
from dagster import PipesClient, ResourceParam
from dagster._annotations import experimental
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.pipes.client import (
PipesClientCompletedInvocation,
PipesContextInjector,
PipesMessageReader,
PipesParams,
)
from dagster._core.pipes.utils import PipesBlobStoreMessageReader, PipesLogReader
from dagster._core.pipes.context import PipesMessageHandler
from dagster._core.pipes.utils import (
PipesBlobStoreMessageReader,
PipesEnvContextInjector,
PipesLogReader,
extract_message_or_forward_to_stdout,
open_pipes_session,
)
from dagster_pipes import PipesDefaultMessageWriter

if TYPE_CHECKING:
from dagster_pipes import PipesContextData
Expand Down Expand Up @@ -102,3 +116,129 @@ def no_messages_debug_text(self) -> str:
" PipesS3MessageWriter to be explicitly passed to open_dagster_pipes in the external"
" process."
)


@experimental
class PipesLambdaLogsMessageReader(PipesMessageReader):
"""Message reader that consumes buffered pipes messages that were flushed on exit from the
final 4k of logs that are returned from issuing a sync lambda invocation.
Limitations: If the volume of pipes messages exceeds 4k, messages will be lost and it is
recommended to switch to PipesS3MessageWriter & PipesS3MessageReader.
"""

@contextmanager
def read_messages(
self,
handler: PipesMessageHandler,
) -> Iterator[PipesParams]:
self._handler = handler
try:
# use buffered stdio to shift the pipes messages to the tail of logs
yield {PipesDefaultMessageWriter.BUFFERED_STDIO_KEY: PipesDefaultMessageWriter.STDERR}
finally:
self._handler = None

def consume_lambda_logs(self, response) -> None:
handler = check.not_none(
self._handler, "Can only consume logs within context manager scope."
)

log_result = base64.b64decode(response["LogResult"]).decode("utf-8")

for log_line in log_result.splitlines():
extract_message_or_forward_to_stdout(handler, log_line)

def no_messages_debug_text(self) -> str:
return (
"Attempted to read messages by extracting them from the tail of lambda logs directly."
)


@experimental
class PipesLambdaEventContextInjector(PipesEnvContextInjector):
def no_messages_debug_text(self) -> str:
return "Attempted to inject context via the lambda event input."


@experimental
class _PipesLambdaClient(PipesClient):
"""A pipes client for invoking AWS lambda.
By default context is injected via the lambda input event and messages are parsed out of the
4k tail of logs. S3
Args:
client (boto3.client): The boto lambda client used to call invoke.
context_injector (Optional[PipesContextInjector]): A context injector to use to inject
context into the lambda function. Defaults to :py:class:`PipesLambdaEventContextInjector`.
message_reader (Optional[PipesMessageReader]): A message reader to use to read messages
from the lambda function. Defaults to :py:class:`PipesLambdaLogsMessageReader`.
"""

def __init__(
self,
client: boto3.client,
context_injector: Optional[PipesContextInjector] = None,
message_reader: Optional[PipesMessageReader] = None,
):
self._client = client
self._message_reader = message_reader or PipesLambdaLogsMessageReader()
self._context_injector = context_injector or PipesLambdaEventContextInjector()

@classmethod
def _is_dagster_maintained(cls) -> bool:
return True

def run(
self,
*,
function_name: str,
event: Mapping[str, Any],
context: OpExecutionContext,
):
"""Synchronously invoke a lambda function, enriched with the pipes protocol.
Args:
function_name (str): The name of the function to use.
event (Mapping[str, Any]): A JSON serializable object to pass as input to the lambda.
context (OpExecutionContext): The context of the currently executing Dagster op or asset.
"""
with open_pipes_session(
context=context,
message_reader=self._message_reader,
context_injector=self._context_injector,
) as session:
other_kwargs = {}
if isinstance(self._message_reader, PipesLambdaLogsMessageReader):
other_kwargs["LogType"] = "Tail"

if isinstance(self._context_injector, PipesLambdaEventContextInjector):
payload_data = {
**event,
**session.get_bootstrap_env_vars(),
}
else:
payload_data = event

response = self._client.invoke(
FunctionName=function_name,
InvocationType="RequestResponse",
Payload=json.dumps(payload_data),
**other_kwargs,
)
if isinstance(self._message_reader, PipesLambdaLogsMessageReader):
self._message_reader.consume_lambda_logs(response)

if "FunctionError" in response:
err_payload = json.loads(response["Payload"].read().decode("utf-8"))

raise Exception(
f"Lambda Function Error ({response['FunctionError']}):\n{json.dumps(err_payload, indent=2)}"
)

# should probably have a way to return the lambda result payload
return PipesClientCompletedInvocation(tuple(session.get_results()))


PipesLambdaClient = ResourceParam[_PipesLambdaClient]
Empty file.
Loading

0 comments on commit 3746da4

Please sign in to comment.