Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: RAG engine deployment creation #660

Merged
merged 1 commit into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions api/v1alpha1/condition_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ const (
// WorkspaceConditionTypeInferenceStatus is the state when Inference service has been ready.
WorkspaceConditionTypeInferenceStatus = ConditionType("InferenceReady")

// RAGEneineConditionTypeServiceStatus is the state when service has been ready.
RAGEneineConditionTypeServiceStatus = ConditionType("ServiceReady")

// RAGConditionTypeServiceStatus is the state when RAG Engine service has been ready.
RAGConditionTypeServiceStatus = ConditionType("RAGEngineServiceReady")

// WorkspaceConditionTypeTuningJobStatus is the state when the tuning job starts normally.
WorkspaceConditionTypeTuningJobStatus ConditionType = ConditionType("JobStarted")

Expand All @@ -32,4 +38,6 @@ const (
//For inference, the "True" condition means the inference service is ready to serve requests.
//For fine tuning, the "True" condition means the tuning job completes successfully.
WorkspaceConditionTypeSucceeded ConditionType = ConditionType("WorkspaceSucceeded")

RAGEngineConditionTypeSucceeded ConditionType = ConditionType("RAGEngineSucceeded")
)
2 changes: 2 additions & 0 deletions charts/kaito/ragengine/templates/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ spec:
valueFrom:
fieldRef:
fieldPath: metadata.namespace
- name: CLOUD_PROVIDER
value: {{ .Values.cloudProviderName }}
ports:
- name: http-metrics
containerPort: 8080
Expand Down
2 changes: 2 additions & 0 deletions charts/kaito/ragengine/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@ resources:
nodeSelector: {}
tolerations: []
affinity: {}
# Values can be "azure" or "aws"
cloudProviderName: "azure"
113 changes: 113 additions & 0 deletions pkg/ragengine/controllers/preset-rag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package controllers

import (
"context"
"fmt"

"github.com/kaito-project/kaito/pkg/utils"
"github.com/kaito-project/kaito/pkg/utils/consts"

kaitov1alpha1 "github.com/kaito-project/kaito/api/v1alpha1"
"github.com/kaito-project/kaito/pkg/ragengine/manifests"
"github.com/kaito-project/kaito/pkg/utils/resources"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
"k8s.io/apimachinery/pkg/util/intstr"
"sigs.k8s.io/controller-runtime/pkg/client"
)

const (
ProbePath = "/health"
Port5000 = int32(5000)
)

var (
containerPorts = []corev1.ContainerPort{{
ContainerPort: Port5000,
},
}

livenessProbe = &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Port: intstr.FromInt(5000),
Path: ProbePath,
},
},
InitialDelaySeconds: 600, // 10 minutes
PeriodSeconds: 10,
}

readinessProbe = &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Port: intstr.FromInt(5000),
Path: ProbePath,
},
},
InitialDelaySeconds: 30,
PeriodSeconds: 10,
}

tolerations = []corev1.Toleration{
{
Effect: corev1.TaintEffectNoSchedule,
Operator: corev1.TolerationOpExists,
Key: resources.CapacityNvidiaGPU,
},
{
Effect: corev1.TaintEffectNoSchedule,
Value: consts.GPUString,
Key: consts.SKUString,
Operator: corev1.TolerationOpEqual,
},
}
)

func CreatePresetRAG(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine, revisionNum string, kubeClient client.Client) (client.Object, error) {
var volumes []corev1.Volume
var volumeMounts []corev1.VolumeMount

shmVolume, shmVolumeMount := utils.ConfigSHMVolume(*ragEngineObj.Spec.Compute.Count)
if shmVolume.Name != "" {
volumes = append(volumes, shmVolume)
}
if shmVolumeMount.Name != "" {
volumeMounts = append(volumeMounts, shmVolumeMount)
}

var resourceReq corev1.ResourceRequirements

if ragEngineObj.Spec.Embedding.Local != nil {
skuNumGPUs, err := utils.GetSKUNumGPUs(ctx, kubeClient, ragEngineObj.Status.WorkerNodes,
ragEngineObj.Spec.Compute.InstanceType, "1")
if err != nil {
return nil, fmt.Errorf("failed to get SKU num GPUs: %v", err)
}

resourceReq = corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceName(resources.CapacityNvidiaGPU): resource.MustParse(skuNumGPUs),
},
Limits: corev1.ResourceList{
corev1.ResourceName(resources.CapacityNvidiaGPU): resource.MustParse(skuNumGPUs),
},
}

}
commands := utils.ShellCmd("python3 main.py")
// TODO: provide this image
image := "mcr.microsoft.com/aks/kaito/kaito-rag-service:0.0.1"
imagePullSecretRefs := []corev1.LocalObjectReference{}

depObj := manifests.GenerateRAGDeploymentManifest(ctx, ragEngineObj, revisionNum, image, imagePullSecretRefs, *ragEngineObj.Spec.Compute.Count, commands,
containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volumes, volumeMounts)

err := resources.CreateResource(ctx, depObj, kubeClient)
if client.IgnoreAlreadyExists(err) != nil {
return nil, err
}
return depObj, nil
}
60 changes: 60 additions & 0 deletions pkg/ragengine/controllers/preset-rag_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package controllers

import (
"context"
"os"
"strings"
"testing"

"github.com/kaito-project/kaito/pkg/utils/consts"
"github.com/kaito-project/kaito/pkg/utils/test"
"github.com/stretchr/testify/mock"
appsv1 "k8s.io/api/apps/v1"
)

func TestCreatePresetRAG(t *testing.T) {
test.RegisterTestModel()

testcases := map[string]struct {
nodeCount int
callMocks func(c *test.MockClient)
expectedCmd string
expectedGPUReq string
expectedImage string
expectedVolume string
}{
"test-rag-model": {
nodeCount: 1,
callMocks: func(c *test.MockClient) {
c.On("Create", mock.IsType(context.TODO()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil)
},
expectedCmd: "/bin/sh -c python3 main.py",
expectedImage: "mcr.microsoft.com/aks/kaito/kaito-rag-service:0.0.1",
},
}

for k, tc := range testcases {
t.Run(k, func(t *testing.T) {
os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)
mockClient := test.NewClient()
tc.callMocks(mockClient)

ragEngineObj := test.MockRAGEngineWithPreset
createdObject, _ := CreatePresetRAG(context.TODO(), ragEngineObj, "1", mockClient)

workloadCmd := strings.Join((createdObject.(*appsv1.Deployment)).Spec.Template.Spec.Containers[0].Command, " ")

if workloadCmd != tc.expectedCmd {
t.Errorf("%s: main cmdline is not expected, got %s, expected %s", k, workloadCmd, tc.expectedCmd)
}

image := (createdObject.(*appsv1.Deployment)).Spec.Template.Spec.Containers[0].Image

if image != tc.expectedImage {
t.Errorf("%s: image is not expected, got %s, expected %s", k, image, tc.expectedImage)
}
})
}
}
67 changes: 67 additions & 0 deletions pkg/ragengine/controllers/ragengine_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,78 @@ func (c *RAGEngineReconciler) ensureFinalizer(ctx context.Context, ragEngineObj
func (c *RAGEngineReconciler) addRAGEngine(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) (reconcile.Result, error) {
err := c.applyRAGEngineResource(ctx, ragEngineObj)
if err != nil {
if updateErr := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeSucceeded, metav1.ConditionFalse,
"ragengineFailed", err.Error()); updateErr != nil {
klog.ErrorS(updateErr, "failed to update ragengine status", "ragengine", klog.KObj(ragEngineObj))
return reconcile.Result{}, updateErr
}
// if error is due to machine/nodeClaim instance types unavailability, stop reconcile.
if err.Error() == consts.ErrorInstanceTypesUnavailable {
return reconcile.Result{Requeue: false}, err
}
return reconcile.Result{}, err
}
if err = c.applyRAG(ctx, ragEngineObj); err != nil {
if updateErr := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeSucceeded, metav1.ConditionFalse,
"ragengineFailed", err.Error()); updateErr != nil {
klog.ErrorS(updateErr, "failed to update ragengine status", "ragengine", klog.KObj(ragEngineObj))
return reconcile.Result{}, updateErr
}
return reconcile.Result{}, err
}

if err = c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeSucceeded, metav1.ConditionTrue,
"ragengineSucceeded", "ragengine succeeds"); err != nil {
klog.ErrorS(err, "failed to update ragengine status", "ragengine", klog.KObj(ragEngineObj))
return reconcile.Result{}, err
}
return reconcile.Result{}, nil
}

func (c *RAGEngineReconciler) applyRAG(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) error {
var err error
func() {

deployment := &appsv1.Deployment{}
revisionStr := ragEngineObj.Annotations[kaitov1alpha1.RAGEngineRevisionAnnotation]

if err = resources.GetResource(ctx, ragEngineObj.Name, ragEngineObj.Namespace, c.Client, deployment); err == nil {
klog.InfoS("An inference workload already exists for ragengine", "ragengine", klog.KObj(ragEngineObj))
return

} else if apierrors.IsNotFound(err) {
var workloadObj client.Object
// Need to create a new workload
workloadObj, err = CreatePresetRAG(ctx, ragEngineObj, revisionStr, c.Client)
if err != nil {
return
}
if err = resources.CheckResourceStatus(workloadObj, c.Client, time.Duration(10)*time.Minute); err != nil {
return
}
}

}()

if err != nil {
if updateErr := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGConditionTypeServiceStatus, metav1.ConditionFalse,
"RAGEngineServiceStatusFailed", err.Error()); updateErr != nil {
klog.ErrorS(updateErr, "failed to update ragengine status", "ragengine", klog.KObj(ragEngineObj))
return updateErr
} else {
return err
}
}

if err := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEneineConditionTypeServiceStatus, metav1.ConditionTrue,
"RAGEngineServiceSuccess", "Inference has been deployed successfully"); err != nil {
klog.ErrorS(err, "failed to update ragengine status", "ragengine", klog.KObj(ragEngineObj))
return err
}

return nil
}

func (c *RAGEngineReconciler) deleteRAGEngine(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) (reconcile.Result, error) {
klog.InfoS("deleteRAGEngine", "ragengine", klog.KObj(ragEngineObj))
err := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeDeleting, metav1.ConditionTrue, "ragengineDeleted", "ragengine is being deleted")
Expand Down
Loading
Loading