diff --git a/api/v1alpha1/condition_types.go b/api/v1alpha1/condition_types.go index d317dffdf..e9b08c352 100644 --- a/api/v1alpha1/condition_types.go +++ b/api/v1alpha1/condition_types.go @@ -19,6 +19,12 @@ const ( // WorkspaceConditionTypeInferenceStatus is the state when Inference service has been ready. WorkspaceConditionTypeInferenceStatus = ConditionType("InferenceReady") + // RAGEneineConditionTypeServiceStatus is the state when service has been ready. + RAGEneineConditionTypeServiceStatus = ConditionType("ServiceReady") + + // RAGConditionTypeServiceStatus is the state when RAG Engine service has been ready. + RAGConditionTypeServiceStatus = ConditionType("RAGEngineServiceReady") + // WorkspaceConditionTypeTuningJobStatus is the state when the tuning job starts normally. WorkspaceConditionTypeTuningJobStatus ConditionType = ConditionType("JobStarted") @@ -32,4 +38,6 @@ const ( //For inference, the "True" condition means the inference service is ready to serve requests. //For fine tuning, the "True" condition means the tuning job completes successfully. WorkspaceConditionTypeSucceeded ConditionType = ConditionType("WorkspaceSucceeded") + + RAGEngineConditionTypeSucceeded ConditionType = ConditionType("RAGEngineSucceeded") ) diff --git a/charts/kaito/ragengine/templates/deployment.yaml b/charts/kaito/ragengine/templates/deployment.yaml index d1b13474c..54010dc37 100644 --- a/charts/kaito/ragengine/templates/deployment.yaml +++ b/charts/kaito/ragengine/templates/deployment.yaml @@ -41,6 +41,8 @@ spec: valueFrom: fieldRef: fieldPath: metadata.namespace + - name: CLOUD_PROVIDER + value: {{ .Values.cloudProviderName }} ports: - name: http-metrics containerPort: 8080 diff --git a/charts/kaito/ragengine/values.yaml b/charts/kaito/ragengine/values.yaml index 1d0416dba..0baf44396 100644 --- a/charts/kaito/ragengine/values.yaml +++ b/charts/kaito/ragengine/values.yaml @@ -28,3 +28,5 @@ resources: nodeSelector: {} tolerations: [] affinity: {} +# Values can be "azure" or "aws" +cloudProviderName: "azure" diff --git a/pkg/ragengine/controllers/preset-rag.go b/pkg/ragengine/controllers/preset-rag.go new file mode 100644 index 000000000..7a84fe4c9 --- /dev/null +++ b/pkg/ragengine/controllers/preset-rag.go @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +package controllers + +import ( + "context" + "fmt" + + "github.com/kaito-project/kaito/pkg/utils" + "github.com/kaito-project/kaito/pkg/utils/consts" + + kaitov1alpha1 "github.com/kaito-project/kaito/api/v1alpha1" + "github.com/kaito-project/kaito/pkg/ragengine/manifests" + "github.com/kaito-project/kaito/pkg/utils/resources" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + "k8s.io/apimachinery/pkg/util/intstr" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +const ( + ProbePath = "/health" + Port5000 = int32(5000) +) + +var ( + containerPorts = []corev1.ContainerPort{{ + ContainerPort: Port5000, + }, + } + + livenessProbe = &corev1.Probe{ + ProbeHandler: corev1.ProbeHandler{ + HTTPGet: &corev1.HTTPGetAction{ + Port: intstr.FromInt(5000), + Path: ProbePath, + }, + }, + InitialDelaySeconds: 600, // 10 minutes + PeriodSeconds: 10, + } + + readinessProbe = &corev1.Probe{ + ProbeHandler: corev1.ProbeHandler{ + HTTPGet: &corev1.HTTPGetAction{ + Port: intstr.FromInt(5000), + Path: ProbePath, + }, + }, + InitialDelaySeconds: 30, + PeriodSeconds: 10, + } + + tolerations = []corev1.Toleration{ + { + Effect: corev1.TaintEffectNoSchedule, + Operator: corev1.TolerationOpExists, + Key: resources.CapacityNvidiaGPU, + }, + { + Effect: corev1.TaintEffectNoSchedule, + Value: consts.GPUString, + Key: consts.SKUString, + Operator: corev1.TolerationOpEqual, + }, + } +) + +func CreatePresetRAG(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine, revisionNum string, kubeClient client.Client) (client.Object, error) { + var volumes []corev1.Volume + var volumeMounts []corev1.VolumeMount + + shmVolume, shmVolumeMount := utils.ConfigSHMVolume(*ragEngineObj.Spec.Compute.Count) + if shmVolume.Name != "" { + volumes = append(volumes, shmVolume) + } + if shmVolumeMount.Name != "" { + volumeMounts = append(volumeMounts, shmVolumeMount) + } + + var resourceReq corev1.ResourceRequirements + + if ragEngineObj.Spec.Embedding.Local != nil { + skuNumGPUs, err := utils.GetSKUNumGPUs(ctx, kubeClient, ragEngineObj.Status.WorkerNodes, + ragEngineObj.Spec.Compute.InstanceType, "1") + if err != nil { + return nil, fmt.Errorf("failed to get SKU num GPUs: %v", err) + } + + resourceReq = corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceName(resources.CapacityNvidiaGPU): resource.MustParse(skuNumGPUs), + }, + Limits: corev1.ResourceList{ + corev1.ResourceName(resources.CapacityNvidiaGPU): resource.MustParse(skuNumGPUs), + }, + } + + } + commands := utils.ShellCmd("python3 main.py") + // TODO: provide this image + image := "mcr.microsoft.com/aks/kaito/kaito-rag-service:0.0.1" + imagePullSecretRefs := []corev1.LocalObjectReference{} + + depObj := manifests.GenerateRAGDeploymentManifest(ctx, ragEngineObj, revisionNum, image, imagePullSecretRefs, *ragEngineObj.Spec.Compute.Count, commands, + containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volumes, volumeMounts) + + err := resources.CreateResource(ctx, depObj, kubeClient) + if client.IgnoreAlreadyExists(err) != nil { + return nil, err + } + return depObj, nil +} diff --git a/pkg/ragengine/controllers/preset-rag_test.go b/pkg/ragengine/controllers/preset-rag_test.go new file mode 100644 index 000000000..eeee4a500 --- /dev/null +++ b/pkg/ragengine/controllers/preset-rag_test.go @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +package controllers + +import ( + "context" + "os" + "strings" + "testing" + + "github.com/kaito-project/kaito/pkg/utils/consts" + "github.com/kaito-project/kaito/pkg/utils/test" + "github.com/stretchr/testify/mock" + appsv1 "k8s.io/api/apps/v1" +) + +func TestCreatePresetRAG(t *testing.T) { + test.RegisterTestModel() + + testcases := map[string]struct { + nodeCount int + callMocks func(c *test.MockClient) + expectedCmd string + expectedGPUReq string + expectedImage string + expectedVolume string + }{ + "test-rag-model": { + nodeCount: 1, + callMocks: func(c *test.MockClient) { + c.On("Create", mock.IsType(context.TODO()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil) + }, + expectedCmd: "/bin/sh -c python3 main.py", + expectedImage: "mcr.microsoft.com/aks/kaito/kaito-rag-service:0.0.1", + }, + } + + for k, tc := range testcases { + t.Run(k, func(t *testing.T) { + os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + mockClient := test.NewClient() + tc.callMocks(mockClient) + + ragEngineObj := test.MockRAGEngineWithPreset + createdObject, _ := CreatePresetRAG(context.TODO(), ragEngineObj, "1", mockClient) + + workloadCmd := strings.Join((createdObject.(*appsv1.Deployment)).Spec.Template.Spec.Containers[0].Command, " ") + + if workloadCmd != tc.expectedCmd { + t.Errorf("%s: main cmdline is not expected, got %s, expected %s", k, workloadCmd, tc.expectedCmd) + } + + image := (createdObject.(*appsv1.Deployment)).Spec.Template.Spec.Containers[0].Image + + if image != tc.expectedImage { + t.Errorf("%s: image is not expected, got %s, expected %s", k, image, tc.expectedImage) + } + }) + } +} diff --git a/pkg/ragengine/controllers/ragengine_controller.go b/pkg/ragengine/controllers/ragengine_controller.go index 9a7036182..e58a965ff 100644 --- a/pkg/ragengine/controllers/ragengine_controller.go +++ b/pkg/ragengine/controllers/ragengine_controller.go @@ -115,11 +115,78 @@ func (c *RAGEngineReconciler) ensureFinalizer(ctx context.Context, ragEngineObj func (c *RAGEngineReconciler) addRAGEngine(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) (reconcile.Result, error) { err := c.applyRAGEngineResource(ctx, ragEngineObj) if err != nil { + if updateErr := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeSucceeded, metav1.ConditionFalse, + "ragengineFailed", err.Error()); updateErr != nil { + klog.ErrorS(updateErr, "failed to update ragengine status", "ragengine", klog.KObj(ragEngineObj)) + return reconcile.Result{}, updateErr + } + // if error is due to machine/nodeClaim instance types unavailability, stop reconcile. + if err.Error() == consts.ErrorInstanceTypesUnavailable { + return reconcile.Result{Requeue: false}, err + } + return reconcile.Result{}, err + } + if err = c.applyRAG(ctx, ragEngineObj); err != nil { + if updateErr := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeSucceeded, metav1.ConditionFalse, + "ragengineFailed", err.Error()); updateErr != nil { + klog.ErrorS(updateErr, "failed to update ragengine status", "ragengine", klog.KObj(ragEngineObj)) + return reconcile.Result{}, updateErr + } + return reconcile.Result{}, err + } + + if err = c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeSucceeded, metav1.ConditionTrue, + "ragengineSucceeded", "ragengine succeeds"); err != nil { + klog.ErrorS(err, "failed to update ragengine status", "ragengine", klog.KObj(ragEngineObj)) return reconcile.Result{}, err } return reconcile.Result{}, nil } +func (c *RAGEngineReconciler) applyRAG(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) error { + var err error + func() { + + deployment := &appsv1.Deployment{} + revisionStr := ragEngineObj.Annotations[kaitov1alpha1.RAGEngineRevisionAnnotation] + + if err = resources.GetResource(ctx, ragEngineObj.Name, ragEngineObj.Namespace, c.Client, deployment); err == nil { + klog.InfoS("An inference workload already exists for ragengine", "ragengine", klog.KObj(ragEngineObj)) + return + + } else if apierrors.IsNotFound(err) { + var workloadObj client.Object + // Need to create a new workload + workloadObj, err = CreatePresetRAG(ctx, ragEngineObj, revisionStr, c.Client) + if err != nil { + return + } + if err = resources.CheckResourceStatus(workloadObj, c.Client, time.Duration(10)*time.Minute); err != nil { + return + } + } + + }() + + if err != nil { + if updateErr := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGConditionTypeServiceStatus, metav1.ConditionFalse, + "RAGEngineServiceStatusFailed", err.Error()); updateErr != nil { + klog.ErrorS(updateErr, "failed to update ragengine status", "ragengine", klog.KObj(ragEngineObj)) + return updateErr + } else { + return err + } + } + + if err := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEneineConditionTypeServiceStatus, metav1.ConditionTrue, + "RAGEngineServiceSuccess", "Inference has been deployed successfully"); err != nil { + klog.ErrorS(err, "failed to update ragengine status", "ragengine", klog.KObj(ragEngineObj)) + return err + } + + return nil +} + func (c *RAGEngineReconciler) deleteRAGEngine(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) (reconcile.Result, error) { klog.InfoS("deleteRAGEngine", "ragengine", klog.KObj(ragEngineObj)) err := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeDeleting, metav1.ConditionTrue, "ragengineDeleted", "ragengine is being deleted") diff --git a/pkg/ragengine/manifests/manifests.go b/pkg/ragengine/manifests/manifests.go new file mode 100644 index 000000000..254de2e96 --- /dev/null +++ b/pkg/ragengine/manifests/manifests.go @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package manifests + +import ( + "context" + + "k8s.io/apimachinery/pkg/util/intstr" + + kaitov1alpha1 "github.com/kaito-project/kaito/api/v1alpha1" + "github.com/samber/lo" + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +var controller = true + +func GenerateRAGDeploymentManifest(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine, revisionNum string, imageName string, + imagePullSecretRefs []corev1.LocalObjectReference, replicas int, commands []string, containerPorts []corev1.ContainerPort, + livenessProbe, readinessProbe *corev1.Probe, resourceRequirements corev1.ResourceRequirements, + tolerations []corev1.Toleration, volumes []corev1.Volume, volumeMount []corev1.VolumeMount) *appsv1.Deployment { + + nodeRequirements := make([]corev1.NodeSelectorRequirement, 0, len(ragEngineObj.Spec.Compute.LabelSelector.MatchLabels)) + for key, value := range ragEngineObj.Spec.Compute.LabelSelector.MatchLabels { + nodeRequirements = append(nodeRequirements, corev1.NodeSelectorRequirement{ + Key: key, + Operator: corev1.NodeSelectorOpIn, + Values: []string{value}, + }) + } + + selector := map[string]string{ + kaitov1alpha1.LabelRAGEngineName: ragEngineObj.Name, + } + labelselector := &v1.LabelSelector{ + MatchLabels: selector, + } + initContainers := []corev1.Container{} + + envs := RAGSetEnv(ragEngineObj) + + return &appsv1.Deployment{ + ObjectMeta: v1.ObjectMeta{ + Name: ragEngineObj.Name, + Namespace: ragEngineObj.Namespace, + OwnerReferences: []v1.OwnerReference{ + { + APIVersion: kaitov1alpha1.GroupVersion.String(), + Kind: "RAGEngine", + UID: ragEngineObj.UID, + Name: ragEngineObj.Name, + Controller: &controller, + }, + }, + Annotations: map[string]string{ + kaitov1alpha1.RAGEngineRevisionAnnotation: revisionNum, + }, + }, + Spec: appsv1.DeploymentSpec{ + Replicas: lo.ToPtr(int32(replicas)), + Strategy: appsv1.DeploymentStrategy{ + Type: appsv1.RollingUpdateDeploymentStrategyType, + RollingUpdate: &appsv1.RollingUpdateDeployment{ + MaxSurge: &intstr.IntOrString{ + Type: intstr.Int, + IntVal: 0, + }, + MaxUnavailable: &intstr.IntOrString{ + Type: intstr.Int, + IntVal: 1, + }, + }, // Configuration for rolling updates: allows no extra pods during the update and permits at most one unavailable pod at a time。 + }, + Selector: labelselector, + Template: corev1.PodTemplateSpec{ + ObjectMeta: v1.ObjectMeta{ + Labels: selector, + }, + Spec: corev1.PodSpec{ + ImagePullSecrets: imagePullSecretRefs, + Affinity: &corev1.Affinity{ + NodeAffinity: &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: nodeRequirements, + }, + }, + }, + }, + }, + InitContainers: initContainers, + Containers: []corev1.Container{ + { + Name: ragEngineObj.Name, + Image: imageName, + Command: commands, + Resources: resourceRequirements, + LivenessProbe: livenessProbe, + ReadinessProbe: readinessProbe, + Ports: containerPorts, + VolumeMounts: volumeMount, + Env: envs, + }, + }, + Tolerations: tolerations, + Volumes: volumes, + }, + }, + }, + } +} + +func RAGSetEnv(ragEngineObj *kaitov1alpha1.RAGEngine) []corev1.EnvVar { + var envs []corev1.EnvVar + var embeddingType string + if ragEngineObj.Spec.Embedding.Local != nil { + embeddingType = "local" + if ragEngineObj.Spec.Embedding.Local.ModelID != "" { + modelID := ragEngineObj.Spec.Embedding.Local.ModelID + modelIDEnv := corev1.EnvVar{ + Name: "MODEL_ID", + Value: modelID, + } + envs = append(envs, modelIDEnv) + } + if ragEngineObj.Spec.Embedding.Local.ModelAccessSecret != "" { + accessSecret := ragEngineObj.Spec.Embedding.Local.ModelAccessSecret + accessSecretEnv := corev1.EnvVar{ + Name: "ACCESS_SECRET", + Value: accessSecret, + } + envs = append(envs, accessSecretEnv) + } + } else if ragEngineObj.Spec.Embedding.Remote != nil { + embeddingType = "remote" + // TODO: Model ID Env + } + embeddingTypeEnv := corev1.EnvVar{ + Name: "EMBEDDING_TYPE", + Value: embeddingType, + } + envs = append(envs, embeddingTypeEnv) + + stoageEnv := corev1.EnvVar{ + Name: "VECTOR_DB_TYPE", + Value: "faiss", // TODO: get storage done + } + envs = append(envs, stoageEnv) + inferenceServiceURL := ragEngineObj.Spec.InferenceService.URL + inferenceServiceURLEnv := corev1.EnvVar{ + Name: "INFERENCE_URL", + Value: inferenceServiceURL, + } + envs = append(envs, inferenceServiceURLEnv) + + if ragEngineObj.Spec.InferenceService.AccessSecret != "" { + accessSecretEnv := corev1.EnvVar{ + Name: "INFERENCE_ACCESS_SECRET", + Value: ragEngineObj.Spec.InferenceService.AccessSecret, + } + envs = append(envs, accessSecretEnv) + } + return envs +} diff --git a/pkg/ragengine/manifests/manifests_test.go b/pkg/ragengine/manifests/manifests_test.go new file mode 100644 index 000000000..c25872bc2 --- /dev/null +++ b/pkg/ragengine/manifests/manifests_test.go @@ -0,0 +1,70 @@ +package manifests + +import ( + "context" + "reflect" + + "github.com/kaito-project/kaito/pkg/utils/test" + + "testing" + + kaitov1alpha1 "github.com/kaito-project/kaito/api/v1alpha1" + v1 "k8s.io/api/core/v1" +) + +func kvInNodeRequirement(key, val string, nodeReq []v1.NodeSelectorRequirement) bool { + for _, each := range nodeReq { + if each.Key == key && each.Values[0] == val && each.Operator == v1.NodeSelectorOpIn { + return true + } + } + return false +} + +func TestGenerateRAGDeploymentManifest(t *testing.T) { + t.Run("generate RAG deployment", func(t *testing.T) { + + // Mocking the RAGEngine object for the test + ragEngine := test.MockRAGEngineWithPreset + + // Calling the function to generate the deployment manifest + obj := GenerateRAGDeploymentManifest(context.TODO(), ragEngine, test.MockRAGEngineWithPresetHash, + "", // imageName + nil, // imagePullSecretRefs + *ragEngine.Spec.Compute.Count, // replicas + nil, // commands + nil, // containerPorts + nil, // livenessProbe + nil, // readinessProbe + v1.ResourceRequirements{}, + nil, // tolerations + nil, // volumes + nil, // volumeMount + ) + + // Expected label selector for the deployment + appSelector := map[string]string{ + kaitov1alpha1.LabelRAGEngineName: ragEngine.Name, + } + + // Check if the deployment's selector is correct + if !reflect.DeepEqual(appSelector, obj.Spec.Selector.MatchLabels) { + t.Errorf("RAGEngine workload selector is wrong") + } + + // Check if the template labels match the expected labels + if !reflect.DeepEqual(appSelector, obj.Spec.Template.ObjectMeta.Labels) { + t.Errorf("RAGEngine template label is wrong") + } + + // Extract node selector requirements from the deployment manifest + nodeReq := obj.Spec.Template.Spec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions + + // Validate if the node requirements match the RAGEngine's label selector + for key, value := range ragEngine.Spec.Compute.LabelSelector.MatchLabels { + if !kvInNodeRequirement(key, value, nodeReq) { + t.Errorf("Node affinity requirements are wrong for key %s and value %s", key, value) + } + } + }) +} diff --git a/pkg/utils/test/testUtils.go b/pkg/utils/test/testUtils.go index d51c3958e..912cd2bc0 100644 --- a/pkg/utils/test/testUtils.go +++ b/pkg/utils/test/testUtils.go @@ -72,6 +72,30 @@ var ( } ) +var ( + MockRAGEngine = &v1alpha1.RAGEngine{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testRAGEngine", + Namespace: "kaito", + }, + Spec: &v1alpha1.RAGEngineSpec{ + Compute: &v1alpha1.ResourceSpec{ + Count: &gpuNodeCount, + InstanceType: "Standard_NC12s_v3", + LabelSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "ragengine.kaito.io/name": "testRAGEngine", + }, + }, + }, + Embedding: &v1alpha1.EmbeddingSpec{ + Local: &v1alpha1.LocalEmbeddingSpec{ + ModelID: "BAAI/bge-small-en-v1.5", + }, + }, + }, + } +) var ( MockRAGEngineDistributedModel = &v1alpha1.RAGEngine{ ObjectMeta: metav1.ObjectMeta{ @@ -119,6 +143,36 @@ var ( var MockWorkspaceWithPresetHash = "89ae127050ec264a5ce84db48ef7226574cdf1299e6bd27fe90b927e34cc8adb" +var ( + MockRAGEngineWithPreset = &v1alpha1.RAGEngine{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testRAGEngine", + Namespace: "kaito", + }, + Spec: &v1alpha1.RAGEngineSpec{ + Compute: &v1alpha1.ResourceSpec{ + Count: &gpuNodeCount, + InstanceType: "Standard_NC12s_v3", + LabelSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "ragengine.kaito.io/name": "testRAGEngine", + }, + }, + }, + Embedding: &v1alpha1.EmbeddingSpec{ + Local: &v1alpha1.LocalEmbeddingSpec{ + ModelID: "BAAI/bge-small-en-v1.5", + }, + }, + InferenceService: &v1alpha1.InferenceServiceSpec{ + URL: "http://localhost:5000/chat", + }, + }, + } +) + +var MockRAGEngineWithPresetHash = "14485768c1b67a529a71e3c87d9f2e6c1ed747534dea07e268e93475a5e21e" + var ( MockWorkspaceWithDeleteOldCR = v1alpha1.Workspace{ ObjectMeta: metav1.ObjectMeta{