diff --git a/python_modules/libraries/dagster-aws/dagster_aws/ecs/tasks.py b/python_modules/libraries/dagster-aws/dagster_aws/ecs/tasks.py index 175562d54da6c..ff6e7cd16a813 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/ecs/tasks.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/ecs/tasks.py @@ -357,38 +357,38 @@ def get_task_kwargs_from_current_task( cluster, task, ): - enis = [] - subnets = [] - for attachment in task["attachments"]: - if attachment["type"] == "ElasticNetworkInterface": - for detail in attachment["details"]: - if detail["name"] == "subnetId": - subnets.append(detail["value"]) - if detail["name"] == "networkInterfaceId": - enis.append(ec2.NetworkInterface(detail["value"])) - - public_ip = False - security_groups = [] - for eni in enis: - if (eni.association_attribute or {}).get("PublicIp"): - public_ip = True - for group in eni.groups: - security_groups.append(group["GroupId"]) - - run_task_kwargs = { - "cluster": cluster, - "networkConfiguration": { - "awsvpcConfiguration": { - "subnets": subnets, - "assignPublicIp": "ENABLED" if public_ip else "DISABLED", - "securityGroups": security_groups, - }, - }, - } + run_task_kwargs = {"cluster": cluster} if not task.get("capacityProviderStrategy"): run_task_kwargs["launchType"] = task.get("launchType") or "FARGATE" else: run_task_kwargs["capacityProviderStrategy"] = task.get("capacityProviderStrategy") + if run_task_kwargs["launchType"] != "EXTERNAL": + enis = [] + subnets = [] + for attachment in task["attachments"]: + if attachment["type"] == "ElasticNetworkInterface": + for detail in attachment["details"]: + if detail["name"] == "subnetId": + subnets.append(detail["value"]) + if detail["name"] == "networkInterfaceId": + enis.append(ec2.NetworkInterface(detail["value"])) + + public_ip = False + security_groups = [] + + for eni in enis: + if (eni.association_attribute or {}).get("PublicIp"): + public_ip = True + for group in eni.groups: + security_groups.append(group["GroupId"]) + + aws_vpc_config = { + "subnets": subnets, + "assignPublicIp": "ENABLED" if public_ip else "DISABLED", + "securityGroups": security_groups, + } + run_task_kwargs["networkConfiguration"] = {"awsvpcConfiguration": aws_vpc_config} + return run_task_kwargs diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/ecs_tests/launcher_tests/test_launching.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/ecs_tests/launcher_tests/test_launching.py index b9bd6ff72a172..351b3ba90ad00 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws_tests/ecs_tests/launcher_tests/test_launching.py +++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/ecs_tests/launcher_tests/test_launching.py @@ -1221,3 +1221,59 @@ def test_custom_launcher( == WorkerStatus.RUNNING ) ecs.stop_task(task=task_arn) + + +def test_external_launch_type( + ecs, + instance_cm, + workspace, + external_job, + job, +): + container_name = "external" + + task_definition = ecs.register_task_definition( + family="external", + containerDefinitions=[{"name": container_name, "image": "dagster:first"}], + networkMode="bridge", + memory="512", + cpu="256", + )["taskDefinition"] + + assert task_definition["networkMode"] == "bridge" + + task_definition_arn = task_definition["taskDefinitionArn"] + + # You can provide a family or a task definition ARN + with instance_cm( + { + "task_definition": task_definition_arn, + "container_name": container_name, + "run_task_kwargs": { + "launchType": "EXTERNAL", + }, + } + ) as instance: + run = instance.create_run_for_job( + job, + external_job_origin=external_job.get_external_origin(), + job_code_origin=external_job.get_python_origin(), + ) + + initial_task_definitions = ecs.list_task_definitions()["taskDefinitionArns"] + initial_tasks = ecs.list_tasks()["taskArns"] + + instance.launch_run(run.run_id, workspace) + + # A new task definition is not created + assert ecs.list_task_definitions()["taskDefinitionArns"] == initial_task_definitions + + # A new task is launched + tasks = ecs.list_tasks()["taskArns"] + + assert len(tasks) == len(initial_tasks) + 1 + task_arn = next(iter(set(tasks).difference(initial_tasks))) + task = ecs.describe_tasks(tasks=[task_arn])["tasks"][0] + + assert task["taskDefinitionArn"] == task_definition["taskDefinitionArn"] + assert task["launchType"] == "EXTERNAL"