diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes.py index 296a5442dbe6c..fb3f7b230dfed 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/pipes.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes.py @@ -353,7 +353,7 @@ def __init__( context_injector: PipesContextInjector, message_reader: Optional[PipesMessageReader] = None, client: Optional[boto3.client] = None, - forward_termination: bool = False, + forward_termination: bool = True, ): self._client = client or boto3.client("glue") self._context_injector = context_injector @@ -439,6 +439,7 @@ def run( try: run_id = self._client.start_job_run(**params)["JobRunId"] + except ClientError as err: context.log.error( "Couldn't create job %s. Here's why: %s: %s", @@ -459,9 +460,9 @@ def run( self._terminate_job_run(context=context, job_name=job_name, run_id=run_id) raise - if response["JobRun"]["JobRunState"] == "FAILED": + if status := response["JobRun"]["JobRunState"] != "SUCCEEDED": raise RuntimeError( - f"Glue job {job_name} run {run_id} failed:\n{response['JobRun']['ErrorMessage']}" + f"Glue job {job_name} run {run_id} completed with status {status} :\n{response['JobRun'].get('ErrorMessage')}" ) else: context.log.info(f"Glue job {job_name} run {run_id} completed successfully") @@ -479,7 +480,14 @@ def run( def _wait_for_job_run_completion(self, job_name: str, run_id: str) -> Dict[str, Any]: while True: response = self._client.get_job_run(JobName=job_name, RunId=run_id) - if response["JobRun"]["JobRunState"] in ["FAILED", "SUCCEEDED"]: + # https://docs.aws.amazon.com/glue/latest/dg/job-run-statuses.html + if response["JobRun"]["JobRunState"] in [ + "FAILED", + "SUCCEEDED", + "STOPPED", + "TIMEOUT", + "ERROR", + ]: return response time.sleep(5) diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/fake_glue.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/fake_glue.py index 744b30e549353..d6346a47c0248 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/fake_glue.py +++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/fake_glue.py @@ -1,12 +1,23 @@ -import subprocess import sys import tempfile import time -from typing import Dict, Literal, Optional +import warnings +from dataclasses import dataclass +from subprocess import PIPE, Popen +from typing import Dict, List, Literal, Optional import boto3 +@dataclass +class SimulatedJobRun: + popen: Popen + job_run_id: str + log_group: str + local_script: tempfile._TemporaryFileWrapper + stopped: bool = False + + class LocalGlueMockClient: def __init__( self, @@ -21,6 +32,9 @@ def __init__( to receive any Dagster messages from it. If pipes_messages_backend is configured to be CloudWatch, it also uploads stderr and stdout logs to CloudWatch as if this has been done by Glue. + + Once the job is submitted, it is being executed in a separate process to mimic Glue behavior. + Once the job status is requested, the process is checked for its status and the result is returned. """ self.aws_endpoint_url = aws_endpoint_url self.s3_client = s3_client @@ -28,8 +42,39 @@ def __init__( self.pipes_messages_backend = pipes_messages_backend self.cloudwatch_client = cloudwatch_client - def get_job_run(self, *args, **kwargs): - return self.glue_client.get_job_run(*args, **kwargs) + self.process = None # jobs will be executed in a separate process + + self._job_runs: Dict[str, SimulatedJobRun] = {} # mapping of JobRunId to SimulatedJobRun + + def get_job_run(self, JobName: str, RunId: str): + # get original response + response = self.glue_client.get_job_run(JobName=JobName, RunId=RunId) + + # check if status override is set + simulated_job_run = self._job_runs[RunId] + + if simulated_job_run.stopped: + response["JobRun"]["JobRunState"] = "STOPPED" + return response + + # check if popen has completed + if simulated_job_run.popen.poll() is not None: + simulated_job_run.popen.wait() + # check status code + if simulated_job_run.popen.returncode == 0: + response["JobRun"]["JobRunState"] = "SUCCEEDED" + else: + response["JobRun"]["JobRunState"] = "FAILED" + _, stderr = simulated_job_run.popen.communicate() + response["JobRun"]["ErrorMessage"] = stderr.decode() + + # upload logs to cloudwatch + if self.pipes_messages_backend == "cloudwatch": + self._upload_logs_to_cloudwatch(RunId) + else: + response["JobRun"]["JobRunState"] = "RUNNING" + + return response def start_job_run(self, JobName: str, Arguments: Optional[Dict[str, str]], **kwargs): params = { @@ -45,67 +90,97 @@ def start_job_run(self, JobName: str, Arguments: Optional[Dict[str, str]], **kwa bucket = script_s3_path.split("/")[2] key = "/".join(script_s3_path.split("/")[3:]) - # load the script and execute it locally - with tempfile.NamedTemporaryFile() as f: - self.s3_client.download_file(bucket, key, f.name) - - args = [] - for key, val in (Arguments or {}).items(): - args.append(key) - args.append(val) - - result = subprocess.run( - [sys.executable, f.name, *args], - check=False, - env={ - "AWS_ENDPOINT_URL": self.aws_endpoint_url, - "TESTING_PIPES_MESSAGES_BACKEND": self.pipes_messages_backend, - }, - capture_output=True, - ) - # mock the job run with moto response = self.glue_client.start_job_run(**params) job_run_id = response["JobRunId"] - job_run_response = self.glue_client.get_job_run(JobName=JobName, RunId=job_run_id) - log_group = job_run_response["JobRun"]["LogGroupName"] + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + f = tempfile.NamedTemporaryFile( + delete=False + ) # we will close this file later during garbage collection + # load the S3 script to a local file + self.s3_client.download_file(bucket, key, f.name) + + # execute the script in a separate process + args = [] + for key, val in (Arguments or {}).items(): + args.append(key) + args.append(val) + popen = Popen( + [sys.executable, f.name, *args], + env={ + "AWS_ENDPOINT_URL": self.aws_endpoint_url, + "TESTING_PIPES_MESSAGES_BACKEND": self.pipes_messages_backend, + }, + stdout=PIPE, + stderr=PIPE, + ) + + # record execution metadata for later use + self._job_runs[job_run_id] = SimulatedJobRun( + popen=popen, + job_run_id=job_run_id, + log_group=self.glue_client.get_job_run(JobName=JobName, RunId=job_run_id)["JobRun"][ + "LogGroupName" + ], + local_script=f, + ) + + return response + + def batch_stop_job_run(self, JobName: str, JobRunIds: List[str]): + for job_run_id in JobRunIds: + if simulated_job_run := self._job_runs.get(job_run_id): + simulated_job_run.popen.terminate() + simulated_job_run.stopped = True + self._upload_logs_to_cloudwatch(job_run_id) + + def _upload_logs_to_cloudwatch(self, job_run_id: str): + log_group = self._job_runs[job_run_id].log_group + stdout, stderr = self._job_runs[job_run_id].popen.communicate() if self.pipes_messages_backend == "cloudwatch": assert ( self.cloudwatch_client is not None ), "cloudwatch_client has to be provided with cloudwatch messages backend" - self.cloudwatch_client.create_log_group( - logGroupName=f"{log_group}/output", - ) - - self.cloudwatch_client.create_log_stream( - logGroupName=f"{log_group}/output", - logStreamName=job_run_id, - ) - - for line in result.stderr.decode().split( - "\n" - ): # uploading log lines one by one is good enough for tests - if line: - self.cloudwatch_client.put_log_events( - logGroupName=f"{log_group}/output", # yes, Glue routes stderr to /output - logStreamName=job_run_id, - logEvents=[{"timestamp": int(time.time() * 1000), "message": str(line)}], - ) - time.sleep( - 0.01 - ) # make sure the logs will be properly filtered by ms timestamp when accessed next time - - # replace run state with actual results - response["JobRun"] = {} - - response["JobRun"]["JobRunState"] = "SUCCEEDED" if result.returncode == 0 else "FAILED" - - # add error message if failed - if result.returncode != 0: - # this actually has to be just the Python exception, but this is good enough for now - response["JobRun"]["ErrorMessage"] = result.stderr + assert ( + self.cloudwatch_client is not None + ), "cloudwatch_client has to be provided with cloudwatch messages backend" - return response + try: + self.cloudwatch_client.create_log_group( + logGroupName=f"{log_group}/output", + ) + except self.cloudwatch_client.exceptions.ResourceAlreadyExistsException: + pass + + try: + self.cloudwatch_client.create_log_stream( + logGroupName=f"{log_group}/output", + logStreamName=job_run_id, + ) + except self.cloudwatch_client.exceptions.ResourceAlreadyExistsException: + pass + + for out in [stderr, stdout]: # Glue routes both stderr and stdout to /output + for line in out.decode().split( + "\n" + ): # uploading log lines one by one is good enough for tests + if line: + self.cloudwatch_client.put_log_events( + logGroupName=f"{log_group}/output", + logStreamName=job_run_id, + logEvents=[ + {"timestamp": int(time.time() * 1000), "message": str(line)} + ], + ) + time.sleep( + 0.01 + ) # make sure the logs will be properly filtered by ms timestamp when accessed next time + + def __del__(self): + # cleanup local script paths + for job_run in self._job_runs.values(): + job_run.local_script.close() diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/test_pipes.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/test_pipes.py index d516cee6270ea..cae96b52c717d 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/test_pipes.py +++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/test_pipes.py @@ -1,16 +1,19 @@ import base64 import inspect import json +import multiprocessing +import os import re import shutil import textwrap +import time from contextlib import contextmanager from tempfile import NamedTemporaryFile from typing import Any, Callable, Iterator, Literal import boto3 import pytest -from dagster import asset, materialize, open_pipes_session +from dagster import AssetsDefinition, asset, materialize, open_pipes_session from dagster._core.definitions.asset_check_spec import AssetCheckKey, AssetCheckSpec from dagster._core.definitions.data_version import ( DATA_VERSION_IS_USER_PROVIDED_TAG, @@ -362,13 +365,6 @@ def test_glue_pipes( cloudwatch_client, pipes_messages_backend: Literal["s3", "cloudwatch"], ): - context_injector = PipesS3ContextInjector(bucket=_S3_TEST_BUCKET, client=s3_client) - message_reader = ( - PipesS3MessageReader(bucket=_S3_TEST_BUCKET, client=s3_client, interval=0.001) - if pipes_messages_backend == "s3" - else PipesCloudWatchMessageReader(client=cloudwatch_client) - ) - @asset(check_specs=[AssetCheckSpec(name="foo_check", asset=AssetKey(["foo"]))]) def foo(context: AssetExecutionContext, pipes_glue_client: PipesGlueClient): results = pipes_glue_client.run( @@ -378,6 +374,13 @@ def foo(context: AssetExecutionContext, pipes_glue_client: PipesGlueClient): ).get_results() return results + context_injector = PipesS3ContextInjector(bucket=_S3_TEST_BUCKET, client=s3_client) + message_reader = ( + PipesS3MessageReader(bucket=_S3_TEST_BUCKET, client=s3_client, interval=0.001) + if pipes_messages_backend == "s3" + else PipesCloudWatchMessageReader(client=cloudwatch_client) + ) + pipes_glue_client = PipesGlueClient( client=LocalGlueMockClient( aws_endpoint_url=_MOTO_SERVER_URL, @@ -409,3 +412,127 @@ def foo(context: AssetExecutionContext, pipes_glue_client: PipesGlueClient): ) assert len(asset_check_executions) == 1 assert asset_check_executions[0].status == AssetCheckExecutionRecordStatus.SUCCEEDED + + +@pytest.fixture +def long_glue_job(s3_client, glue_client) -> Iterator[str]: + job_name = "Very Long Job" + + def script_fn(): + import os + import time + + import boto3 + from dagster_pipes import PipesCliArgsParamsLoader, PipesS3ContextLoader, open_dagster_pipes + + s3_client = boto3.client( + "s3", region_name="us-east-1", endpoint_url="http://localhost:5193" + ) + + with open_dagster_pipes( + context_loader=PipesS3ContextLoader(client=s3_client), + params_loader=PipesCliArgsParamsLoader(), + ) as context: + context.log.info("Glue job sleeping...") + time.sleep(int(os.getenv("SLEEP_SECONDS", "1"))) + + with temp_script(script_fn) as script_path: + s3_key = "long_glue_script.py" + s3_client.upload_file(script_path, _S3_TEST_BUCKET, s3_key) + + glue_client.create_job( + Name=job_name, + Description="Test job", + Command={ + "Name": "glueetl", # Spark job type + "ScriptLocation": f"s3://{_S3_TEST_BUCKET}/{s3_key}", + "PythonVersion": "3.10", + }, + GlueVersion="4.0", + Role="arn:aws:iam::012345678901:role/service-role/AWSGlueServiceRole-test", + ) + + yield job_name + + +@pytest.fixture +def foo_asset(long_glue_job: str) -> AssetsDefinition: + @asset + def foo(context: AssetExecutionContext, pipes_glue_client: PipesGlueClient): + results = pipes_glue_client.run( + context=context, + job_name=long_glue_job, + ).get_results() + return results + + return foo + + +@pytest.fixture +def local_glue_mock_client(glue_client, s3_client, cloudwatch_client) -> LocalGlueMockClient: + return LocalGlueMockClient( + aws_endpoint_url=_MOTO_SERVER_URL, + glue_client=glue_client, + s3_client=s3_client, + cloudwatch_client=cloudwatch_client, + pipes_messages_backend="cloudwatch", + ) + + +@pytest.fixture +def pipes_glue_client(local_glue_mock_client, s3_client, cloudwatch_client) -> PipesGlueClient: + return PipesGlueClient( + client=local_glue_mock_client, + context_injector=PipesS3ContextInjector(bucket=_S3_TEST_BUCKET, client=s3_client), + message_reader=PipesCloudWatchMessageReader(client=cloudwatch_client), + ) + + +def test_glue_pipes_interruption_forwarding_asset_is_valid( + foo_asset, pipes_glue_client, local_glue_mock_client +): + # make sure this runs without multiprocessing first + + with instance_for_test() as instance: + materialize( + [foo_asset], instance=instance, resources={"pipes_glue_client": pipes_glue_client} + ) + + +def test_glue_pipes_interruption_forwarding( + long_glue_job, foo_asset, pipes_glue_client, local_glue_mock_client +): + def materialize_asset(env, return_dict): + os.environ.update(env) + try: + with instance_for_test() as instance: + materialize( # this will be interrupted and raise an exception + [foo_asset], + instance=instance, + resources={"pipes_glue_client": pipes_glue_client}, + ) + finally: + job_run_id = next(iter(local_glue_mock_client._job_runs.keys())) # noqa + return_dict[0] = local_glue_mock_client.get_job_run(long_glue_job, job_run_id) + + with multiprocessing.Manager() as manager: + return_dict = manager.dict() + + p = multiprocessing.Process( + target=materialize_asset, + args=( + {"SLEEP_SECONDS": "10"}, + return_dict, + ), + ) + p.start() + + while p.is_alive(): + # we started executing the run + # time to interrupt it! + time.sleep(3) + p.terminate() + + p.join() + assert not p.is_alive() + assert return_dict[0]["JobRun"]["JobRunState"] == "STOPPED"