diff --git a/api/v1alpha1/ragengine_validation.go b/api/v1alpha1/ragengine_validation.go index e5ce6f284..9ffff50cf 100644 --- a/api/v1alpha1/ragengine_validation.go +++ b/api/v1alpha1/ragengine_validation.go @@ -6,8 +6,15 @@ package v1alpha1 import ( "context" "fmt" + "net/url" + "os" + "regexp" + "strings" + "github.com/kaito-project/kaito/pkg/utils" + "github.com/kaito-project/kaito/pkg/utils/consts" admissionregistrationv1 "k8s.io/api/admissionregistration/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/klog/v2" "knative.dev/pkg/apis" ) @@ -32,5 +39,88 @@ func (w *RAGEngine) validateCreate() (errs *apis.FieldError) { if w.Spec.InferenceService == nil { errs = errs.Also(apis.ErrGeneric("InferenceService must be specified", "")) } + errs = errs.Also(w.Spec.InferenceService.validateCreate()) + if w.Spec.Embedding == nil { + errs = errs.Also(apis.ErrGeneric("Embedding must be specified", "")) + return errs + } + if w.Spec.Embedding.Local == nil && w.Spec.Embedding.Remote == nil { + errs = errs.Also(apis.ErrGeneric("Either remote embedding or local embedding must be specified, not neither", "")) + } + if w.Spec.Embedding.Local != nil && w.Spec.Embedding.Remote != nil { + errs = errs.Also(apis.ErrGeneric("Either remote embedding or local embedding must be specified, but not both", "")) + } + errs = errs.Also(w.Spec.Compute.validateRAGCreate()) + if w.Spec.Embedding.Local != nil { + w.Spec.Embedding.Local.validateCreate().ViaField("embedding") + } + if w.Spec.Embedding.Remote != nil { + w.Spec.Embedding.Remote.validateCreate().ViaField("embedding") + } + + return errs +} + +func (r *ResourceSpec) validateRAGCreate() (errs *apis.FieldError) { + instanceType := string(r.InstanceType) + + skuHandler, err := utils.GetSKUHandler() + if err != nil { + errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to get SKU handler: %v", err), "instanceType")) + return errs + } + gpuConfigs := skuHandler.GetGPUConfigs() + + if _, exists := gpuConfigs[instanceType]; !exists { + provider := os.Getenv("CLOUD_PROVIDER") + // Check for other instance types pattern matches if cloud provider is Azure + if provider != consts.AzureCloudName || (!strings.HasPrefix(instanceType, N_SERIES_PREFIX) && !strings.HasPrefix(instanceType, D_SERIES_PREFIX)) { + errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported instance type %s. Supported SKUs: %s", instanceType, skuHandler.GetSupportedSKUs()), "instanceType")) + } + } + + // Validate labelSelector + if _, err := metav1.LabelSelectorAsMap(r.LabelSelector); err != nil { + errs = errs.Also(apis.ErrInvalidValue(err.Error(), "labelSelector")) + } + + return errs +} + +func (e *LocalEmbeddingSpec) validateCreate() (errs *apis.FieldError) { + if e.Image == "" && e.ModelID == "" { + errs = errs.Also(apis.ErrGeneric("Either image or modelID must be specified, not neither", "")) + } + if e.Image != "" && e.ModelID != "" { + errs = errs.Also(apis.ErrGeneric("Either image or modelID must be specified, but not both", "")) + } + if e.Image != "" { + re := regexp.MustCompile(`^(.+/[^:/]+):([^:/]+)$`) + if !re.MatchString(e.Image) { + errs = errs.Also(apis.ErrInvalidValue("Invalid image format, require full input image URL", "Image")) + } else { + // Executes if image is of correct format + err := utils.ExtractAndValidateRepoName(e.Image) + if err != nil { + errs = errs.Also(apis.ErrInvalidValue(err.Error(), "Image")) + } + } + } + return errs +} + +func (e *RemoteEmbeddingSpec) validateCreate() (errs *apis.FieldError) { + _, err := url.ParseRequestURI(e.URL) + if err != nil { + errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("URL input error: %v", err), "remote url")) + } + return errs +} + +func (e *InferenceServiceSpec) validateCreate() (errs *apis.FieldError) { + _, err := url.ParseRequestURI(e.URL) + if err != nil { + errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("URL input error: %v", err), "remote url")) + } return errs } diff --git a/api/v1alpha1/ragengine_validation_test.go b/api/v1alpha1/ragengine_validation_test.go new file mode 100644 index 000000000..37e101c64 --- /dev/null +++ b/api/v1alpha1/ragengine_validation_test.go @@ -0,0 +1,257 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package v1alpha1 + +import ( + "os" + "strings" + "testing" + + "github.com/kaito-project/kaito/pkg/utils/consts" +) + +func TestRAGEngineValidateCreate(t *testing.T) { + tests := []struct { + name string + ragEngine *RAGEngine + wantErr bool + errField string + }{ + { + name: "Both Local and Remote Embedding specified", + ragEngine: &RAGEngine{ + Spec: &RAGEngineSpec{ + Compute: &ResourceSpec{ + InstanceType: "Standard_NC12s_v3", + }, + InferenceService: &InferenceServiceSpec{URL: "http://example.com"}, + Embedding: &EmbeddingSpec{ + Local: &LocalEmbeddingSpec{ + ModelID: "BAAI/bge-small-en-v1.5", + }, + Remote: &RemoteEmbeddingSpec{URL: "http://remote-embedding.com"}, + }, + }, + }, + wantErr: true, + errField: "Either remote embedding or local embedding must be specified, but not both", + }, + { + name: "Embedding not specified", + ragEngine: &RAGEngine{ + Spec: &RAGEngineSpec{ + Compute: &ResourceSpec{ + InstanceType: "Standard_NC12s_v3", + }, + InferenceService: &InferenceServiceSpec{URL: "http://example.com"}, + }, + }, + wantErr: true, + errField: "Embedding must be specified", + }, + { + name: "None of Local and Remote Embedding specified", + ragEngine: &RAGEngine{ + Spec: &RAGEngineSpec{ + Compute: &ResourceSpec{ + InstanceType: "Standard_NC12s_v3", + }, + InferenceService: &InferenceServiceSpec{URL: "http://example.com"}, + Embedding: &EmbeddingSpec{}, + }, + }, + wantErr: true, + errField: "Either remote embedding or local embedding must be specified, not neither", + }, + { + name: "Only Local Embedding specified", + ragEngine: &RAGEngine{ + Spec: &RAGEngineSpec{ + Compute: &ResourceSpec{ + InstanceType: "Standard_NC12s_v3", + }, + InferenceService: &InferenceServiceSpec{URL: "http://example.com"}, + Embedding: &EmbeddingSpec{ + Local: &LocalEmbeddingSpec{ + ModelID: "BAAI/bge-small-en-v1.5", + }, + }, + }, + }, + wantErr: false, + }, + { + name: "Only Remote Embedding specified", + ragEngine: &RAGEngine{ + Spec: &RAGEngineSpec{ + Compute: &ResourceSpec{ + InstanceType: "Standard_NC12s_v3", + }, + InferenceService: &InferenceServiceSpec{URL: "http://example.com"}, + Embedding: &EmbeddingSpec{ + Remote: &RemoteEmbeddingSpec{URL: "http://remote-embedding.com"}, + }, + }, + }, + wantErr: false, + }, + } + os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.ragEngine.validateCreate() + hasErr := err != nil + + if hasErr != tt.wantErr { + t.Errorf("validateCreate() error = %v, wantErr %v", err, tt.wantErr) + } + + if hasErr && tt.errField != "" && !strings.Contains(err.Error(), tt.errField) { + t.Errorf("validateCreate() expected error to contain %s, but got %s", tt.errField, err.Error()) + } + }) + } +} + +func TestLocalEmbeddingValidateCreate(t *testing.T) { + tests := []struct { + name string + localEmbedding *LocalEmbeddingSpec + wantErr bool + errField string + }{ + { + name: "Neither Image nor ModelID specified", + localEmbedding: &LocalEmbeddingSpec{}, + wantErr: true, + errField: "Either image or modelID must be specified, not neither", + }, + { + name: "Both Image and ModelID specified", + localEmbedding: &LocalEmbeddingSpec{ + Image: "image-path", + ModelID: "model-id", + }, + wantErr: true, + errField: "Either image or modelID must be specified, but not both", + }, + { + name: "Invalid Image Format", + localEmbedding: &LocalEmbeddingSpec{ + Image: "invalid-image-format", + }, + wantErr: true, + errField: "Invalid image format", + }, + { + name: "Valid Image Specified", + localEmbedding: &LocalEmbeddingSpec{ + Image: "myrepo/myimage:tag", + }, + wantErr: false, + }, + { + name: "Valid ModelID Specified", + localEmbedding: &LocalEmbeddingSpec{ + ModelID: "valid-model-id", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.localEmbedding.validateCreate() + hasErr := err != nil + + if hasErr != tt.wantErr { + t.Errorf("validateCreate() error = %v, wantErr %v", err, tt.wantErr) + } + + if hasErr && tt.errField != "" && !strings.Contains(err.Error(), tt.errField) { + t.Errorf("validateCreate() expected error to contain %s, but got %s", tt.errField, err.Error()) + } + }) + } +} + +func TestRemoteEmbeddingValidateCreate(t *testing.T) { + tests := []struct { + name string + remoteEmbedding *RemoteEmbeddingSpec + wantErr bool + errField string + }{ + { + name: "Invalid URL Specified", + remoteEmbedding: &RemoteEmbeddingSpec{ + URL: "invalid-url", + }, + wantErr: true, + errField: "URL input error", + }, + { + name: "Valid URL Specified", + remoteEmbedding: &RemoteEmbeddingSpec{ + URL: "http://example.com", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.remoteEmbedding.validateCreate() + hasErr := err != nil + + if hasErr != tt.wantErr { + t.Errorf("validateCreate() error = %v, wantErr %v", err, tt.wantErr) + } + + if hasErr && tt.errField != "" && !strings.Contains(err.Error(), tt.errField) { + t.Errorf("validateCreate() expected error to contain %s, but got %s", tt.errField, err.Error()) + } + }) + } +} + +func TestInferenceServiceValidateCreate(t *testing.T) { + tests := []struct { + name string + inferenceService *InferenceServiceSpec + wantErr bool + errField string + }{ + { + name: "Invalid URL Specified", + inferenceService: &InferenceServiceSpec{ + URL: "invalid-url", + }, + wantErr: true, + errField: "URL input error", + }, + { + name: "Valid URL Specified", + inferenceService: &InferenceServiceSpec{ + URL: "http://example.com", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.inferenceService.validateCreate() + hasErr := err != nil + + if hasErr != tt.wantErr { + t.Errorf("validateCreate() error = %v, wantErr %v", err, tt.wantErr) + } + + if hasErr && tt.errField != "" && !strings.Contains(err.Error(), tt.errField) { + t.Errorf("validateCreate() expected error to contain %s, but got %s", tt.errField, err.Error()) + } + }) + } +}