From b635b2ffb0460e5ca05c59e4b19b2f8e0c670a7b Mon Sep 17 00:00:00 2001 From: Bangqi Zhu Date: Mon, 23 Dec 2024 15:40:25 -0800 Subject: [PATCH] RAG ensure service Signed-off-by: Bangqi Zhu --- .../controllers/ragengine_controller.go | 49 +++++++++++ .../controllers/ragengine_controller_test.go | 81 +++++++++++++++++++ pkg/ragengine/manifests/manifests.go | 39 ++++++++- 3 files changed, 168 insertions(+), 1 deletion(-) diff --git a/pkg/ragengine/controllers/ragengine_controller.go b/pkg/ragengine/controllers/ragengine_controller.go index 349b5d1c1..8e6c8aa55 100644 --- a/pkg/ragengine/controllers/ragengine_controller.go +++ b/pkg/ragengine/controllers/ragengine_controller.go @@ -127,6 +127,14 @@ func (c *RAGEngineReconciler) addRAGEngine(ctx context.Context, ragEngineObj *ka } return reconcile.Result{}, err } + if err := c.ensureServices(ctx, ragEngineObj); err != nil { + if updateErr := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeSucceeded, metav1.ConditionFalse, + "ragEngineFailed", err.Error()); updateErr != nil { + klog.ErrorS(updateErr, "failed to update ragEngine status", "ragEngine", klog.KObj(ragEngineObj)) + return reconcile.Result{}, updateErr + } + return reconcile.Result{}, err + } if err = c.applyRAG(ctx, ragEngineObj); err != nil { if updateErr := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeSucceeded, metav1.ConditionFalse, "ragengineFailed", err.Error()); updateErr != nil { @@ -144,6 +152,47 @@ func (c *RAGEngineReconciler) addRAGEngine(ctx context.Context, ragEngineObj *ka return reconcile.Result{}, nil } +func (c *RAGEngineReconciler) ensureServices(ctx context.Context, ragObj *kaitov1alpha1.RAGEngine) error { + serviceType := corev1.ServiceTypeClusterIP + ragAnnotations := ragObj.GetAnnotations() + + if len(ragAnnotations) != 0 { + val, found := ragAnnotations[kaitov1alpha1.AnnotationEnableLB] + if found && val == "True" { + serviceType = corev1.ServiceTypeLoadBalancer + } + } + + // Ensure Service for index and query + // TODO: ServiceName currently does not accept customization for now + + queryServiceName := ragObj.Name + + if err := c.ensureService(ctx, ragObj, queryServiceName, serviceType); err != nil { + return err + } + + return nil +} + +func (c *RAGEngineReconciler) ensureService(ctx context.Context, ragObj *kaitov1alpha1.RAGEngine, serviceName string, serviceType corev1.ServiceType) error { + existingSVC := &corev1.Service{} + err := resources.GetResource(ctx, serviceName, ragObj.Namespace, c.Client, existingSVC) + if err != nil { + if !apierrors.IsNotFound(err) { + return err + } + } else { + return nil + } + serviceObj := manifests.GenerateRAGServiceManifest(ctx, ragObj, serviceName, serviceType) + if err := resources.CreateResource(ctx, serviceObj, c.Client); err != nil { + return err + } + + return nil +} + func (c *RAGEngineReconciler) applyRAG(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) error { var err error func() { diff --git a/pkg/ragengine/controllers/ragengine_controller_test.go b/pkg/ragengine/controllers/ragengine_controller_test.go index 407c4665d..fa6901048 100644 --- a/pkg/ragengine/controllers/ragengine_controller_test.go +++ b/pkg/ragengine/controllers/ragengine_controller_test.go @@ -636,3 +636,84 @@ func TestApplyRAG(t *testing.T) { }) } } + +func TestEnsureService(t *testing.T) { + test.RegisterTestModel() + testcases := map[string]struct { + callMocks func(c *test.MockClient) + expectedError error + ragengine v1alpha1.RAGEngine + verifyCalls func(c *test.MockClient) + }{ + + "Existing service is found for RAGEngine": { + callMocks: func(c *test.MockClient) { + c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(nil) + }, + expectedError: nil, + ragengine: *test.MockRAGEngineWithPreset, + verifyCalls: func(c *test.MockClient) { + c.AssertNumberOfCalls(t, "List", 0) + c.AssertNumberOfCalls(t, "Create", 0) + c.AssertNumberOfCalls(t, "Get", 1) + c.AssertNumberOfCalls(t, "Delete", 0) + c.AssertNumberOfCalls(t, "Update", 0) + }, + }, + + "Service creation fails": { + callMocks: func(c *test.MockClient) { + c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(test.NotFoundError()) + c.On("Create", mock.IsType(context.Background()), mock.IsType(&corev1.Service{}), mock.Anything).Return(errors.New("cannot create service")) + }, + expectedError: errors.New("cannot create service"), + ragengine: *test.MockRAGEngineWithPreset, + verifyCalls: func(c *test.MockClient) { + c.AssertNumberOfCalls(t, "List", 0) + c.AssertNumberOfCalls(t, "Create", 4) + c.AssertNumberOfCalls(t, "Get", 4) + c.AssertNumberOfCalls(t, "Delete", 0) + c.AssertNumberOfCalls(t, "Update", 0) + }, + }, + + "Successfully creates a new service": { + callMocks: func(c *test.MockClient) { + c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(test.NotFoundError()) + c.On("Create", mock.IsType(context.Background()), mock.IsType(&corev1.Service{}), mock.Anything).Return(nil) + }, + expectedError: nil, + ragengine: *test.MockRAGEngineWithPreset, + verifyCalls: func(c *test.MockClient) { + c.AssertNumberOfCalls(t, "List", 0) + c.AssertNumberOfCalls(t, "Create", 1) + c.AssertNumberOfCalls(t, "Get", 4) + c.AssertNumberOfCalls(t, "Delete", 0) + c.AssertNumberOfCalls(t, "Update", 0) + }, + }, + } + + for k, tc := range testcases { + t.Run(k, func(t *testing.T) { + mockClient := test.NewClient() + tc.callMocks(mockClient) + + reconciler := &RAGEngineReconciler{ + Client: mockClient, + Scheme: test.NewTestScheme(), + } + ctx := context.Background() + + err := reconciler.ensureServices(ctx, &tc.ragengine) + if tc.expectedError == nil { + assert.Check(t, err == nil, "Not expected to return error") + } else { + assert.Equal(t, tc.expectedError.Error(), err.Error()) + } + if tc.verifyCalls != nil { + tc.verifyCalls(mockClient) + } + }) + } +} diff --git a/pkg/ragengine/manifests/manifests.go b/pkg/ragengine/manifests/manifests.go index 254de2e96..25b832fd0 100644 --- a/pkg/ragengine/manifests/manifests.go +++ b/pkg/ragengine/manifests/manifests.go @@ -151,7 +151,7 @@ func RAGSetEnv(ragEngineObj *kaitov1alpha1.RAGEngine) []corev1.EnvVar { envs = append(envs, stoageEnv) inferenceServiceURL := ragEngineObj.Spec.InferenceService.URL inferenceServiceURLEnv := corev1.EnvVar{ - Name: "INFERENCE_URL", + Name: "LLM_INFERENCE_URL", Value: inferenceServiceURL, } envs = append(envs, inferenceServiceURLEnv) @@ -165,3 +165,40 @@ func RAGSetEnv(ragEngineObj *kaitov1alpha1.RAGEngine) []corev1.EnvVar { } return envs } + +func GenerateRAGServiceManifest(ctx context.Context, ragObj *kaitov1alpha1.RAGEngine, serviceName string, serviceType corev1.ServiceType) *corev1.Service { + selector := map[string]string{ + kaitov1alpha1.LabelRAGEngineName: ragObj.Name, + } + + servicePorts := []corev1.ServicePort{ + { + Name: "http", + Protocol: corev1.ProtocolTCP, + Port: 80, + TargetPort: intstr.FromInt32(5000), + }, + } + + return &corev1.Service{ + ObjectMeta: v1.ObjectMeta{ + Name: serviceName, + Namespace: ragObj.Namespace, + OwnerReferences: []v1.OwnerReference{ + { + APIVersion: kaitov1alpha1.GroupVersion.String(), + Kind: "RAGEngine", + UID: ragObj.UID, + Name: ragObj.Name, + Controller: &controller, + }, + }, + }, + Spec: corev1.ServiceSpec{ + Type: serviceType, + Ports: servicePorts, + Selector: selector, + PublishNotReadyAddresses: true, + }, + } +}