diff --git a/docs/content/api/modules.json.gz b/docs/content/api/modules.json.gz index 075cbe9945024..0261c0e14a5ee 100644 Binary files a/docs/content/api/modules.json.gz and b/docs/content/api/modules.json.gz differ diff --git a/docs/content/api/searchindex.json.gz b/docs/content/api/searchindex.json.gz index 88c41e193bebb..981bd500dfda9 100644 Binary files a/docs/content/api/searchindex.json.gz and b/docs/content/api/searchindex.json.gz differ diff --git a/docs/content/api/sections.json.gz b/docs/content/api/sections.json.gz index 47161918e3282..95fa14769061e 100644 Binary files a/docs/content/api/sections.json.gz and b/docs/content/api/sections.json.gz differ diff --git a/pyright/alt-1/requirements-pinned.txt b/pyright/alt-1/requirements-pinned.txt index 94228fcfbeb80..ffaddc6610597 100644 --- a/pyright/alt-1/requirements-pinned.txt +++ b/pyright/alt-1/requirements-pinned.txt @@ -1,5 +1,5 @@ agate==1.9.1 -aiobotocore==2.13.2 +aiobotocore==2.13.3 aiofile==3.8.8 aiohappyeyeballs==2.4.0 aiohttp==3.10.5 @@ -9,7 +9,6 @@ alembic==1.13.2 aniso8601==9.0.1 annotated-types==0.7.0 anyio==4.4.0 -appnope==0.1.4 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.3.0 @@ -18,16 +17,17 @@ asn1crypto==1.5.1 astroid==3.2.4 asttokens==2.4.1 async-lru==2.0.4 -async-timeout==4.0.3 attrs==24.2.0 babel==2.16.0 backoff==2.2.1 backports-tarfile==1.2.0 beautifulsoup4==4.12.3 bleach==6.1.0 -boto3==1.34.131 -botocore==1.34.131 -buildkite-test-collector==0.1.8 +boto3==1.34.162 +boto3-stubs-lite==1.35.6 +botocore==1.34.162 +botocore-stubs==1.35.6 +buildkite-test-collector==0.1.9 cachetools==5.5.0 caio==0.9.17 certifi==2024.7.4 @@ -83,7 +83,6 @@ distlib==0.3.8 docker==7.1.0 docstring-parser==0.16 duckdb==1.0.0 -exceptiongroup==1.2.2 execnet==2.1.1 executing==2.0.1 fastjsonschema==2.20.0 @@ -103,11 +102,12 @@ google-cloud-core==2.4.1 google-cloud-storage==2.18.2 google-crc32c==1.5.0 google-resumable-media==2.7.2 -googleapis-common-protos==1.63.2 +googleapis-common-protos==1.64.0 gql==3.5.0 graphene==3.3 graphql-core==3.2.3 graphql-relay==3.2.0 +greenlet==3.0.3 grpcio==1.66.0 grpcio-health-checking==1.62.3 grpcio-status==1.62.3 @@ -118,8 +118,8 @@ httplib2==0.22.0 httptools==0.6.1 httpx==0.27.0 humanfriendly==10.0 -hypothesis==6.111.1 -idna==3.7 +hypothesis==6.111.2 +idna==3.8 importlib-metadata==6.11.0 iniconfig==2.0.0 ipykernel==6.29.5 @@ -131,6 +131,7 @@ jaraco-classes==3.4.0 jaraco-context==6.0.1 jaraco-functools==4.0.2 jedi==0.19.1 +jeepney==0.8.0 jinja2==3.1.4 jmespath==1.0.1 joblib==1.4.2 @@ -144,7 +145,7 @@ jupyter-events==0.10.0 jupyter-lsp==2.2.5 jupyter-server==2.14.2 jupyter-server-terminals==0.5.3 -jupyterlab==4.2.4 +jupyterlab==4.2.5 jupyterlab-pygments==0.3.0 jupyterlab-server==2.27.3 keyring==25.3.0 @@ -168,7 +169,8 @@ morefs==0.2.2 msgpack==1.0.8 multidict==6.0.5 multimethod==1.10 -mypy==1.11.1 +mypy==1.11.2 +mypy-boto3-ecs==1.35.2 mypy-extensions==1.0.0 mypy-protobuf==3.6.0 nbclient==0.10.0 @@ -219,7 +221,7 @@ pygments==2.18.0 pyjwt==2.9.0 pylint==3.2.6 pyopenssl==24.2.1 -pyparsing==3.1.2 +pyparsing==3.1.4 pyproject-api==1.7.1 pyright==1.1.370 pyspark==3.5.2 @@ -245,7 +247,7 @@ requests-toolbelt==1.0.0 responses==0.23.1 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 -rich==13.7.1 +rich==13.8.0 rpds-py==0.20.0 rsa==4.9 s3fs==2024.3.1 @@ -253,6 +255,7 @@ s3transfer==0.10.2 scikit-learn==1.5.1 scipy==1.14.1 seaborn==0.13.2 +secretstorage==3.3.3 send2trash==1.8.3 setuptools==73.0.1 shellingham==1.5.4 @@ -264,13 +267,13 @@ snowflake-sqlalchemy==1.5.1 sortedcontainers==2.4.0 soupsieve==2.6 sqlalchemy==1.4.53 -sqlglot==25.16.0 +sqlglot==25.17.0 sqlglotrs==0.2.9 sqlparse==0.5.1 stack-data==0.6.3 starlette==0.38.2 structlog==24.4.0 -syrupy==4.6.4 +syrupy==4.7.1 tabulate==0.9.0 terminado==0.18.1 text-unidecode==1.3 @@ -284,7 +287,8 @@ tox==4.18.0 tqdm==4.66.5 traitlets==5.14.3 typeguard==4.3.0 -typer==0.12.4 +typer==0.12.5 +types-awscrt==0.21.2 types-backports==0.1.3 types-certifi==2021.10.8.3 types-cffi==1.16.0.20240331 @@ -299,6 +303,7 @@ types-python-dateutil==2.9.0.20240821 types-pytz==2024.1.0.20240417 types-pyyaml==6.0.12.20240808 types-requests==2.32.0.20240712 +types-s3transfer==0.10.1 types-setuptools==73.0.0.20240822 types-simplejson==3.19.0.20240801 types-six==1.16.21.20240513 @@ -316,7 +321,7 @@ urllib3==2.2.2 uvicorn==0.30.6 uvloop==0.20.0 virtualenv==20.26.3 -watchdog==4.0.2 +watchdog==5.0.0 watchfiles==0.23.0 wcwidth==0.2.13 webcolors==24.8.0 @@ -326,4 +331,4 @@ websockets==13.0 wheel==0.44.0 wrapt==1.16.0 yarl==1.9.4 -zipp==3.20.0 +zipp==3.20.1 diff --git a/pyright/master/requirements-pinned.txt b/pyright/master/requirements-pinned.txt index 74154502ac45c..3a6f8587f48b9 100644 --- a/pyright/master/requirements-pinned.txt +++ b/pyright/master/requirements-pinned.txt @@ -1,4 +1,4 @@ -acryl-datahub==0.14.0.2 +acryl-datahub==0.14.0.3 agate==1.9.1 aiofile==3.8.8 aiofiles==24.1.0 @@ -25,7 +25,6 @@ apache-airflow-providers-sqlite==3.8.2 apeye==1.4.1 apeye-core==1.1.5 apispec==6.6.1 -appnope==0.1.4 argcomplete==3.5.0 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 @@ -37,14 +36,13 @@ asn1crypto==1.5.1 asttokens==2.4.1 astunparse==1.6.3 async-lru==2.0.4 -async-timeout==4.0.3 attrs==24.2.0 autodocsumm==0.2.13 autoflake==2.3.1 -e python_modules/automation avro==1.11.3 avro-gen3==0.7.13 -aws-sam-translator==1.89.0 +aws-sam-translator==1.91.0 aws-xray-sdk==2.14.0 azure-core==1.30.2 azure-identity==1.17.1 @@ -58,10 +56,12 @@ billiard==4.2.0 bitmath==1.3.3.1 bleach==6.1.0 blinker==1.8.2 -bokeh==3.5.1 -boto3==1.35.4 -botocore==1.35.4 -buildkite-test-collector==0.1.8 +bokeh==3.5.2 +boto3==1.35.6 +boto3-stubs-lite==1.35.6 +botocore==1.35.6 +botocore-stubs==1.35.6 +buildkite-test-collector==0.1.9 cachecontrol==0.14.0 cached-property==1.5.2 cachelib==0.9.0 @@ -92,7 +92,7 @@ confluent-kafka==2.5.0 connexion==2.14.2 contourpy==1.2.1 coverage==7.6.1 -cron-descriptor==1.4.3 +cron-descriptor==1.4.5 croniter==3.0.3 cryptography==43.0.0 cssutils==2.11.1 @@ -209,7 +209,6 @@ duckdb==1.0.0 ecdsa==0.19.0 email-validator==1.3.1 entrypoints==0.4 -exceptiongroup==1.2.2 execnet==2.1.1 executing==2.0.1 expandvars==0.12.0 @@ -233,7 +232,7 @@ flatbuffers==24.3.25 fonttools==4.53.1 fqdn==1.5.1 frozenlist==1.4.1 -fsspec==2024.3.1 +fsspec==2024.3.0 gitdb==4.0.11 gitpython==3.1.43 giturlparse==0.12.0 @@ -248,13 +247,14 @@ google-cloud-storage==2.18.2 google-crc32c==1.5.0 google-re2==1.1.20240702 google-resumable-media==2.7.2 -googleapis-common-protos==1.63.2 +googleapis-common-protos==1.64.0 gql==3.5.0 graphene==3.3 graphql-core==3.2.3 graphql-relay==3.2.0 graphviz==0.20.3 great-expectations==0.17.11 +greenlet==3.0.3 grpcio==1.66.0 grpcio-health-checking==1.62.3 grpcio-status==1.62.3 @@ -269,8 +269,8 @@ httptools==0.6.1 httpx==0.27.0 humanfriendly==10.0 humanize==4.10.0 -hypothesis==6.111.1 -idna==3.7 +hypothesis==6.111.2 +idna==3.8 ijson==3.3.0 imagesize==1.4.1 importlib-metadata==6.11.0 @@ -305,7 +305,7 @@ jupyter-events==0.10.0 jupyter-lsp==2.2.5 jupyter-server==2.14.2 jupyter-server-terminals==0.5.3 -jupyterlab==4.2.4 +jupyterlab==4.2.5 jupyterlab-pygments==0.3.0 jupyterlab-server==2.27.3 jupyterlab-widgets==3.0.13 @@ -317,10 +317,10 @@ kubernetes==30.1.0 kubernetes-asyncio==30.1.0 langchain==0.2.9 langchain-community==0.2.9 -langchain-core==0.2.34 +langchain-core==0.2.35 langchain-openai==0.1.14 langchain-text-splitters==0.2.2 -langsmith==0.1.102 +langsmith==0.1.104 lazy-object-proxy==1.10.0 leather==0.4.0 limits==3.13.0 @@ -357,6 +357,7 @@ msal-extensions==1.2.0 msgpack==1.0.8 multidict==6.0.5 multimethod==1.10 +mypy-boto3-ecs==1.35.2 mypy-extensions==1.0.0 mypy-protobuf==3.6.0 mysql-connector-python==9.0.0 @@ -372,6 +373,18 @@ noteable-origami==1.1.5 notebook==7.2.1 notebook-shim==0.2.4 numpy==1.26.4 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.20.5 +nvidia-nvjitlink-cu12==12.6.20 +nvidia-nvtx-cu12==12.1.105 oauth2client==4.1.3 oauthlib==3.2.2 objgraph==3.6.1 @@ -408,7 +421,7 @@ partd==1.4.2 path==16.16.0 pathable==0.4.3 pathspec==0.12.1 -pathvalidate==3.2.0 +pathvalidate==3.2.1 pendulum==2.1.2 pexpect==4.9.0 pillow==10.4.0 @@ -438,7 +451,7 @@ pyarrow-hotfix==0.6 pyasn1==0.6.0 pyasn1-modules==0.4.0 pycparser==2.22 -pydantic==1.10.17 +pydantic==1.10.18 pydata-google-auth==1.8.2 pyflakes==3.2.0 pygments==2.18.0 @@ -446,7 +459,7 @@ pyjwt==2.9.0 pymdown-extensions==10.9 pynacl==1.5.0 pyopenssl==24.2.1 -pyparsing==3.1.2 +pyparsing==3.1.4 pypd==1.1.0 pyproject-api==1.7.1 pyright==1.1.370 @@ -488,7 +501,7 @@ requirements-parser==0.11.0 responses==0.23.1 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 -rich==13.7.1 +rich==13.8.0 rich-argparse==1.5.2 rpds-py==0.20.0 rsa==4.9 @@ -498,7 +511,7 @@ s3transfer==0.10.2 scikit-learn==1.5.1 scipy==1.14.1 scrapbook==0.5.0 -sdf-cli==0.3.21 +sdf-cli==0.3.23 seaborn==0.13.2 selenium==4.23.1 semver==3.0.2 @@ -515,7 +528,7 @@ skein==0.8.2 skl2onnx==1.17.0 slack-sdk==3.31.0 sling==1.2.15 -sling-mac-arm64==1.2.15 +sling-linux-amd64==1.2.15 smmap==5.0.1 sniffio==1.3.1 snowballstemmer==2.2.0 @@ -538,7 +551,7 @@ sphinxcontrib-serializinghtml==2.0.0 sqlalchemy==1.4.53 sqlalchemy-jsonfield==1.0.2 sqlalchemy-utils==0.41.2 -sqlglot==25.16.0 +sqlglot==25.17.0 sqlglotrs==0.2.9 sqlparse==0.5.1 sshpubkeys==3.3.1 @@ -547,7 +560,7 @@ stack-data==0.6.3 starlette==0.38.2 structlog==24.4.0 sympy==1.13.2 -syrupy==4.6.4 +syrupy==4.7.1 tabledata==1.3.3 tabulate==0.9.0 tblib==3.0.0 @@ -571,13 +584,15 @@ tqdm==4.66.5 traitlets==5.14.3 trio==0.26.2 trio-websocket==0.11.1 +triton==3.0.0 -e examples/experimental/dagster-airlift/examples/tutorial-example -e examples/tutorial_notebook_assets -twilio==9.2.3 +twilio==9.2.4 twine==1.15.0 typeguard==4.3.0 typepy==1.3.2 -typer==0.12.4 +typer==0.12.5 +types-awscrt==0.21.2 types-backports==0.1.3 types-certifi==2021.10.8.3 types-cffi==1.16.0.20240331 @@ -592,6 +607,7 @@ types-python-dateutil==2.9.0.20240821 types-pytz==2024.1.0.20240417 types-pyyaml==6.0.12.20240808 types-requests==2.31.0.6 +types-s3transfer==0.10.1 types-setuptools==73.0.0.20240822 types-simplejson==3.19.0.20240801 types-six==1.16.21.20240513 @@ -606,7 +622,7 @@ tzdata==2024.1 tzlocal==5.2 uc-micro-py==1.0.3 unicodecsv==0.14.1 -universal-pathlib==0.2.2 +universal-pathlib==0.2.3 uri-template==1.3.0 uritemplate==4.1.1 urllib3==1.26.19 @@ -616,7 +632,7 @@ uvloop==0.20.0 vine==5.1.0 virtualenv==20.25.0 wandb==0.17.7 -watchdog==4.0.2 +watchdog==5.0.0 watchfiles==0.23.0 wcwidth==0.2.13 webcolors==24.8.0 @@ -640,4 +656,4 @@ xmltodict==0.12.0 xyzservices==2024.6.0 yarl==1.9.4 zict==3.0.0 -zipp==3.20.0 +zipp==3.20.1 diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes/__init__.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes/__init__.py index 1da771cf018c6..e513f5cc16adf 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/pipes/__init__.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes/__init__.py @@ -1,4 +1,4 @@ -from dagster_aws.pipes.clients import PipesGlueClient, PipesLambdaClient +from dagster_aws.pipes.clients import PipesECSClient, PipesGlueClient, PipesLambdaClient from dagster_aws.pipes.context_injectors import ( PipesLambdaEventContextInjector, PipesS3ContextInjector, @@ -12,6 +12,7 @@ __all__ = [ "PipesGlueClient", "PipesLambdaClient", + "PipesECSClient", "PipesS3ContextInjector", "PipesLambdaEventContextInjector", "PipesS3MessageReader", diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/__init__.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/__init__.py index 3495d649d9390..b7625af2e2cfb 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/__init__.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/__init__.py @@ -1,4 +1,5 @@ +from dagster_aws.pipes.clients.ecs import PipesECSClient from dagster_aws.pipes.clients.glue import PipesGlueClient from dagster_aws.pipes.clients.lambda_ import PipesLambdaClient -__all__ = ["PipesGlueClient", "PipesLambdaClient"] +__all__ = ["PipesGlueClient", "PipesLambdaClient", "PipesECSClient"] diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/ecs.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/ecs.py new file mode 100644 index 0000000000000..667f03a47d77d --- /dev/null +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/ecs.py @@ -0,0 +1,221 @@ +from pprint import pformat +from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast + +import boto3 +import botocore +import dagster._check as check +from dagster import PipesClient +from dagster._annotations import experimental +from dagster._core.definitions.resource_annotation import TreatAsResourceParam +from dagster._core.errors import DagsterExecutionInterruptedError +from dagster._core.execution.context.compute import OpExecutionContext +from dagster._core.pipes.client import ( + PipesClientCompletedInvocation, + PipesContextInjector, + PipesMessageReader, +) +from dagster._core.pipes.utils import PipesEnvContextInjector, open_pipes_session + +from dagster_aws.pipes.message_readers import PipesCloudWatchMessageReader + +if TYPE_CHECKING: + from mypy_boto3_ecs.client import ECSClient + from mypy_boto3_ecs.type_defs import RunTaskRequestRequestTypeDef + + +@experimental +class PipesECSClient(PipesClient, TreatAsResourceParam): + """A pipes client for running AWS ECS tasks. + + Args: + client (Optional[boto3.client]): The boto ECS client used to launch the ECS task + context_injector (Optional[PipesContextInjector]): A context injector to use to inject + context into the ECS task. Defaults to :py:class:`PipesEnvContextInjector`. + message_reader (Optional[PipesMessageReader]): A message reader to use to read messages + from the ECS task. Defaults to :py:class:`PipesCloudWatchMessageReader`. + forward_termination (bool): Whether to cancel the ECS task when the Dagster process receives a termination signal. + """ + + def __init__( + self, + client: Optional[boto3.client] = None, + context_injector: Optional[PipesContextInjector] = None, + message_reader: Optional[PipesMessageReader] = None, + forward_termination: bool = True, + ): + self._client: "ECSClient" = client or boto3.client("ecs") + self._context_injector = context_injector or PipesEnvContextInjector() + self._message_reader = message_reader or PipesCloudWatchMessageReader() + self.forward_termination = check.bool_param(forward_termination, "forward_termination") + + @classmethod + def _is_dagster_maintained(cls) -> bool: + return True + + def run( + self, + *, + context: OpExecutionContext, + run_task_params: "RunTaskRequestRequestTypeDef", + extras: Optional[Dict[str, Any]] = None, + ) -> PipesClientCompletedInvocation: + """Run ECS tasks, enriched with the pipes protocol. + + Args: + context (OpExecutionContext): The context of the currently executing Dagster op or asset. + run_task_params (dict): Parameters for the ``run_task`` boto3 ECS client call. + Must contain ``taskDefinition`` key. + See `Boto3 API Documentation `_ + extras (Optional[Dict[str, Any]]): Additional information to pass to the pipes session. + + Returns: + PipesClientCompletedInvocation: Wrapper containing results reported by the external + process. + """ + with open_pipes_session( + context=context, + message_reader=self._message_reader, + context_injector=self._context_injector, + extras=extras, + ) as session: + params = run_task_params + + task_definition = params["taskDefinition"] + cluster = params.get("cluster") + + overrides = cast(dict, params.get("overrides") or {}) + overrides["containerOverrides"] = overrides.get("containerOverrides", []) + + # get all containers from task definition + task_definition_response = self._client.describe_task_definition( + taskDefinition=task_definition + ) + + log_configurations = { + container["name"]: container.get("logConfiguration") + for container in task_definition_response["taskDefinition"]["containerDefinitions"] + } + + all_container_names = { + container["name"] + for container in task_definition_response["taskDefinition"]["containerDefinitions"] + } + + container_names_with_overrides = { + container_override["name"] for container_override in overrides["containerOverrides"] + } + + pipes_args = session.get_bootstrap_env_vars() + + # set env variables for every container in the taskDefinition + # respecting current overrides provided by the user + + environment_overrides = [ + { + "name": k, + "value": v, + } + for k, v in pipes_args.items() + ] + + # set environment variables for existing overrides + + for container_override in overrides["containerOverrides"]: + container_override["environment"] = container_override.get("environment", []) + container_override["environment"].extend(environment_overrides) + + # set environment variables for containers that are not in the overrides + for container_name in all_container_names - container_names_with_overrides: + overrides["containerOverrides"].append( + { + "name": container_name, + "environment": environment_overrides, + } + ) + + params["overrides"] = ( + overrides # assign in case overrides was created here as an empty dict + ) + + response = self._client.run_task(**params) + + tasks: List[str] = [task["taskArn"] for task in response["tasks"]] + + try: + response = self._wait_for_tasks_completion(tasks=tasks, cluster=cluster) + + # collect logs from all containers + for task in response["tasks"]: + task_id = task["taskArn"].split("/")[-1] + + for container in task["containers"]: + if log_config := log_configurations.get(container["name"]): + if log_config["logDriver"] == "awslogs": + log_group = log_config["options"]["awslogs-group"] + + # stream name is combined from: prefix, container name, task id + log_stream = f"{log_config['options']['awslogs-stream-prefix']}/{container['name']}/{task_id}" + + if isinstance(self._message_reader, PipesCloudWatchMessageReader): + self._message_reader.consume_cloudwatch_logs( + log_group, + log_stream, + start_time=int(task["createdAt"].timestamp() * 1000), + ) + else: + context.log.warning( + f"[pipes] Unsupported log driver {log_config['logDriver']} for container {container['name']} in task {task['taskArn']}. Dagster Pipes won't be able to receive messages from this container." + ) + + # check for failed containers + failed_containers = {} + + for task in response["tasks"]: + for container in task["containers"]: + if container.get("exitCode") not in (0, None): + failed_containers[container["runtimeId"]] = container.get("exitCode") + + if failed_containers: + raise RuntimeError( + f"Some ECS containers finished with non-zero exit code:\n{pformat(list(failed_containers.keys()))}" + ) + + except DagsterExecutionInterruptedError: + if self.forward_termination: + context.log.warning( + "[pipes] Dagster process interrupted, terminating ECS tasks" + ) + self._terminate_tasks(context=context, tasks=tasks, cluster=cluster) + raise + + context.log.info(f"[pipes] ECS tasks {tasks} completed") + + return PipesClientCompletedInvocation(session) + + def _wait_for_tasks_completion( + self, tasks: List[str], cluster: Optional[str] = None + ) -> Dict[str, Any]: + waiter = self._client.get_waiter("tasks_stopped") + + params: Dict[str, Any] = {"tasks": tasks} + + if cluster: + params["cluster"] = cluster + + waiter.wait(**params) + return self._client.describe_tasks(**params) + + def _terminate_tasks( + self, context: OpExecutionContext, tasks: List[str], cluster: Optional[str] = None + ): + for task in tasks: + try: + self._client.stop_task( + cluster=cluster, + task=task, + reason="Dagster process was interrupted", + ) + except botocore.exceptions.ClientError as e: + context.log.warning( + f"[pipes] Couldn't stop ECS task {task} in cluster {cluster}:\n{e}" + ) diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/fake_ecs.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/fake_ecs.py new file mode 100644 index 0000000000000..fb2fa09a3caae --- /dev/null +++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/fake_ecs.py @@ -0,0 +1,236 @@ +import sys +import time +import uuid +from dataclasses import dataclass +from datetime import datetime +from subprocess import PIPE, Popen +from typing import Dict, List, Optional, cast + +import boto3 + + +@dataclass +class SimulatedTaskRun: + popen: Popen + cluster: str + task_arn: str + log_group: str + log_stream: str + created_at: datetime + runtime_id: str + stopped_reason: Optional[str] = None + stopped: bool = False + logs_uploaded: bool = False + + +class LocalECSMockClient: + def __init__(self, ecs_client: boto3.client, cloudwatch_client: boto3.client): + self.ecs_client = ecs_client + self.cloudwatch_client = cloudwatch_client + + self._task_runs: Dict[ + str, SimulatedTaskRun + ] = {} # mapping of TaskDefinitionArn to TaskDefinition + + def get_waiter(self, waiter_name: str): + return WaiterMock(self, waiter_name) + + def register_task_definition(self, **kwargs): + return self.ecs_client.register_task_definition(**kwargs) + + def describe_task_definition(self, **kwargs): + response = self.ecs_client.describe_task_definition(**kwargs) + assert ( + len(response["taskDefinition"]["containerDefinitions"]) == 1 + ), "Only 1 container is supported in tests" + # unlike real ECS, moto doesn't use cloudwatch logging by default + # so let's add it here + response["taskDefinition"]["containerDefinitions"][0]["logConfiguration"] = ( + response["taskDefinition"]["containerDefinitions"][0].get("logConfiguration") + or { + "logDriver": "awslogs", + "options": { + "awslogs-group": f"{response['taskDefinition']['taskDefinitionArn']}", # this value doesn't really matter + "awslogs-stream-prefix": "ecs", + }, + } + ) + return response + + def run_task(self, **kwargs): + response = self.ecs_client.run_task(**kwargs) + + task_arn = response["tasks"][0]["taskArn"] + task_definition_arn = response["tasks"][0]["taskDefinitionArn"] + + task_definition = self.describe_task_definition(taskDefinition=task_definition_arn)[ + "taskDefinition" + ] + + assert ( + len(task_definition["containerDefinitions"]) == 1 + ), "Only 1 container is supported in tests" + + # execute in a separate process + command = task_definition["containerDefinitions"][0]["command"] + + assert ( + command[0] == sys.executable + ), "Only the current Python interpreter is supported in tests" + + created_at = datetime.now() + + popen = Popen( + command, + stdout=PIPE, + stderr=PIPE, + # get env from container overrides + env={ + env["name"]: env["value"] + for env in kwargs["overrides"]["containerOverrides"][0].get("environment", []) + }, + ) + + log_group = task_definition["containerDefinitions"][0]["logConfiguration"]["options"][ + "awslogs-group" + ] + stream_prefix = task_definition["containerDefinitions"][0]["logConfiguration"]["options"][ + "awslogs-stream-prefix" + ] + container_name = task_definition["containerDefinitions"][0]["name"] + log_stream = f"{stream_prefix}/{container_name}/{task_arn.split('/')[-1]}" + + self._task_runs[task_arn] = SimulatedTaskRun( + popen=popen, + cluster=kwargs.get("cluster", "default"), + task_arn=task_arn, + log_group=log_group, + log_stream=log_stream, + created_at=created_at, + runtime_id=str(uuid.uuid4()), + ) + + return response + + def describe_tasks(self, cluster: str, tasks: List[str]): + assert len(tasks) == 1, "Only 1 task is supported in tests" + + simulated_task = cast(SimulatedTaskRun, self._task_runs[tasks[0]]) + + response = self.ecs_client.describe_tasks(cluster=cluster, tasks=tasks) + + assert len(response["tasks"]) == 1, "Only 1 task is supported in tests" + + task_definition = self.describe_task_definition( + taskDefinition=response["tasks"][0]["taskDefinitionArn"] + )["taskDefinition"] + + assert ( + len(task_definition["containerDefinitions"]) == 1 + ), "Only 1 container is supported in tests" + + # need to inject container name since moto doesn't return it + + response["tasks"][0]["containers"].append( + { + "name": task_definition["containerDefinitions"][0]["name"], + "runtimeId": simulated_task.runtime_id, + } + ) + + response["tasks"][0]["createdAt"] = simulated_task.created_at + + # check if any failed + for task in response["tasks"]: + if task["taskArn"] in self._task_runs: + simulated_task = self._task_runs[task["taskArn"]] + + if simulated_task.stopped: + task["lastStatus"] = "STOPPED" + task["stoppedReason"] = simulated_task.stopped_reason + task["containers"][0]["exitCode"] = 1 + self._upload_logs_to_cloudwatch(task["taskArn"]) + return response + + if simulated_task.popen.poll() is not None: + simulated_task.popen.wait() + # check status code + if simulated_task.popen.returncode == 0: + task["lastStatus"] = "STOPPED" + task["containers"][0]["exitCode"] = 0 + else: + task["lastStatus"] = "STOPPED" + # _, stderr = simulated_task.popen.communicate() + task["containers"][0]["exitCode"] = 1 + + self._upload_logs_to_cloudwatch(task["taskArn"]) + + else: + task["lastStatus"] = "RUNNING" + + return response + + def stop_task(self, cluster: str, task: str, reason: Optional[str] = None): + if simulated_task := self._task_runs.get(task): + simulated_task.popen.terminate() + simulated_task.stopped = True + simulated_task.stopped_reason = reason + self._upload_logs_to_cloudwatch(task) + else: + raise RuntimeError(f"Task {task} was not found") + + def _upload_logs_to_cloudwatch(self, task: str): + simulated_task = self._task_runs[task] + + if simulated_task.logs_uploaded: + return + + log_group = simulated_task.log_group + log_stream = simulated_task.log_stream + + stdout, stderr = self._task_runs[task].popen.communicate() + + try: + self.cloudwatch_client.create_log_group( + logGroupName=f"{log_group}", + ) + except self.cloudwatch_client.exceptions.ResourceAlreadyExistsException: + pass + + try: + self.cloudwatch_client.create_log_stream( + logGroupName=f"{log_group}", + logStreamName=log_stream, + ) + except self.cloudwatch_client.exceptions.ResourceAlreadyExistsException: + pass + + for out in [stderr, stdout]: + for line in out.decode().split("\n"): + if line: + self.cloudwatch_client.put_log_events( + logGroupName=f"{log_group}", + logStreamName=log_stream, + logEvents=[{"timestamp": int(time.time() * 1000), "message": str(line)}], + ) + + time.sleep(0.01) + + simulated_task.logs_uploaded = True + + +class WaiterMock: + def __init__(self, client: LocalECSMockClient, waiter_name: str): + self.client = client + self.waiter_name = waiter_name + + def wait(self, **kwargs): + if self.waiter_name == "tasks_stopped": + while True: + response = self.client.describe_tasks(**kwargs) + if all(task["lastStatus"] == "STOPPED" for task in response["tasks"]): + return + time.sleep(0.1) + + else: + raise NotImplementedError(f"Waiter {self.waiter_name} is not implemented") diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/test_pipes.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/test_pipes.py index 36bd3680bc11a..0fda9190bba9d 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/test_pipes.py +++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/test_pipes.py @@ -5,6 +5,7 @@ import os import re import shutil +import sys import textwrap import time from contextlib import contextmanager @@ -28,6 +29,7 @@ from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecordStatus from dagster_aws.pipes import ( PipesCloudWatchMessageReader, + PipesECSClient, PipesGlueClient, PipesLambdaClient, PipesLambdaLogsMessageReader, @@ -35,7 +37,9 @@ PipesS3MessageReader, ) from moto.server import ThreadedMotoServer # type: ignore # (pyright bug) +from mypy_boto3_ecs import ECSClient +from dagster_aws_tests.pipes_tests.fake_ecs import LocalECSMockClient from dagster_aws_tests.pipes_tests.fake_glue import LocalGlueMockClient from dagster_aws_tests.pipes_tests.fake_lambda import ( LOG_TAIL_LIMIT, @@ -61,6 +65,36 @@ def temp_script(script_fn: Callable[[], Any]) -> Iterator[str]: _MOTO_SERVER_URL = f"http://localhost:{_MOTO_SERVER_PORT}" +@pytest.fixture +def external_script_default_components() -> Iterator[str]: + # This is called in an external process and so cannot access outer scope + def script_fn(): + import os + import time + + from dagster_pipes import open_dagster_pipes + + with open_dagster_pipes() as context: + context.log.info("hello world") + context.report_asset_materialization( + metadata={"bar": {"raw_value": context.get_extra("bar"), "type": "md"}}, + data_version="alpha", + ) + context.report_asset_check( + "foo_check", + passed=True, + severity="WARN", + metadata={ + "meta_1": 1, + "meta_2": {"raw_value": "foo", "type": "text"}, + }, + ) + time.sleep(float(os.getenv("SLEEP_SECONDS", "0.1"))) + + with temp_script(script_fn) as script_path: + yield script_path + + @pytest.fixture def external_script() -> Iterator[str]: # This is called in an external process and so cannot access outer scope @@ -98,7 +132,7 @@ def script_fn(): @pytest.fixture -def moto_server() -> Iterator[boto3.client]: +def moto_server() -> Iterator[ThreadedMotoServer]: # We need to use the moto server for cross-process communication server = ThreadedMotoServer(port=_MOTO_SERVER_PORT) # on localhost:5000 by default server.start() @@ -107,7 +141,7 @@ def moto_server() -> Iterator[boto3.client]: @pytest.fixture -def s3_client(moto_server) -> boto3.client: +def s3_client(moto_server): client = boto3.client("s3", region_name="us-east-1", endpoint_url=_MOTO_SERVER_URL) client.create_bucket(Bucket=_S3_TEST_BUCKET) return client @@ -340,7 +374,7 @@ def script_fn(): @pytest.fixture -def glue_client(moto_server, external_s3_glue_script, s3_client) -> boto3.client: +def glue_client(moto_server, external_s3_glue_script, s3_client): client = boto3.client("glue", region_name="us-east-1", endpoint_url=_MOTO_SERVER_URL) client.create_job( Name=GLUE_JOB_NAME, @@ -357,7 +391,7 @@ def glue_client(moto_server, external_s3_glue_script, s3_client) -> boto3.client @pytest.fixture -def cloudwatch_client(moto_server, external_s3_glue_script, s3_client) -> boto3.client: +def cloudwatch_client(moto_server, external_s3_glue_script, s3_client): return boto3.client("logs", region_name="us-east-1", endpoint_url=_MOTO_SERVER_URL) @@ -460,7 +494,7 @@ def script_fn(): @pytest.fixture -def foo_asset(long_glue_job: str) -> AssetsDefinition: +def glue_asset(long_glue_job: str) -> AssetsDefinition: @asset def foo(context: AssetExecutionContext, pipes_glue_client: PipesGlueClient): results = pipes_glue_client.run( @@ -492,32 +526,28 @@ def pipes_glue_client(local_glue_mock_client, s3_client, cloudwatch_client) -> P ) -def test_glue_pipes_interruption_forwarding_asset_is_valid( - foo_asset, pipes_glue_client, local_glue_mock_client -): +def test_glue_pipes_interruption_forwarding_asset_is_valid(glue_asset, pipes_glue_client): # make sure this runs without multiprocessing first with instance_for_test() as instance: materialize( - [foo_asset], instance=instance, resources={"pipes_glue_client": pipes_glue_client} + [glue_asset], instance=instance, resources={"pipes_glue_client": pipes_glue_client} ) -def test_glue_pipes_interruption_forwarding( - long_glue_job, foo_asset, pipes_glue_client, local_glue_mock_client -): +def test_glue_pipes_interruption_forwarding(long_glue_job, glue_asset, pipes_glue_client): def materialize_asset(env, return_dict): os.environ.update(env) try: with instance_for_test() as instance: materialize( # this will be interrupted and raise an exception - [foo_asset], + [glue_asset], instance=instance, resources={"pipes_glue_client": pipes_glue_client}, ) finally: - job_run_id = next(iter(local_glue_mock_client._job_runs.keys())) # noqa - return_dict[0] = local_glue_mock_client.get_job_run(long_glue_job, job_run_id) + job_run_id = next(iter(pipes_glue_client._client._job_runs.keys())) # noqa + return_dict[0] = pipes_glue_client._client.get_job_run(long_glue_job, job_run_id) # noqa with multiprocessing.Manager() as manager: return_dict = manager.dict() @@ -540,3 +570,148 @@ def materialize_asset(env, return_dict): p.join() assert not p.is_alive() assert return_dict[0]["JobRun"]["JobRunState"] == "STOPPED" + + +@pytest.fixture +def ecs_client(moto_server, external_s3_glue_script, s3_client) -> ECSClient: + return boto3.client("ecs", region_name="us-east-1", endpoint_url=_MOTO_SERVER_URL) + + +@pytest.fixture +def ecs_cluster(ecs_client) -> str: + cluster_name = "test-cluster" + ecs_client.create_cluster(clusterName=cluster_name) + return cluster_name + + +@pytest.fixture +def ecs_task_definition(ecs_client, external_script_default_components) -> str: + task_definition = "test-task" + ecs_client.register_task_definition( + family=task_definition, + containerDefinitions=[ + { + "name": "test-container", + "image": "test-image", + "command": [sys.executable, external_script_default_components], + "memory": 512, + } + ], + ) + return task_definition + + +@pytest.fixture +def local_ecs_mock_client( + ecs_client, cloudwatch_client, ecs_cluster, ecs_task_definition +) -> LocalECSMockClient: + return LocalECSMockClient(ecs_client=ecs_client, cloudwatch_client=cloudwatch_client) + + +@pytest.fixture +def pipes_ecs_client(local_ecs_mock_client, s3_client, cloudwatch_client) -> PipesECSClient: + return PipesECSClient( + client=local_ecs_mock_client, + message_reader=PipesCloudWatchMessageReader( + client=cloudwatch_client, + ), + ) + + +@asset(check_specs=[AssetCheckSpec(name="foo_check", asset=AssetKey(["ecs_asset"]))]) +def ecs_asset(context: AssetExecutionContext, pipes_ecs_client: PipesECSClient): + return pipes_ecs_client.run( + context=context, + extras={"bar": "baz"}, + run_task_params={ + "cluster": "test-cluster", + "count": 1, + "taskDefinition": "test-task", + "launchType": "FARGATE", + "networkConfiguration": {"awsvpcConfiguration": {"subnets": ["subnet-12345678"]}}, + "overrides": { + "containerOverrides": [ + { + "name": "test-container", + "environment": [ + { + "name": "SLEEP_SECONDS", + "value": os.getenv( + "SLEEP_SECONDS", "0.1" + ), # this can be increased to test interruption + } + ], + } + ] + }, + }, + ).get_results() + + +def test_ecs_pipes( + capsys, + pipes_ecs_client: PipesECSClient, +): + with instance_for_test() as instance: + materialize( + [ecs_asset], instance=instance, resources={"pipes_ecs_client": pipes_ecs_client} + ) + mat = instance.get_latest_materialization_event(ecs_asset.key) + assert mat and mat.asset_materialization + assert isinstance(mat.asset_materialization.metadata["bar"], MarkdownMetadataValue) + assert mat.asset_materialization.metadata["bar"].value == "baz" + assert mat.asset_materialization.tags + assert mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha" + assert mat.asset_materialization.tags[DATA_VERSION_IS_USER_PROVIDED_TAG] + + captured = capsys.readouterr() + assert re.search(r"dagster - INFO - [^\n]+ - hello world\n", captured.err, re.MULTILINE) + + asset_check_executions = instance.event_log_storage.get_asset_check_execution_history( + check_key=AssetCheckKey(ecs_asset.key, name="foo_check"), + limit=1, + ) + assert len(asset_check_executions) == 1 + assert asset_check_executions[0].status == AssetCheckExecutionRecordStatus.SUCCEEDED + + +def test_ecs_pipes_interruption_forwarding(pipes_ecs_client: PipesECSClient): + def materialize_asset(env, return_dict): + os.environ.update(env) + try: + with instance_for_test() as instance: + materialize( # this will be interrupted and raise an exception + [ecs_asset], + instance=instance, + resources={"pipes_ecs_client": pipes_ecs_client}, + ) + finally: + assert len(pipes_ecs_client._client._task_runs) > 0 # noqa + task_arn = next(iter(pipes_ecs_client._client._task_runs.keys())) # noqa + return_dict[0] = pipes_ecs_client._client.describe_tasks( # noqa + cluster="test-cluster", tasks=[task_arn] + ) + + with multiprocessing.Manager() as manager: + return_dict = manager.dict() + + p = multiprocessing.Process( + target=materialize_asset, + args=( + {"SLEEP_SECONDS": "10"}, + return_dict, + ), + ) + p.start() + + while p.is_alive(): + # we started executing the run + # time to interrupt it! + time.sleep(4) + p.terminate() + + p.join() + assert not p.is_alive() + # breakpoint() + assert return_dict[0]["tasks"][0]["containers"][0]["exitCode"] == 1 + assert return_dict[0]["tasks"][0]["stoppedReason"] == "Dagster process was interrupted" diff --git a/python_modules/libraries/dagster-aws/setup.py b/python_modules/libraries/dagster-aws/setup.py index a48aebfe3f0d4..c00f3fadb582d 100644 --- a/python_modules/libraries/dagster-aws/setup.py +++ b/python_modules/libraries/dagster-aws/setup.py @@ -37,6 +37,7 @@ def get_version() -> str: python_requires=">=3.8,<3.13", install_requires=[ "boto3", + "boto3-stubs-lite[ecs]", f"dagster{pin}", "packaging", "requests",