diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index d216efa3..d862fddd 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -70,9 +70,9 @@ from pydantic import VERSION as PYDANTIC_VERSION if PYDANTIC_VERSION.startswith("2."): - from pydantic.v1 import Field, root_validator + from pydantic.v1 import BaseModel, Field, root_validator else: - from pydantic import Field, root_validator + from pydantic import Field, root_validator, BaseModel from slugify import slugify from tenacity import retry, stop_after_attempt, wait_fixed, wait_random @@ -126,6 +126,7 @@ taskRoleArn: "{{ task_role_arn }}" tags: "{{ labels }}" taskDefinition: "{{ task_definition_arn }}" +capacityProviderStrategy: "{{ capacity_provider_strategy }}" """ # Create task run retry settings @@ -245,6 +246,16 @@ def mask_api_key(task_run_request): ) +class CapacityProvider(BaseModel): + """ + The capacity provider strategy to use when running the task. + """ + + capacityProvider: str + weight: int + base: int + + class ECSJobConfiguration(BaseJobConfiguration): """ Job configuration for an ECS worker. @@ -425,6 +436,14 @@ class ECSVariables(BaseVariables): ), ) ) + capacity_provider_strategy: Optional[List[CapacityProvider]] = Field( + default_factory=list, + description=( + "The capacity provider strategy to use when running the task. " + "If a capacity provider strategy is specified, the selected launch" + " type will be ignored." + ), + ) image: Optional[str] = Field( default=None, description=( @@ -1449,17 +1468,24 @@ def _prepare_task_run_request( task_run_request.setdefault("taskDefinition", task_definition_arn) assert task_run_request["taskDefinition"] == task_definition_arn + capacityProviderStrategy = task_run_request.get("capacityProviderStrategy") - if task_run_request.get("launchType") == "FARGATE_SPOT": + if capacityProviderStrategy: + # Should not be provided at all if capacityProviderStrategy is set, see https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_RunTask.html#ECS-RunTask-request-capacityProviderStrategy # noqa + self._logger.warning( + "Found capacityProviderStrategy. " + "Removing launchType from task run request." + ) + task_run_request.pop("launchType", None) + + elif task_run_request.get("launchType") == "FARGATE_SPOT": # Should not be provided at all for FARGATE SPOT task_run_request.pop("launchType", None) # A capacity provider strategy is required for FARGATE SPOT - task_run_request.setdefault( - "capacityProviderStrategy", - [{"capacityProvider": "FARGATE_SPOT", "weight": 1}], - ) - + task_run_request["capacityProviderStrategy"] = [ + {"capacityProvider": "FARGATE_SPOT", "weight": 1} + ] overrides = task_run_request.get("overrides", {}) container_overrides = overrides.get("containerOverrides", []) diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index dc329311..c5c78179 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -506,6 +506,7 @@ async def test_launch_types( # Instead, it requires a capacity provider strategy but this is not supported # by moto and is not present on the task even when provided so we assert on the # mock call to ensure it is sent + assert mock_run_task.call_args[0][1].get("capacityProviderStrategy") == [ {"capacityProvider": "FARGATE_SPOT", "weight": 1} ] @@ -2050,6 +2051,41 @@ async def test_user_defined_environment_variables_in_task_definition_template( ] +@pytest.mark.usefixtures("ecs_mocks") +async def test_user_defined_capacity_provider_strategy( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + capacity_provider_strategy=[ + {"base": 0, "weight": 1, "capacityProvider": "r6i.large"} + ], + ) + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + # Capture the task run call because moto does not track + # 'capacityProviderStrategy' + original_run_task = worker._create_task_run + mock_run_task = MagicMock(side_effect=original_run_task) + worker._create_task_run = mock_run_task + + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + assert not task.get("launchType") + # Instead, it requires a capacity provider strategy but this is not supported + # by moto and is not present on the task even when provided so we assert on the + # mock call to ensure it is sent + assert mock_run_task.call_args[0][1].get("capacityProviderStrategy") == [ + {"base": 0, "weight": 1, "capacityProvider": "r6i.large"}, + ] + + @pytest.mark.usefixtures("ecs_mocks") async def test_user_defined_environment_variables_in_task_run_request_template( aws_credentials: AwsCredentials, flow_run: FlowRun