Skip to content

Commit

Permalink
RAG ensure service
Browse files Browse the repository at this point in the history
Signed-off-by: Bangqi Zhu <[email protected]>
  • Loading branch information
Bangqi Zhu committed Dec 12, 2024
1 parent 83f25cd commit 2f9c2ba
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 0 deletions.
58 changes: 58 additions & 0 deletions pkg/ragengine/controllers/ragengine_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,14 @@ func (c *RAGEngineReconciler) addRAGEngine(ctx context.Context, ragEngineObj *ka
}
return reconcile.Result{}, err
}
if err := c.ensureServices(ctx, ragEngineObj); err != nil {
if updateErr := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeSucceeded, metav1.ConditionFalse,
"ragEngineFailed", err.Error()); updateErr != nil {
klog.ErrorS(updateErr, "failed to update ragEngine status", "ragEngine", klog.KObj(ragEngineObj))
return reconcile.Result{}, updateErr
}
return reconcile.Result{}, err
}
if err = c.applyRAG(ctx, ragEngineObj); err != nil {
if updateErr := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeSucceeded, metav1.ConditionFalse,
"ragengineFailed", err.Error()); updateErr != nil {
Expand All @@ -144,6 +152,56 @@ func (c *RAGEngineReconciler) addRAGEngine(ctx context.Context, ragEngineObj *ka
return reconcile.Result{}, nil
}

func (c *RAGEngineReconciler) ensureServices(ctx context.Context, ragObj *kaitov1alpha1.RAGEngine) error {
serviceType := corev1.ServiceTypeClusterIP
ragAnnotations := ragObj.GetAnnotations()

if len(ragAnnotations) != 0 {
val, found := ragAnnotations[kaitov1alpha1.AnnotationEnableLB]
if found && val == "True" {
serviceType = corev1.ServiceTypeLoadBalancer
}
}

// Ensure QueryService
queryServiceName := ragObj.Spec.QueryServiceName
if queryServiceName == "" {
queryServiceName = fmt.Sprintf("%s-query", ragObj.Name)
}
if err := c.ensureService(ctx, ragObj, queryServiceName, serviceType, "query"); err != nil {
return err
}

// Ensure IndexService
indexServiceName := ragObj.Spec.IndexServiceName
if indexServiceName == "" {
indexServiceName = fmt.Sprintf("%s-index", ragObj.Name)
}
if err := c.ensureService(ctx, ragObj, indexServiceName, serviceType, "index"); err != nil {
return err
}

return nil
}

func (c *RAGEngineReconciler) ensureService(ctx context.Context, ragObj *kaitov1alpha1.RAGEngine, serviceName string, serviceType corev1.ServiceType, serviceRole string) error {
existingSVC := &corev1.Service{}
err := resources.GetResource(ctx, serviceName, ragObj.Namespace, c.Client, existingSVC)
if err != nil {
if !apierrors.IsNotFound(err) {
return err
}
} else {
return nil
}
serviceObj := manifests.GenerateRAGServiceManifest(ctx, ragObj, serviceName, serviceType, serviceRole)
if err := resources.CreateResource(ctx, serviceObj, c.Client); err != nil {
return err
}

return nil
}

func (c *RAGEngineReconciler) applyRAG(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) error {
var err error
func() {
Expand Down
81 changes: 81 additions & 0 deletions pkg/ragengine/controllers/ragengine_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -636,3 +636,84 @@ func TestApplyRAG(t *testing.T) {
})
}
}

func TestEnsureService(t *testing.T) {
test.RegisterTestModel()
testcases := map[string]struct {
callMocks func(c *test.MockClient)
expectedError error
ragengine v1alpha1.RAGEngine
verifyCalls func(c *test.MockClient)
}{

"Existing service is found for RAGEngine": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(nil)
},
expectedError: nil,
ragengine: *test.MockRAGEngineWithPreset,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 0)
c.AssertNumberOfCalls(t, "Create", 0)
c.AssertNumberOfCalls(t, "Get", 2)
c.AssertNumberOfCalls(t, "Delete", 0)
c.AssertNumberOfCalls(t, "Update", 0)
},
},

"Service creation fails": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(test.NotFoundError())
c.On("Create", mock.IsType(context.Background()), mock.IsType(&corev1.Service{}), mock.Anything).Return(errors.New("cannot create service"))
},
expectedError: errors.New("cannot create service"),
ragengine: *test.MockRAGEngineWithPreset,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 0)
c.AssertNumberOfCalls(t, "Create", 4)
c.AssertNumberOfCalls(t, "Get", 4)
c.AssertNumberOfCalls(t, "Delete", 0)
c.AssertNumberOfCalls(t, "Update", 0)
},
},

"Successfully creates a new service": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(test.NotFoundError())
c.On("Create", mock.IsType(context.Background()), mock.IsType(&corev1.Service{}), mock.Anything).Return(nil)
},
expectedError: nil,
ragengine: *test.MockRAGEngineWithPreset,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 0)
c.AssertNumberOfCalls(t, "Create", 2)
c.AssertNumberOfCalls(t, "Get", 8)
c.AssertNumberOfCalls(t, "Delete", 0)
c.AssertNumberOfCalls(t, "Update", 0)
},
},
}

for k, tc := range testcases {
t.Run(k, func(t *testing.T) {
mockClient := test.NewClient()
tc.callMocks(mockClient)

reconciler := &RAGEngineReconciler{
Client: mockClient,
Scheme: test.NewTestScheme(),
}
ctx := context.Background()

err := reconciler.ensureServices(ctx, &tc.ragengine)
if tc.expectedError == nil {
assert.Check(t, err == nil, "Not expected to return error")
} else {
assert.Equal(t, tc.expectedError.Error(), err.Error())
}
if tc.verifyCalls != nil {
tc.verifyCalls(mockClient)
}
})
}
}
49 changes: 49 additions & 0 deletions pkg/ragengine/manifests/manifests.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,52 @@ func RAGSetEnv(ragEngineObj *kaitov1alpha1.RAGEngine) []corev1.EnvVar {
}
return envs
}

func GenerateRAGServiceManifest(ctx context.Context, ragObj *kaitov1alpha1.RAGEngine, serviceName string, serviceType corev1.ServiceType, serviceRole string) *corev1.Service {
selector := map[string]string{
kaitov1alpha1.LabelRAGEngineName: ragObj.Name,
}

var servicePorts []corev1.ServicePort
if serviceRole == "query" {
servicePorts = []corev1.ServicePort{
{
Name: "http",
Protocol: corev1.ProtocolTCP,
Port: 80,
TargetPort: intstr.FromInt32(5000),
},
}
} else if serviceRole == "index" {
servicePorts = []corev1.ServicePort{
{
Name: "http",
Protocol: corev1.ProtocolTCP,
Port: 80,
TargetPort: intstr.FromInt32(5000),
},
}
}

return &corev1.Service{
ObjectMeta: v1.ObjectMeta{
Name: serviceName,
Namespace: ragObj.Namespace,
OwnerReferences: []v1.OwnerReference{
{
APIVersion: kaitov1alpha1.GroupVersion.String(),
Kind: "RAGEngine",
UID: ragObj.UID,
Name: ragObj.Name,
Controller: &controller,
},
},
},
Spec: corev1.ServiceSpec{
Type: serviceType,
Ports: servicePorts,
Selector: selector,
PublishNotReadyAddresses: true,
},
}
}

0 comments on commit 2f9c2ba

Please sign in to comment.