From 5ea7fb2e8d9cc88279cc3a1a025a34bef761d29c Mon Sep 17 00:00:00 2001 From: jerryzhuang Date: Mon, 30 Sep 2024 21:06:47 +1000 Subject: [PATCH] fix: ignore instanceType when selecting preferred nodes Signed-off-by: jerryzhuang --- api/v1alpha1/workspace_types.go | 3 +- pkg/controllers/workspace_controller.go | 33 +++--- pkg/controllers/workspace_controller_test.go | 110 ++++++++++++++++--- pkg/utils/test/testUtils.go | 23 ++++ 4 files changed, 138 insertions(+), 31 deletions(-) diff --git a/api/v1alpha1/workspace_types.go b/api/v1alpha1/workspace_types.go index 64cb9967a..35069423f 100644 --- a/api/v1alpha1/workspace_types.go +++ b/api/v1alpha1/workspace_types.go @@ -34,8 +34,7 @@ type ResourceSpec struct { LabelSelector *metav1.LabelSelector `json:"labelSelector"` // PreferredNodes is an optional node list specified by the user. - // If a node in the list does not have the required labels or - // the required instanceType, it will be ignored. + // If a node in the list does not have the required labels, it will be ignored. // +optional PreferredNodes []string `json:"preferredNodes,omitempty"` } diff --git a/pkg/controllers/workspace_controller.go b/pkg/controllers/workspace_controller.go index d193a7d8c..25ef22faa 100644 --- a/pkg/controllers/workspace_controller.go +++ b/pkg/controllers/workspace_controller.go @@ -40,6 +40,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/client-go/tools/record" "k8s.io/klog/v2" "k8s.io/utils/clock" @@ -385,7 +386,7 @@ func (c *WorkspaceReconciler) applyWorkspaceResource(ctx context.Context, wObj * } } - // Find all nodes that match the labelSelector and instanceType, they are not necessarily created by machines/nodeClaims. + // Find all nodes that meet the requirements, they are not necessarily created by machines/nodeClaims. validNodes, err := c.getAllQualifiedNodes(ctx, wObj) if err != nil { return err @@ -474,7 +475,6 @@ func (c *WorkspaceReconciler) applyWorkspaceResource(ctx context.Context, wObj * return nil } -// getAllQualifiedNodes returns all nodes that match the labelSelector and instanceType. func (c *WorkspaceReconciler) getAllQualifiedNodes(ctx context.Context, wObj *kaitov1alpha1.Workspace) ([]*corev1.Node, error) { var qualifiedNodes []*corev1.Node @@ -488,33 +488,36 @@ func (c *WorkspaceReconciler) getAllQualifiedNodes(ctx context.Context, wObj *ka return nil, nil } + preferredNodeSet := sets.New(wObj.Resource.PreferredNodes...) + for index := range nodeList.Items { nodeObj := nodeList.Items[index] - // Skip nodes that are being deleted + // skip nodes that are being deleted if nodeObj.DeletionTimestamp != nil { continue } - foundInstanceType := c.validateNodeInstanceType(ctx, wObj, lo.ToPtr(nodeObj)) + + // skip nodes that are not ready _, statusRunning := lo.Find(nodeObj.Status.Conditions, func(condition corev1.NodeCondition) bool { return condition.Type == corev1.NodeReady && condition.Status == corev1.ConditionTrue }) + if !statusRunning { + continue + } - if foundInstanceType && statusRunning { + // match the preferred node + if preferredNodeSet.Has(nodeObj.Name) { qualifiedNodes = append(qualifiedNodes, lo.ToPtr(nodeObj)) + continue } - } - return qualifiedNodes, nil -} - -// check if node has the required instanceType -func (c *WorkspaceReconciler) validateNodeInstanceType(ctx context.Context, wObj *kaitov1alpha1.Workspace, nodeObj *corev1.Node) bool { - if instanceTypeLabel, found := nodeObj.Labels[corev1.LabelInstanceTypeStable]; found { - if instanceTypeLabel != wObj.Resource.InstanceType { - return false + // match the instanceType + if nodeObj.Labels[corev1.LabelInstanceTypeStable] == wObj.Resource.InstanceType { + qualifiedNodes = append(qualifiedNodes, lo.ToPtr(nodeObj)) } } - return true + + return qualifiedNodes, nil } // createAndValidateNode creates a new node and validates status. diff --git a/pkg/controllers/workspace_controller_test.go b/pkg/controllers/workspace_controller_test.go index f71493523..ba473bc07 100644 --- a/pkg/controllers/workspace_controller_test.go +++ b/pkg/controllers/workspace_controller_test.go @@ -751,29 +751,108 @@ func TestApplyInferenceWithTemplate(t *testing.T) { } func TestGetAllQualifiedNodes(t *testing.T) { + deletedNode := corev1.Node{ + ObjectMeta: v1.ObjectMeta{ + Name: "node4", + Labels: map[string]string{ + corev1.LabelInstanceTypeStable: "Standard_NC12s_v3", + }, + DeletionTimestamp: &v1.Time{Time: time.Now()}, + }, + } + testcases := map[string]struct { callMocks func(c *test.MockClient) + workspace *v1alpha1.Workspace expectedError error + expectedNodes []string }{ "Fails to get qualified nodes because can't list nodes": { callMocks: func(c *test.MockClient) { c.On("List", mock.IsType(context.Background()), mock.IsType(&corev1.NodeList{}), mock.Anything).Return(errors.New("Failed to list nodes")) }, + workspace: test.MockWorkspaceDistributedModel, expectedError: errors.New("Failed to list nodes"), + expectedNodes: nil, }, "Gets all qualified nodes": { callMocks: func(c *test.MockClient) { nodeList := test.MockNodeList - deletedNode := corev1.Node{ - ObjectMeta: v1.ObjectMeta{ - Name: "node4", - Labels: map[string]string{ - corev1.LabelInstanceTypeStable: "Standard_NC12s_v3", + + nodeList.Items = append(nodeList.Items, deletedNode) + + relevantMap := c.CreateMapWithType(nodeList) + //insert node objects into the map + for _, obj := range test.MockNodeList.Items { + n := obj + objKey := client.ObjectKeyFromObject(&n) + + relevantMap[objKey] = &n + } + + c.On("List", mock.IsType(context.Background()), mock.IsType(&corev1.NodeList{}), mock.Anything).Return(nil) + }, + workspace: test.MockWorkspaceDistributedModel, + expectedError: nil, + expectedNodes: []string{"node1"}, + }, + "Gets all qualified nodes with preferred": { + callMocks: func(c *test.MockClient) { + nodeList := test.MockNodeList + + nodeList.Items = append(nodeList.Items, deletedNode) + + nodesFromOtherVendor := []corev1.Node{ + { + ObjectMeta: v1.ObjectMeta{ + Name: "node-p1", + Labels: map[string]string{ + corev1.LabelInstanceTypeStable: "vendor1", + }, + }, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + { + Type: corev1.NodeReady, + Status: corev1.ConditionTrue, + }, + }, + }, + }, + { + ObjectMeta: v1.ObjectMeta{ + Name: "node-p2", + Labels: map[string]string{ + corev1.LabelInstanceTypeStable: "vendor2", + }, + }, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + { + Type: corev1.NodeReady, + Status: corev1.ConditionFalse, + }, + }, + }, + }, + { + ObjectMeta: v1.ObjectMeta{ + Name: "node-p3", + Labels: map[string]string{ + corev1.LabelInstanceTypeStable: "vendor1", + }, + }, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + { + Type: corev1.NodeReady, + Status: corev1.ConditionTrue, + }, + }, }, - DeletionTimestamp: &v1.Time{Time: time.Now()}, }, } - nodeList.Items = append(nodeList.Items, deletedNode) + nodeList.Items = append(nodeList.Items, nodesFromOtherVendor...) relevantMap := c.CreateMapWithType(nodeList) //insert node objects into the map @@ -786,14 +865,15 @@ func TestGetAllQualifiedNodes(t *testing.T) { c.On("List", mock.IsType(context.Background()), mock.IsType(&corev1.NodeList{}), mock.Anything).Return(nil) }, + workspace: test.MockWorkspaceWithPreferredNodes, expectedError: nil, + expectedNodes: []string{"node1", "node-p1"}, }, } for k, tc := range testcases { t.Run(k, func(t *testing.T) { mockClient := test.NewClient() - mockWorkspace := test.MockWorkspaceDistributedModel reconciler := &WorkspaceReconciler{ Client: mockClient, Scheme: test.NewTestScheme(), @@ -802,15 +882,17 @@ func TestGetAllQualifiedNodes(t *testing.T) { tc.callMocks(mockClient) - nodes, err := reconciler.getAllQualifiedNodes(ctx, mockWorkspace) - if tc.expectedError == nil { - assert.Check(t, err == nil, "Not expected to return error") - assert.Check(t, nodes != nil, "Response node array should not be nil") - assert.Check(t, len(nodes) == 1, "One out of three nodes should be qualified") - } else { + nodes, err := reconciler.getAllQualifiedNodes(ctx, tc.workspace) + + if tc.expectedError != nil { assert.Equal(t, tc.expectedError.Error(), err.Error()) assert.Check(t, nodes == nil, "Response node array should be nil") + return } + + assert.Check(t, err == nil, "Not expected to return error") + assert.Check(t, nodes != nil, "Response node array should not be nil") + assert.Check(t, len(nodes) == len(tc.expectedNodes), "Unexpected qualified nodes") }) } } diff --git a/pkg/utils/test/testUtils.go b/pkg/utils/test/testUtils.go index d952852d9..a3eee57b4 100644 --- a/pkg/utils/test/testUtils.go +++ b/pkg/utils/test/testUtils.go @@ -47,6 +47,29 @@ var ( }, }, } + MockWorkspaceWithPreferredNodes = &v1alpha1.Workspace{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testWorkspace", + Namespace: "kaito", + }, + Resource: v1alpha1.ResourceSpec{ + Count: &gpuNodeCount, + InstanceType: "Standard_NC12s_v3", + LabelSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "apps": "test", + }, + }, + PreferredNodes: []string{"node-p1", "node-p2"}, + }, + Inference: &v1alpha1.InferenceSpec{ + Preset: &v1alpha1.PresetSpec{ + PresetMeta: v1alpha1.PresetMeta{ + Name: "test-distributed-model", + }, + }, + }, + } ) var (