Skip to content

Commit

Permalink
[KubeRay] support suspending worker groups in KubeRay autoscaler
Browse files Browse the repository at this point in the history
Signed-off-by: Rueian <[email protected]>
  • Loading branch information
rueian committed Jan 11, 2025
1 parent fa119fd commit afed960
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 0 deletions.
17 changes: 17 additions & 0 deletions python/ray/autoscaler/_private/kuberay/node_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,13 @@ def _worker_group_index(raycluster: Dict[str, Any], group_name: str) -> int:
return group_names.index(group_name)


def _worker_group_is_suspended(
raycluster: Dict[str, Any], group_index: int
) -> Optional[int]:
"""Extract the suspend field of a worker group."""
return raycluster["spec"]["workerGroupSpecs"][group_index].get("suspend", False)


def _worker_group_max_replicas(
raycluster: Dict[str, Any], group_index: int
) -> Optional[int]:
Expand Down Expand Up @@ -486,6 +493,11 @@ def _scale_request_to_patch_payload(
# Collect patches for replica counts.
for node_type, target_replicas in scale_request.desired_num_workers.items():
group_index = _worker_group_index(raycluster, node_type)
group_is_suspended = _worker_group_is_suspended(raycluster, group_index)
# The NodeProvider shouldn’t path the replicas field of a suspended
# worker group. It is KubeRay operator's responsibility instead.
if group_is_suspended:
continue
group_max_replicas = _worker_group_max_replicas(raycluster, group_index)
# Cap the replica count to maxReplicas.
if group_max_replicas is not None and group_max_replicas < target_replicas:
Expand All @@ -510,6 +522,11 @@ def _scale_request_to_patch_payload(

for node_type, workers_to_delete in deletion_groups.items():
group_index = _worker_group_index(raycluster, node_type)
group_is_suspended = _worker_group_is_suspended(raycluster, group_index)
# The NodeProvider shouldn’t delete replicas of a suspended
# worker group. It is KubeRay operator's responsibility instead.
if group_is_suspended:
continue
patch = worker_delete_patch(group_index, workers_to_delete)
patch_payload.append(patch)

Expand Down
79 changes: 79 additions & 0 deletions python/ray/tests/kuberay/test_kuberay_node_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,85 @@ def test_submit_scale_request(node_data_dict, scale_request, expected_patch_payl
assert patch_payload == expected_patch_payload


@pytest.mark.skipif(sys.platform.startswith("win"), reason="Not relevant on Windows.")
@pytest.mark.parametrize(
"node_data_dict,scale_request,expected_patch_payload",
[
(
{
"raycluster-autoscaler-head-8zsc8": NodeData(
kind="head",
type="headgroup",
replica_index=None,
ip="10.4.2.6",
status="up-to-date",
),
"raycluster-autoscaler-worker-fake-gpu-group-2qnhv": NodeData(
kind="worker",
type="fake-gpu-group",
replica_index=None,
ip="10.4.0.6",
status="up-to-date",
),
"raycluster-autoscaler-worker-small-group-dkz2r": NodeData(
kind="worker",
type="small-group",
replica_index=None,
ip="10.4.1.8",
status="up-to-date",
),
"raycluster-autoscaler-worker-small-group-lbfm4": NodeData(
kind="worker",
type="small-group",
replica_index=None,
ip="10.4.0.5",
status="up-to-date",
),
},
ScaleRequest(
desired_num_workers={
"small-group": 0, # Delete 2
"gpu-group": 1, # Don't touch
"blah-group": 5, # Create 5
},
workers_to_delete={
"raycluster-autoscaler-worker-small-group-lbfm4",
"raycluster-autoscaler-worker-small-group-dkz2r",
},
),
[
# The small-group is suspended. Therefore, no actual patch to it.
{
"op": "replace",
"path": "/spec/workerGroupSpecs/3/replicas",
"value": 5,
}
],
),
],
)
def test_submit_scale_request_with_suspended_groups(
node_data_dict, scale_request, expected_patch_payload
):
"""Test the KubeRayNodeProvider's RayCluster patch payload given a dict
of current node counts and a scale request.
"""
raycluster = get_basic_ray_cr()
# Suspend the first worker group
raycluster["spec"]["workerGroupSpecs"][0]["suspend"] = True
# Add another worker group for the sake of this test.
blah_group = copy.deepcopy(raycluster["spec"]["workerGroupSpecs"][1])
blah_group["groupName"] = "blah-group"
raycluster["spec"]["workerGroupSpecs"].append(blah_group)
with mock.patch.object(KubeRayNodeProvider, "__init__", return_value=None):
kr_node_provider = KubeRayNodeProvider(provider_config={}, cluster_name="fake")
kr_node_provider.node_data_dict = node_data_dict
patch_payload = kr_node_provider._scale_request_to_patch_payload(
scale_request=scale_request, raycluster=raycluster
)
assert patch_payload == expected_patch_payload


@pytest.mark.parametrize("node_set", [{"A", "B", "C", "D", "E"}])
@pytest.mark.parametrize("cpu_workers_to_delete", ["A", "Z"])
@pytest.mark.parametrize("gpu_workers_to_delete", ["B", "Y"])
Expand Down

0 comments on commit afed960

Please sign in to comment.