Skip to content

Commit

Permalink
pipes-glue-interruption-tests-wip (#23502)
Browse files Browse the repository at this point in the history
Exploring if it's possible to rewrite the fake testing Glue client to
run execute scripts in a background process so we can test interruption
  • Loading branch information
danielgafni committed Aug 8, 2024
1 parent 0c830dd commit 7f6049c
Show file tree
Hide file tree
Showing 3 changed files with 289 additions and 71 deletions.
28 changes: 22 additions & 6 deletions python_modules/libraries/dagster-aws/dagster_aws/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from dagster import PipesClient
from dagster._annotations import experimental
from dagster._core.definitions.resource_annotation import TreatAsResourceParam
from dagster._core.errors import DagsterExecutionInterruptedError
from dagster._core.errors import DagsterExecutionInterruptedError, DagsterInvariantViolationError
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.pipes.client import (
PipesClientCompletedInvocation,
Expand Down Expand Up @@ -351,7 +351,7 @@ class PipesGlueClient(PipesClient, TreatAsResourceParam):
Args:
context_injector (Optional[PipesContextInjector]): A context injector to use to inject
context into the Glue job, for example, :py:class:`PipesGlueContextInjector`.
message_reader (Optional[PipesGlueMessageReader]): A message reader to use to read messages
message_reader (Optional[PipesMessageReader]): A message reader to use to read messages
from the glue job run. Defaults to :py:class:`PipesGlueLogsMessageReader`.
client (Optional[boto3.client]): The boto Glue client used to launch the Glue job
forward_termination (bool): Whether to cancel the Glue job run when the Dagster process receives a termination signal.
Expand All @@ -362,13 +362,21 @@ def __init__(
context_injector: PipesContextInjector,
message_reader: Optional[PipesMessageReader] = None,
client: Optional[boto3.client] = None,
forward_termination: bool = False,
forward_termination: bool = True,
):
self._client = client or boto3.client("glue")
self._context_injector = context_injector
self._message_reader = message_reader or PipesCloudWatchMessageReader()
self.forward_termination = check.bool_param(forward_termination, "forward_termination")

self._last_job_run_id: Optional[str] = None

@property
def last_job_run_id(self) -> str:
if not self._last_job_run_id:
raise DagsterInvariantViolationError("No AWS Glue jobs were started using this client")
return self._last_job_run_id

@classmethod
def _is_dagster_maintained(cls) -> bool:
return True
Expand Down Expand Up @@ -448,6 +456,7 @@ def run(

try:
run_id = self._client.start_job_run(**params)["JobRunId"]
self.__last_job_run_id = run_id
except ClientError as err:
context.log.error(
"Couldn't create job %s. Here's why: %s: %s",
Expand All @@ -468,9 +477,9 @@ def run(
self._terminate_job_run(context=context, job_name=job_name, run_id=run_id)
raise

if response["JobRun"]["JobRunState"] == "FAILED":
if status := response["JobRun"]["JobRunState"] != "SUCCEEDED":
raise RuntimeError(
f"Glue job {job_name} run {run_id} failed:\n{response['JobRun']['ErrorMessage']}"
f"Glue job {job_name} run {run_id} completed with status {status} :\n{response['JobRun'].get('ErrorMessage')}"
)
else:
context.log.info(f"Glue job {job_name} run {run_id} completed successfully")
Expand All @@ -488,7 +497,14 @@ def run(
def _wait_for_job_run_completion(self, job_name: str, run_id: str) -> Dict[str, Any]:
while True:
response = self._client.get_job_run(JobName=job_name, RunId=run_id)
if response["JobRun"]["JobRunState"] in ["FAILED", "SUCCEEDED"]:
# https://docs.aws.amazon.com/glue/latest/dg/job-run-statuses.html
if response["JobRun"]["JobRunState"] in [
"FAILED",
"SUCCEEDED",
"STOPPED",
"TIMEOUT",
"ERROR",
]:
return response
time.sleep(5)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
import subprocess
import sys
import tempfile
import time
from typing import Dict, Literal, Optional
import warnings
from dataclasses import dataclass
from subprocess import PIPE, Popen
from typing import Dict, List, Literal, Optional

import boto3


@dataclass
class SimulatedJobRun:
popen: Popen
job_run_id: str
log_group: str
local_script: tempfile._TemporaryFileWrapper
stopped: bool = False


class LocalGlueMockClient:
def __init__(
self,
Expand All @@ -21,15 +32,49 @@ def __init__(
to receive any Dagster messages from it.
If pipes_messages_backend is configured to be CloudWatch, it also uploads stderr and stdout logs to CloudWatch
as if this has been done by Glue.
Once the job is submitted, it is being executed in a separate process to mimic Glue behavior.
Once the job status is requested, the process is checked for its status and the result is returned.
"""
self.aws_endpoint_url = aws_endpoint_url
self.s3_client = s3_client
self.glue_client = glue_client
self.pipes_messages_backend = pipes_messages_backend
self.cloudwatch_client = cloudwatch_client

def get_job_run(self, *args, **kwargs):
return self.glue_client.get_job_run(*args, **kwargs)
self.process = None # jobs will be executed in a separate process

self._job_runs: Dict[str, SimulatedJobRun] = {} # mapping of JobRunId to SimulatedJobRun

def get_job_run(self, JobName: str, RunId: str):
# get original response
response = self.glue_client.get_job_run(JobName=JobName, RunId=RunId)

# check if status override is set
simulated_job_run = self._job_runs[RunId]

if simulated_job_run.stopped:
response["JobRun"]["JobRunState"] = "STOPPED"
return response

# check if popen has completed
if simulated_job_run.popen.poll() is not None:
simulated_job_run.popen.wait()
# check status code
if simulated_job_run.popen.returncode == 0:
response["JobRun"]["JobRunState"] = "SUCCEEDED"
else:
response["JobRun"]["JobRunState"] = "FAILED"
_, stderr = simulated_job_run.popen.communicate()
response["JobRun"]["ErrorMessage"] = stderr.decode()

# upload logs to cloudwatch
if self.pipes_messages_backend == "cloudwatch":
self._upload_logs_to_cloudwatch(RunId)
else:
response["JobRun"]["JobRunState"] = "RUNNING"

return response

def start_job_run(self, JobName: str, Arguments: Optional[Dict[str, str]], **kwargs):
params = {
Expand All @@ -45,67 +90,97 @@ def start_job_run(self, JobName: str, Arguments: Optional[Dict[str, str]], **kwa
bucket = script_s3_path.split("/")[2]
key = "/".join(script_s3_path.split("/")[3:])

# load the script and execute it locally
with tempfile.NamedTemporaryFile() as f:
self.s3_client.download_file(bucket, key, f.name)

args = []
for key, val in (Arguments or {}).items():
args.append(key)
args.append(val)

result = subprocess.run(
[sys.executable, f.name, *args],
check=False,
env={
"AWS_ENDPOINT_URL": self.aws_endpoint_url,
"TESTING_PIPES_MESSAGES_BACKEND": self.pipes_messages_backend,
},
capture_output=True,
)

# mock the job run with moto
response = self.glue_client.start_job_run(**params)
job_run_id = response["JobRunId"]

job_run_response = self.glue_client.get_job_run(JobName=JobName, RunId=job_run_id)
log_group = job_run_response["JobRun"]["LogGroupName"]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
f = tempfile.NamedTemporaryFile(
delete=False
) # we will close this file later during garbage collection
# load the S3 script to a local file
self.s3_client.download_file(bucket, key, f.name)

# execute the script in a separate process
args = []
for key, val in (Arguments or {}).items():
args.append(key)
args.append(val)
popen = Popen(
[sys.executable, f.name, *args],
env={
"AWS_ENDPOINT_URL": self.aws_endpoint_url,
"TESTING_PIPES_MESSAGES_BACKEND": self.pipes_messages_backend,
},
stdout=PIPE,
stderr=PIPE,
)

# record execution metadata for later use
self._job_runs[job_run_id] = SimulatedJobRun(
popen=popen,
job_run_id=job_run_id,
log_group=self.glue_client.get_job_run(JobName=JobName, RunId=job_run_id)["JobRun"][
"LogGroupName"
],
local_script=f,
)

return response

def batch_stop_job_run(self, JobName: str, JobRunIds: List[str]):
for job_run_id in JobRunIds:
if simulated_job_run := self._job_runs.get(job_run_id):
simulated_job_run.popen.terminate()
simulated_job_run.stopped = True
self._upload_logs_to_cloudwatch(job_run_id)

def _upload_logs_to_cloudwatch(self, job_run_id: str):
log_group = self._job_runs[job_run_id].log_group
stdout, stderr = self._job_runs[job_run_id].popen.communicate()

if self.pipes_messages_backend == "cloudwatch":
assert (
self.cloudwatch_client is not None
), "cloudwatch_client has to be provided with cloudwatch messages backend"

self.cloudwatch_client.create_log_group(
logGroupName=f"{log_group}/output",
)

self.cloudwatch_client.create_log_stream(
logGroupName=f"{log_group}/output",
logStreamName=job_run_id,
)

for line in result.stderr.decode().split(
"\n"
): # uploading log lines one by one is good enough for tests
if line:
self.cloudwatch_client.put_log_events(
logGroupName=f"{log_group}/output", # yes, Glue routes stderr to /output
logStreamName=job_run_id,
logEvents=[{"timestamp": int(time.time() * 1000), "message": str(line)}],
)
time.sleep(
0.01
) # make sure the logs will be properly filtered by ms timestamp when accessed next time

# replace run state with actual results
response["JobRun"] = {}

response["JobRun"]["JobRunState"] = "SUCCEEDED" if result.returncode == 0 else "FAILED"

# add error message if failed
if result.returncode != 0:
# this actually has to be just the Python exception, but this is good enough for now
response["JobRun"]["ErrorMessage"] = result.stderr
assert (
self.cloudwatch_client is not None
), "cloudwatch_client has to be provided with cloudwatch messages backend"

return response
try:
self.cloudwatch_client.create_log_group(
logGroupName=f"{log_group}/output",
)
except self.cloudwatch_client.exceptions.ResourceAlreadyExistsException:
pass

try:
self.cloudwatch_client.create_log_stream(
logGroupName=f"{log_group}/output",
logStreamName=job_run_id,
)
except self.cloudwatch_client.exceptions.ResourceAlreadyExistsException:
pass

for out in [stderr, stdout]: # Glue routes both stderr and stdout to /output
for line in out.decode().split(
"\n"
): # uploading log lines one by one is good enough for tests
if line:
self.cloudwatch_client.put_log_events(
logGroupName=f"{log_group}/output",
logStreamName=job_run_id,
logEvents=[
{"timestamp": int(time.time() * 1000), "message": str(line)}
],
)
time.sleep(
0.01
) # make sure the logs will be properly filtered by ms timestamp when accessed next time

def __del__(self):
# cleanup local script paths
for job_run in self._job_runs.values():
job_run.local_script.close()
Loading

0 comments on commit 7f6049c

Please sign in to comment.