diff --git a/pkg/ragengine/controllers/ragengine_controller.go b/pkg/ragengine/controllers/ragengine_controller.go index 349b5d1c1..20b374a8b 100644 --- a/pkg/ragengine/controllers/ragengine_controller.go +++ b/pkg/ragengine/controllers/ragengine_controller.go @@ -127,6 +127,14 @@ func (c *RAGEngineReconciler) addRAGEngine(ctx context.Context, ragEngineObj *ka } return reconcile.Result{}, err } + if err := c.ensureServices(ctx, ragEngineObj); err != nil { + if updateErr := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeSucceeded, metav1.ConditionFalse, + "ragEngineFailed", err.Error()); updateErr != nil { + klog.ErrorS(updateErr, "failed to update ragEngine status", "ragEngine", klog.KObj(ragEngineObj)) + return reconcile.Result{}, updateErr + } + return reconcile.Result{}, err + } if err = c.applyRAG(ctx, ragEngineObj); err != nil { if updateErr := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeSucceeded, metav1.ConditionFalse, "ragengineFailed", err.Error()); updateErr != nil { @@ -144,6 +152,56 @@ func (c *RAGEngineReconciler) addRAGEngine(ctx context.Context, ragEngineObj *ka return reconcile.Result{}, nil } +func (c *RAGEngineReconciler) ensureServices(ctx context.Context, ragObj *kaitov1alpha1.RAGEngine) error { + serviceType := corev1.ServiceTypeClusterIP + ragAnnotations := ragObj.GetAnnotations() + + if len(ragAnnotations) != 0 { + val, found := ragAnnotations[kaitov1alpha1.AnnotationEnableLB] + if found && val == "True" { + serviceType = corev1.ServiceTypeLoadBalancer + } + } + + // Ensure QueryService + queryServiceName := ragObj.Spec.QueryServiceName + if queryServiceName == "" { + queryServiceName = fmt.Sprintf("%s-query", ragObj.Name) + } + if err := c.ensureService(ctx, ragObj, queryServiceName, serviceType, "query"); err != nil { + return err + } + + // Ensure IndexService + indexServiceName := ragObj.Spec.IndexServiceName + if indexServiceName == "" { + indexServiceName = fmt.Sprintf("%s-index", ragObj.Name) + } + if err := c.ensureService(ctx, ragObj, indexServiceName, serviceType, "index"); err != nil { + return err + } + + return nil +} + +func (c *RAGEngineReconciler) ensureService(ctx context.Context, ragObj *kaitov1alpha1.RAGEngine, serviceName string, serviceType corev1.ServiceType, serviceRole string) error { + existingSVC := &corev1.Service{} + err := resources.GetResource(ctx, serviceName, ragObj.Namespace, c.Client, existingSVC) + if err != nil { + if !apierrors.IsNotFound(err) { + return err + } + } else { + return nil + } + serviceObj := manifests.GenerateRAGServiceManifest(ctx, ragObj, serviceName, serviceType, serviceRole) + if err := resources.CreateResource(ctx, serviceObj, c.Client); err != nil { + return err + } + + return nil +} + func (c *RAGEngineReconciler) applyRAG(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) error { var err error func() { diff --git a/pkg/ragengine/controllers/ragengine_controller_test.go b/pkg/ragengine/controllers/ragengine_controller_test.go index 407c4665d..6edad9851 100644 --- a/pkg/ragengine/controllers/ragengine_controller_test.go +++ b/pkg/ragengine/controllers/ragengine_controller_test.go @@ -636,3 +636,84 @@ func TestApplyRAG(t *testing.T) { }) } } + +func TestEnsureService(t *testing.T) { + test.RegisterTestModel() + testcases := map[string]struct { + callMocks func(c *test.MockClient) + expectedError error + ragengine v1alpha1.RAGEngine + verifyCalls func(c *test.MockClient) + }{ + + "Existing service is found for RAGEngine": { + callMocks: func(c *test.MockClient) { + c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(nil) + }, + expectedError: nil, + ragengine: *test.MockRAGEngineWithPreset, + verifyCalls: func(c *test.MockClient) { + c.AssertNumberOfCalls(t, "List", 0) + c.AssertNumberOfCalls(t, "Create", 0) + c.AssertNumberOfCalls(t, "Get", 2) + c.AssertNumberOfCalls(t, "Delete", 0) + c.AssertNumberOfCalls(t, "Update", 0) + }, + }, + + "Service creation fails": { + callMocks: func(c *test.MockClient) { + c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(test.NotFoundError()) + c.On("Create", mock.IsType(context.Background()), mock.IsType(&corev1.Service{}), mock.Anything).Return(errors.New("cannot create service")) + }, + expectedError: errors.New("cannot create service"), + ragengine: *test.MockRAGEngineWithPreset, + verifyCalls: func(c *test.MockClient) { + c.AssertNumberOfCalls(t, "List", 0) + c.AssertNumberOfCalls(t, "Create", 4) + c.AssertNumberOfCalls(t, "Get", 4) + c.AssertNumberOfCalls(t, "Delete", 0) + c.AssertNumberOfCalls(t, "Update", 0) + }, + }, + + "Successfully creates a new service": { + callMocks: func(c *test.MockClient) { + c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(test.NotFoundError()) + c.On("Create", mock.IsType(context.Background()), mock.IsType(&corev1.Service{}), mock.Anything).Return(nil) + }, + expectedError: nil, + ragengine: *test.MockRAGEngineWithPreset, + verifyCalls: func(c *test.MockClient) { + c.AssertNumberOfCalls(t, "List", 0) + c.AssertNumberOfCalls(t, "Create", 2) + c.AssertNumberOfCalls(t, "Get", 8) + c.AssertNumberOfCalls(t, "Delete", 0) + c.AssertNumberOfCalls(t, "Update", 0) + }, + }, + } + + for k, tc := range testcases { + t.Run(k, func(t *testing.T) { + mockClient := test.NewClient() + tc.callMocks(mockClient) + + reconciler := &RAGEngineReconciler{ + Client: mockClient, + Scheme: test.NewTestScheme(), + } + ctx := context.Background() + + err := reconciler.ensureServices(ctx, &tc.ragengine) + if tc.expectedError == nil { + assert.Check(t, err == nil, "Not expected to return error") + } else { + assert.Equal(t, tc.expectedError.Error(), err.Error()) + } + if tc.verifyCalls != nil { + tc.verifyCalls(mockClient) + } + }) + } +} diff --git a/pkg/ragengine/manifests/manifests.go b/pkg/ragengine/manifests/manifests.go index 254de2e96..eede312fb 100644 --- a/pkg/ragengine/manifests/manifests.go +++ b/pkg/ragengine/manifests/manifests.go @@ -165,3 +165,52 @@ func RAGSetEnv(ragEngineObj *kaitov1alpha1.RAGEngine) []corev1.EnvVar { } return envs } + +func GenerateRAGServiceManifest(ctx context.Context, ragObj *kaitov1alpha1.RAGEngine, serviceName string, serviceType corev1.ServiceType, serviceRole string) *corev1.Service { + selector := map[string]string{ + kaitov1alpha1.LabelRAGEngineName: ragObj.Name, + } + + var servicePorts []corev1.ServicePort + if serviceRole == "query" { + servicePorts = []corev1.ServicePort{ + { + Name: "http", + Protocol: corev1.ProtocolTCP, + Port: 80, + TargetPort: intstr.FromInt32(5000), + }, + } + } else if serviceRole == "index" { + servicePorts = []corev1.ServicePort{ + { + Name: "http", + Protocol: corev1.ProtocolTCP, + Port: 80, + TargetPort: intstr.FromInt32(5000), + }, + } + } + + return &corev1.Service{ + ObjectMeta: v1.ObjectMeta{ + Name: serviceName, + Namespace: ragObj.Namespace, + OwnerReferences: []v1.OwnerReference{ + { + APIVersion: kaitov1alpha1.GroupVersion.String(), + Kind: "RAGEngine", + UID: ragObj.UID, + Name: ragObj.Name, + Controller: &controller, + }, + }, + }, + Spec: corev1.ServiceSpec{ + Type: serviceType, + Ports: servicePorts, + Selector: selector, + PublishNotReadyAddresses: true, + }, + } +}