Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Oct 6, 2023
1 parent 71c5bc1 commit 942fee6
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 54 deletions.
29 changes: 14 additions & 15 deletions python_modules/dagster-pipes/dagster_pipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from abc import ABC, abstractmethod
from contextlib import ExitStack, contextmanager
from io import StringIO
from threading import Event, Lock, Thread
from queue import Queue
from threading import Event, Thread
from typing import (
IO,
TYPE_CHECKING,
Expand Down Expand Up @@ -81,7 +82,7 @@ def _make_message(method: str, params: Optional[Mapping[str, Any]]) -> "PipesMes


class PipesMessage(TypedDict):
"""A message sent from the orchestration process to the external process."""
"""A message sent from the external process to the orchestration process."""

__dagster_pipes_version: str
method: str
Expand Down Expand Up @@ -525,19 +526,17 @@ class PipesBlobStoreMessageWriterChannel(PipesMessageWriterChannel):

def __init__(self, *, interval: float = 10):
self._interval = interval
self._lock = Lock()
self._buffer = []
self._buffer: Queue[PipesMessage] = Queue()
self._counter = 1

def write_message(self, message: PipesMessage) -> None:
with self._lock:
self._buffer.append(message)
self._buffer.put(message)

def flush_messages(self) -> Sequence[PipesMessage]:
with self._lock:
messages = list(self._buffer)
self._buffer.clear()
return messages
items = []
while not self._buffer.empty():
items.append(self._buffer.get())
return items

@abstractmethod
def upload_messages_chunk(self, payload: StringIO, index: int) -> None: ...
Expand All @@ -558,15 +557,15 @@ def buffered_upload_loop(self) -> Iterator[None]:
def _upload_loop(self, is_task_complete: Event) -> None:
start_or_last_upload = datetime.datetime.now()
while True:
num_pending = len(self._buffer)
now = datetime.datetime.now()
if num_pending == 0 and is_task_complete.is_set():
if self._buffer.empty() and is_task_complete.is_set():
break
elif is_task_complete.is_set() or (now - start_or_last_upload).seconds > self._interval:
payload = "\n".join([json.dumps(message) for message in self.flush_messages()])
self.upload_messages_chunk(StringIO(payload), self._counter)
start_or_last_upload = now
self._counter += 1
if len(payload) > 0:
self.upload_messages_chunk(StringIO(payload), self._counter)
start_or_last_upload = now
self._counter += 1
time.sleep(1)


Expand Down
100 changes: 81 additions & 19 deletions python_modules/dagster/dagster/_core/pipes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def read_messages(
self,
handler: "PipesMessageHandler",
) -> Iterator[PipesParams]:
"""Set up a thread to read streaming messages from teh external process by tailing the
"""Set up a thread to read streaming messages from the external process by tailing the
target file.
Args:
Expand Down Expand Up @@ -223,14 +223,22 @@ class PipesBlobStoreMessageReader(PipesMessageReader):
Args:
interval (float): interval in seconds between attempts to download a chunk
forward_stdout (bool): whether to forward stdout from the pipes process to Dagster.
forward_stderr (bool): whether to forward stderr from the pipes process to Dagster.
"""

interval: float
counter: int
forward_stdout: bool
forward_stderr: bool

def __init__(self, interval: float = 10):
def __init__(
self, interval: float = 10, forward_stdout: bool = False, forward_stderr: bool = False
):
self.interval = interval
self.counter = 1
self.forward_stdout = forward_stdout
self.forward_stderr = forward_stderr

@contextmanager
def read_messages(
Expand All @@ -249,23 +257,33 @@ def read_messages(
"""
with self.get_params() as params:
is_task_complete = Event()
thread = None
messages_thread = None
stdout_thread = None
stderr_thread = None
try:
thread = Thread(
target=self._reader_thread,
args=(
handler,
params,
is_task_complete,
),
daemon=True,
messages_thread = Thread(
target=self._messages_thread, args=(handler, params, is_task_complete)
)
thread.start()
messages_thread.start()
if self.forward_stdout:
stdout_thread = Thread(
target=self._stdout_thread, args=(params, is_task_complete)
)
stdout_thread.start()
if self.forward_stderr:
stderr_thread = Thread(
target=self._stderr_thread, args=(params, is_task_complete)
)
stderr_thread.start()
yield params
finally:
is_task_complete.set()
if thread:
thread.join()
if messages_thread:
messages_thread.join()
if stdout_thread:
stdout_thread.join()
if stderr_thread:
stderr_thread.join()

@abstractmethod
@contextmanager
Expand All @@ -280,31 +298,75 @@ def get_params(self) -> Iterator[PipesParams]:
@abstractmethod
def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]: ...

def _reader_thread(
self, handler: "PipesMessageHandler", params: PipesParams, is_task_complete: Event
def download_stdout_chunk(self, params: PipesParams) -> Optional[str]:
raise NotImplementedError()

def download_stderr_chunk(self, params: PipesParams) -> Optional[str]:
raise NotImplementedError()

def _messages_thread(
self,
handler: "PipesMessageHandler",
params: PipesParams,
is_task_complete: Event,
) -> None:
start_or_last_download = datetime.datetime.now()
while True:
now = datetime.datetime.now()
if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set():
chunk = self.download_messages_chunk(self.counter, params)
start_or_last_download = now
chunk = self.download_messages_chunk(self.counter, params)
if chunk:
for line in chunk.split("\n"):
message = json.loads(line)
handler.handle_message(message)
handler.handle_message(json.loads(line))
self.counter += 1
elif is_task_complete.is_set():
break
time.sleep(1)

def _stdout_thread(
self,
params: PipesParams,
is_task_complete: Event,
) -> None:
start_or_last_download = datetime.datetime.now()
while True:
now = datetime.datetime.now()
if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set():
start_or_last_download = now
chunk = self.download_stdout_chunk(params)
if chunk:
sys.stdout.write(chunk)
elif is_task_complete.is_set():
break
time.sleep(1)

def _stderr_thread(
self,
params: PipesParams,
is_task_complete: Event,
) -> None:
start_or_last_download = datetime.datetime.now()
while True:
now = datetime.datetime.now()
if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set():
start_or_last_download = now
chunk = self.download_stderr_chunk(params)
if chunk:
sys.stderr.write(chunk)
elif is_task_complete.is_set():
break
time.sleep(1)


def extract_message_or_forward_to_stdout(handler: "PipesMessageHandler", log_line: str):
# exceptions as control flow, you love to see it
try:
message = json.loads(log_line)
if PIPES_PROTOCOL_VERSION_FIELD in message.keys():
handler.handle_message(message)
else:
sys.stdout.writelines((log_line, "\n"))
except Exception:
# move non-message logs in to stdout for compute log capture
sys.stdout.writelines((log_line, "\n"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import string
import time
from contextlib import contextmanager
from typing import Iterator, Mapping, Optional
from typing import Iterator, Literal, Mapping, Optional

import dagster._check as check
from dagster._annotations import experimental
Expand All @@ -23,6 +23,7 @@
open_pipes_session,
)
from dagster_pipes import (
DAGSTER_PIPES_BOOTSTRAP_PARAM_NAMES,
PipesContextData,
PipesExtras,
PipesParams,
Expand Down Expand Up @@ -109,6 +110,12 @@ def run(
**(self.env or {}),
**pipes_session.get_bootstrap_env_vars(),
}
cluster_log_root = pipes_session.get_bootstrap_params()[
DAGSTER_PIPES_BOOTSTRAP_PARAM_NAMES["messages"]
]["cluster_log_root"]
submit_task_dict["new_cluster"]["cluster_log_conf"] = {
"dbfs": {"destination": f"dbfs:{cluster_log_root}"}
}
task = jobs.SubmitTask.from_dict(submit_task_dict)
run_id = self.client.jobs.submit(
tasks=[task],
Expand All @@ -135,6 +142,7 @@ def run(
f"Error running Databricks job: {run.state.state_message}"
)
time.sleep(5)
time.sleep(30) # 30 seconds to make sure logs are flushed
return PipesClientCompletedInvocation(tuple(pipes_session.get_results()))


Expand Down Expand Up @@ -200,22 +208,35 @@ class PipesDbfsMessageReader(PipesBlobStoreMessageReader):
Args:
interval (float): interval in seconds between attempts to download a chunk
client (WorkspaceClient): A databricks `WorkspaceClient` object.
cluster_log_root (Optional[str]): The root path on DBFS where the cluster logs are written.
If set, this will be used to read stderr/stdout logs.
"""

def __init__(self, *, interval: int = 10, client: WorkspaceClient):
super().__init__(interval=interval)
def __init__(
self,
*,
interval: int = 10,
client: WorkspaceClient,
forward_stdout: bool = False,
forward_stderr: bool = False,
):
super().__init__(
interval=interval, forward_stdout=forward_stdout, forward_stderr=forward_stderr
)
self.dbfs_client = files.DbfsAPI(client.api_client)
self.stdio_position = {"stdout": 0, "stderr": 0}

@contextmanager
def get_params(self) -> Iterator[PipesParams]:
with dbfs_tempdir(self.dbfs_client) as tempdir:
yield {"path": tempdir}
with dbfs_tempdir(self.dbfs_client) as messages_tempdir, dbfs_tempdir(
self.dbfs_client
) as logs_tempdir:
yield {"path": messages_tempdir, "cluster_log_root": logs_tempdir}

def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]:
message_path = os.path.join(params["path"], f"{index}.json")
try:
raw_message = self.dbfs_client.read(message_path)

# Files written to dbfs using the Python IO interface used in PipesDbfsMessageWriter are
# base64-encoded.
return base64.b64decode(raw_message.data).decode("utf-8")
Expand All @@ -225,6 +246,33 @@ def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[s
except IOError:
return None

def download_stdout_chunk(self, params: PipesParams) -> Optional[str]:
return self._download_stdio_chunk(params, "stdout")

def download_stderr_chunk(self, params: PipesParams) -> Optional[str]:
return self._download_stdio_chunk(params, "stderr")

def _download_stdio_chunk(
self, params: PipesParams, stream: Literal["stdout", "stderr"]
) -> Optional[str]:
log_root_path = os.path.join(params["cluster_log_root"])
child_dirs = list(self.dbfs_client.list(log_root_path))
# The directory containing logs will not exist until either 5 minutes have elapsed or the
# job has finished.
if len(child_dirs) == 0:
return None
else:
log_path = f"dbfs:{child_dirs[0].path}/driver/stdout"
try:
read_response = self.dbfs_client.read(log_path)
assert read_response.data
content = base64.b64decode(read_response.data).decode("utf-8")
chunk = content[self.stdio_position[stream] :]
self.stdio_position[stream] = len(content)
return chunk
except IOError:
return None

def no_messages_debug_text(self) -> str:
return (
"Attempted to read messages from a temporary file in dbfs. Expected"
Expand Down
Loading

0 comments on commit 942fee6

Please sign in to comment.