Skip to content

Commit

Permalink
RAG ensure service
Browse files Browse the repository at this point in the history
Signed-off-by: Bangqi Zhu <[email protected]>
  • Loading branch information
Bangqi Zhu committed Dec 23, 2024
1 parent 7da6586 commit b635b2f
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 1 deletion.
49 changes: 49 additions & 0 deletions pkg/ragengine/controllers/ragengine_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,14 @@ func (c *RAGEngineReconciler) addRAGEngine(ctx context.Context, ragEngineObj *ka
}
return reconcile.Result{}, err
}
if err := c.ensureServices(ctx, ragEngineObj); err != nil {
if updateErr := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeSucceeded, metav1.ConditionFalse,
"ragEngineFailed", err.Error()); updateErr != nil {
klog.ErrorS(updateErr, "failed to update ragEngine status", "ragEngine", klog.KObj(ragEngineObj))
return reconcile.Result{}, updateErr
}
return reconcile.Result{}, err
}
if err = c.applyRAG(ctx, ragEngineObj); err != nil {
if updateErr := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeSucceeded, metav1.ConditionFalse,
"ragengineFailed", err.Error()); updateErr != nil {
Expand All @@ -144,6 +152,47 @@ func (c *RAGEngineReconciler) addRAGEngine(ctx context.Context, ragEngineObj *ka
return reconcile.Result{}, nil
}

func (c *RAGEngineReconciler) ensureServices(ctx context.Context, ragObj *kaitov1alpha1.RAGEngine) error {
serviceType := corev1.ServiceTypeClusterIP
ragAnnotations := ragObj.GetAnnotations()

if len(ragAnnotations) != 0 {
val, found := ragAnnotations[kaitov1alpha1.AnnotationEnableLB]
if found && val == "True" {
serviceType = corev1.ServiceTypeLoadBalancer
}
}

// Ensure Service for index and query
// TODO: ServiceName currently does not accept customization for now

queryServiceName := ragObj.Name

if err := c.ensureService(ctx, ragObj, queryServiceName, serviceType); err != nil {
return err
}

return nil
}

func (c *RAGEngineReconciler) ensureService(ctx context.Context, ragObj *kaitov1alpha1.RAGEngine, serviceName string, serviceType corev1.ServiceType) error {
existingSVC := &corev1.Service{}
err := resources.GetResource(ctx, serviceName, ragObj.Namespace, c.Client, existingSVC)
if err != nil {
if !apierrors.IsNotFound(err) {
return err
}
} else {
return nil
}
serviceObj := manifests.GenerateRAGServiceManifest(ctx, ragObj, serviceName, serviceType)
if err := resources.CreateResource(ctx, serviceObj, c.Client); err != nil {
return err
}

return nil
}

func (c *RAGEngineReconciler) applyRAG(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) error {
var err error
func() {
Expand Down
81 changes: 81 additions & 0 deletions pkg/ragengine/controllers/ragengine_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -636,3 +636,84 @@ func TestApplyRAG(t *testing.T) {
})
}
}

func TestEnsureService(t *testing.T) {
test.RegisterTestModel()
testcases := map[string]struct {
callMocks func(c *test.MockClient)
expectedError error
ragengine v1alpha1.RAGEngine
verifyCalls func(c *test.MockClient)
}{

"Existing service is found for RAGEngine": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(nil)
},
expectedError: nil,
ragengine: *test.MockRAGEngineWithPreset,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 0)
c.AssertNumberOfCalls(t, "Create", 0)
c.AssertNumberOfCalls(t, "Get", 1)
c.AssertNumberOfCalls(t, "Delete", 0)
c.AssertNumberOfCalls(t, "Update", 0)
},
},

"Service creation fails": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(test.NotFoundError())
c.On("Create", mock.IsType(context.Background()), mock.IsType(&corev1.Service{}), mock.Anything).Return(errors.New("cannot create service"))
},
expectedError: errors.New("cannot create service"),
ragengine: *test.MockRAGEngineWithPreset,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 0)
c.AssertNumberOfCalls(t, "Create", 4)
c.AssertNumberOfCalls(t, "Get", 4)
c.AssertNumberOfCalls(t, "Delete", 0)
c.AssertNumberOfCalls(t, "Update", 0)
},
},

"Successfully creates a new service": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(test.NotFoundError())
c.On("Create", mock.IsType(context.Background()), mock.IsType(&corev1.Service{}), mock.Anything).Return(nil)
},
expectedError: nil,
ragengine: *test.MockRAGEngineWithPreset,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 0)
c.AssertNumberOfCalls(t, "Create", 1)
c.AssertNumberOfCalls(t, "Get", 4)
c.AssertNumberOfCalls(t, "Delete", 0)
c.AssertNumberOfCalls(t, "Update", 0)
},
},
}

for k, tc := range testcases {
t.Run(k, func(t *testing.T) {
mockClient := test.NewClient()
tc.callMocks(mockClient)

reconciler := &RAGEngineReconciler{
Client: mockClient,
Scheme: test.NewTestScheme(),
}
ctx := context.Background()

err := reconciler.ensureServices(ctx, &tc.ragengine)
if tc.expectedError == nil {
assert.Check(t, err == nil, "Not expected to return error")
} else {
assert.Equal(t, tc.expectedError.Error(), err.Error())
}
if tc.verifyCalls != nil {
tc.verifyCalls(mockClient)
}
})
}
}
39 changes: 38 additions & 1 deletion pkg/ragengine/manifests/manifests.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func RAGSetEnv(ragEngineObj *kaitov1alpha1.RAGEngine) []corev1.EnvVar {
envs = append(envs, stoageEnv)
inferenceServiceURL := ragEngineObj.Spec.InferenceService.URL
inferenceServiceURLEnv := corev1.EnvVar{
Name: "INFERENCE_URL",
Name: "LLM_INFERENCE_URL",
Value: inferenceServiceURL,
}
envs = append(envs, inferenceServiceURLEnv)
Expand All @@ -165,3 +165,40 @@ func RAGSetEnv(ragEngineObj *kaitov1alpha1.RAGEngine) []corev1.EnvVar {
}
return envs
}

func GenerateRAGServiceManifest(ctx context.Context, ragObj *kaitov1alpha1.RAGEngine, serviceName string, serviceType corev1.ServiceType) *corev1.Service {
selector := map[string]string{
kaitov1alpha1.LabelRAGEngineName: ragObj.Name,
}

servicePorts := []corev1.ServicePort{
{
Name: "http",
Protocol: corev1.ProtocolTCP,
Port: 80,
TargetPort: intstr.FromInt32(5000),
},
}

return &corev1.Service{
ObjectMeta: v1.ObjectMeta{
Name: serviceName,
Namespace: ragObj.Namespace,
OwnerReferences: []v1.OwnerReference{
{
APIVersion: kaitov1alpha1.GroupVersion.String(),
Kind: "RAGEngine",
UID: ragObj.UID,
Name: ragObj.Name,
Controller: &controller,
},
},
},
Spec: corev1.ServiceSpec{
Type: serviceType,
Ports: servicePorts,
Selector: selector,
PublishNotReadyAddresses: true,
},
}
}

0 comments on commit b635b2f

Please sign in to comment.