From da63f9352a4159cca70f40781fca23c84045cd28 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 18 Mar 2024 10:37:08 -0700 Subject: [PATCH 1/7] feat: spec level validation --- api/v1alpha1/workspace_types.go | 4 ++-- api/v1alpha1/workspace_validation.go | 28 +++++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/api/v1alpha1/workspace_types.go b/api/v1alpha1/workspace_types.go index 2f966f647..e451e1066 100644 --- a/api/v1alpha1/workspace_types.go +++ b/api/v1alpha1/workspace_types.go @@ -13,7 +13,7 @@ const ( ModelImageAccessModePrivate ModelImageAccessMode = "private" ) -// ResourceSpec desicribes the resource requirement of running the workload. +// ResourceSpec describes the resource requirement of running the workload. // If the number of nodes in the cluster that meet the InstanceType and // LabelSelector requirements is small than the Count, controller // will provision new nodes before deploying the workload. @@ -51,7 +51,7 @@ type PresetMeta struct { // AccessMode specifies whether the containerized model image is accessible via public registry // or private registry. This field defaults to "public" if not specified. // If this field is "private", user needs to provide the private image information in PresetOptions. - // +bebuilder:default:="public" + // +kubebuilder:default:="public" // +optional AccessMode ModelImageAccessMode `json:"accessMode,omitempty"` } diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 16576f684..732526447 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -35,6 +35,7 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { if base == nil { klog.InfoS("Validate creation", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name)) errs = errs.Also( + w.validateCreate().ViaField("spec"), w.Inference.validateCreate().ViaField("inference"), w.Resource.validateCreate(w.Inference).ViaField("resource"), ) @@ -42,6 +43,7 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { klog.InfoS("Validate update", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name)) old := base.(*Workspace) errs = errs.Also( + w.validateUpdate(old).ViaField("spec"), w.Resource.validateUpdate(&old.Resource).ViaField("resource"), w.Inference.validateUpdate(&old.Inference).ViaField("inference"), ) @@ -49,6 +51,15 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { return errs } +func (w *Workspace) validateCreate() (errs *apis.FieldError) { + inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil + tuningSpecified := w.Tuning.Input != nil + if inferenceSpecified != tuningSpecified { + return errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, but not both", "")) + } + return errs +} + func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.FieldError) { var presetName string if inference.Preset != nil { @@ -96,6 +107,21 @@ func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.Field return errs } +func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { + // Check inference specified + oldInferenceSpecified := old.Inference.Preset != nil || old.Inference.Template != nil + inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil + // Check tuning specified + oldTuningSpecified := old.Tuning.Input != nil + tuningSpecified := w.Tuning.Input != nil + + // inference/tuning can be changed, but cannot be set/unset. + if (!oldInferenceSpecified && inferenceSpecified) || (!oldTuningSpecified && tuningSpecified) { + errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "spec")) + } + return errs +} + func (r *ResourceSpec) validateUpdate(old *ResourceSpec) (errs *apis.FieldError) { // We disable changing node count for now. if r.Count != nil && old.Count != nil && *r.Count != *old.Count { @@ -151,7 +177,7 @@ func (i *InferenceSpec) validateUpdate(old *InferenceSpec) (errs *apis.FieldErro if !reflect.DeepEqual(i.Preset, old.Preset) { errs = errs.Also(apis.ErrGeneric("field is immutable", "preset")) } - //inference.template can be changed, but cannot be unset. + // inference.template can be changed, but cannot be set/unset. if (i.Template != nil && old.Template == nil) || (i.Template == nil && old.Template != nil) { errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "template")) } From a4f45e6ee20222d2d43edc6f904f2a43c16b3883 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 18 Mar 2024 15:01:25 -0700 Subject: [PATCH 2/7] feat: Added validation checks for TuningSpec, DataSource, DataDestination --- api/v1alpha1/workspace_types.go | 2 +- api/v1alpha1/workspace_validation.go | 138 +++++++++++++++++++++++---- 2 files changed, 122 insertions(+), 18 deletions(-) diff --git a/api/v1alpha1/workspace_types.go b/api/v1alpha1/workspace_types.go index e451e1066..71e9f829c 100644 --- a/api/v1alpha1/workspace_types.go +++ b/api/v1alpha1/workspace_types.go @@ -106,7 +106,7 @@ type DataSource struct { // URLs specifies the links to the public data sources. E.g., files in a public github repository. // +optional URLs []string `json:"urls,omitempty"` - // The directory in the hsot that contains the data. + // The directory in the host that contains the data. // +optional HostPath string `json:"hostPath,omitempty"` // The name of the image that contains the source data. The assumption is that the source data locates in the diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 732526447..81d9353b4 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -37,6 +37,9 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { errs = errs.Also( w.validateCreate().ViaField("spec"), w.Inference.validateCreate().ViaField("inference"), + w.Tuning.validateCreate().ViaField("tuning"), + w.Tuning.Input.validateCreate().ViaField("input"), + w.Tuning.Output.validateCreate().ViaField("output"), w.Resource.validateCreate(w.Inference).ViaField("resource"), ) } else { @@ -44,8 +47,11 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { old := base.(*Workspace) errs = errs.Also( w.validateUpdate(old).ViaField("spec"), - w.Resource.validateUpdate(&old.Resource).ViaField("resource"), w.Inference.validateUpdate(&old.Inference).ViaField("inference"), + w.Tuning.validateUpdate(&old.Tuning).ViaField("tuning"), + w.Tuning.Input.validateUpdate(old.Tuning.Input).ViaField("input"), + w.Tuning.Output.validateUpdate(old.Tuning.Output).ViaField("output"), + w.Resource.validateUpdate(&old.Resource).ViaField("resource"), ) } return errs @@ -55,11 +61,124 @@ func (w *Workspace) validateCreate() (errs *apis.FieldError) { inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil tuningSpecified := w.Tuning.Input != nil if inferenceSpecified != tuningSpecified { - return errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, but not both", "")) + errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, but not both", "")) } return errs } +func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { + // Check inference specified + oldInferenceSpecified := old.Inference.Preset != nil || old.Inference.Template != nil + inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil + // Check tuning specified + oldTuningSpecified := old.Tuning.Input != nil + tuningSpecified := w.Tuning.Input != nil + + // inference/tuning can be changed, but cannot be set/unset. + if (!oldInferenceSpecified && inferenceSpecified) || (!oldTuningSpecified && tuningSpecified) { + errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "spec")) + } + return errs +} + +func (r *TuningSpec) validateCreate() (errs *apis.FieldError) { + if r.Input == nil { + errs = errs.Also(apis.ErrMissingField("Input")) + } + if r.Output == nil { + errs = errs.Also(apis.ErrMissingField("Output")) + } + // Currently require a preset to specified, in future we can consider defining a template + if r.Preset == nil { + errs = errs.Also(apis.ErrMissingField("Preset")) + } + // TODO: We have to register training plugins and check if it preset exists in plugins here + methodLowerCase := strings.ToLower(string(r.Method)) + if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) { + errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method")) + } + return errs +} + +func (r *TuningSpec) validateUpdate(old *TuningSpec) (errs *apis.FieldError) { + if !reflect.DeepEqual(old.Input, r.Input) { + errs = errs.Also(apis.ErrGeneric("Input field cannot be changed", "Input")) + } + if !reflect.DeepEqual(old.Output, r.Output) { + errs = errs.Also(apis.ErrGeneric("Output field cannot be changed", "Output")) + } + if !reflect.DeepEqual(old.Preset, r.Preset) { + errs = errs.Also(apis.ErrGeneric("Preset cannot be changed", "Preset")) + } + // We will have to consider supporting tuning method and config fields changing + methodLowerCase := strings.ToLower(string(r.Method)) + if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) { + errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method")) + } + return errs +} + +func (r *DataSource) validateCreate() (errs *apis.FieldError) { + sourcesSpecified := 0 + if len(r.URLs) > 0 { + sourcesSpecified++ + } + if r.HostPath != "" { + sourcesSpecified++ + } + if r.Image != "" { + sourcesSpecified++ + } + + // Ensure exactly one of URLs, HostPath, or Image is specified + if sourcesSpecified != 1 { + errs = errs.Also(apis.ErrGeneric("Exactly one of URLs, HostPath, or Image must be specified", "URLs", "HostPath", "Image")) + } + + return errs +} + +func (r *DataSource) validateUpdate(old *DataSource) (errs *apis.FieldError) { + if !reflect.DeepEqual(old.URLs, r.URLs) { + errs = errs.Also(apis.ErrInvalidValue("URLs field cannot be changed once set", "URLs")) + } + if old.HostPath != r.HostPath { + errs = errs.Also(apis.ErrInvalidValue("HostPath field cannot be changed once set", "HostPath")) + } + if old.Image != r.Image { + errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image")) + } + // TODO: Ensure ImageSecrets can be changed + return errs +} + +func (r *DataDestination) validateCreate() (errs *apis.FieldError) { + destinationsSpecified := 0 + if r.HostPath != "" { + destinationsSpecified++ + } + if r.Image != "" { + destinationsSpecified++ + } + + // If no destination is specified, return an error + if destinationsSpecified == 0 { + errs = errs.Also(apis.ErrMissingField("At least one of HostPath or Image must be specified")) + } + return errs +} + +func (r *DataDestination) validateUpdate(old *DataDestination) (errs *apis.FieldError) { + if old.HostPath != r.HostPath { + errs = errs.Also(apis.ErrInvalidValue("HostPath field cannot be changed once set", "HostPath")) + } + if old.Image != r.Image { + errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image")) + } + // TODO: Ensure ImageSecrets can be changed + return errs +} + func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.FieldError) { var presetName string if inference.Preset != nil { @@ -107,21 +226,6 @@ func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.Field return errs } -func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { - // Check inference specified - oldInferenceSpecified := old.Inference.Preset != nil || old.Inference.Template != nil - inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil - // Check tuning specified - oldTuningSpecified := old.Tuning.Input != nil - tuningSpecified := w.Tuning.Input != nil - - // inference/tuning can be changed, but cannot be set/unset. - if (!oldInferenceSpecified && inferenceSpecified) || (!oldTuningSpecified && tuningSpecified) { - errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "spec")) - } - return errs -} - func (r *ResourceSpec) validateUpdate(old *ResourceSpec) (errs *apis.FieldError) { // We disable changing node count for now. if r.Count != nil && old.Count != nil && *r.Count != *old.Count { From a9bbe7ad0e3867c29c1648f60733aacad29765fa Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 18 Mar 2024 15:08:46 -0700 Subject: [PATCH 3/7] fix: prevent toggling --- api/v1alpha1/workspace_validation.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 81d9353b4..39fbf87b7 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -73,10 +73,12 @@ func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { // Check tuning specified oldTuningSpecified := old.Tuning.Input != nil tuningSpecified := w.Tuning.Input != nil + if (!oldInferenceSpecified && inferenceSpecified) || (oldInferenceSpecified && !inferenceSpecified) { + errs = errs.Also(apis.ErrGeneric("Inference field cannot be toggled once set", "inference")) + } - // inference/tuning can be changed, but cannot be set/unset. - if (!oldInferenceSpecified && inferenceSpecified) || (!oldTuningSpecified && tuningSpecified) { - errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "spec")) + if (!oldTuningSpecified && tuningSpecified) || (oldTuningSpecified && !tuningSpecified) { + errs = errs.Also(apis.ErrGeneric("Tuning field cannot be toggled once set", "tuning")) } return errs } From d73ef65ec6846dd4a055062935d9a788a05397ce Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 18 Mar 2024 18:06:49 -0700 Subject: [PATCH 4/7] fix: validation fixes --- api/v1alpha1/workspace_validation.go | 47 ++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 39fbf87b7..9034deae3 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "reflect" + "sort" "strings" "github.com/azure/kaito/pkg/utils/plugin" @@ -60,7 +61,10 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { func (w *Workspace) validateCreate() (errs *apis.FieldError) { inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil tuningSpecified := w.Tuning.Input != nil - if inferenceSpecified != tuningSpecified { + if !inferenceSpecified && !tuningSpecified { + errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, not neither", "")) + } + if inferenceSpecified && tuningSpecified { errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, but not both", "")) } return errs @@ -93,8 +97,9 @@ func (r *TuningSpec) validateCreate() (errs *apis.FieldError) { // Currently require a preset to specified, in future we can consider defining a template if r.Preset == nil { errs = errs.Also(apis.ErrMissingField("Preset")) + } else if presetName := string(r.Preset.Name); !isValidPreset(presetName) { + errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported tuning preset name %s", presetName), "presetName")) } - // TODO: We have to register training plugins and check if it preset exists in plugins here methodLowerCase := strings.ToLower(string(r.Method)) if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) { errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method")) @@ -112,11 +117,11 @@ func (r *TuningSpec) validateUpdate(old *TuningSpec) (errs *apis.FieldError) { if !reflect.DeepEqual(old.Preset, r.Preset) { errs = errs.Also(apis.ErrGeneric("Preset cannot be changed", "Preset")) } - // We will have to consider supporting tuning method and config fields changing - methodLowerCase := strings.ToLower(string(r.Method)) - if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) { - errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method")) + oldMethod, newMethod := strings.ToLower(string(old.Method)), strings.ToLower(string(r.Method)) + if !reflect.DeepEqual(oldMethod, newMethod) { + errs = errs.Also(apis.ErrGeneric("Method cannot be changed", "Method")) } + // Consider supporting config fields changing return errs } @@ -141,7 +146,15 @@ func (r *DataSource) validateCreate() (errs *apis.FieldError) { } func (r *DataSource) validateUpdate(old *DataSource) (errs *apis.FieldError) { - if !reflect.DeepEqual(old.URLs, r.URLs) { + oldURLs := make([]string, len(old.URLs)) + copy(oldURLs, old.URLs) + sort.Strings(old.URLs) + + newURLs := make([]string, len(r.URLs)) + copy(newURLs, r.URLs) + sort.Strings(r.URLs) + + if !reflect.DeepEqual(oldURLs, newURLs) { errs = errs.Also(apis.ErrInvalidValue("URLs field cannot be changed once set", "URLs")) } if old.HostPath != r.HostPath { @@ -150,7 +163,18 @@ func (r *DataSource) validateUpdate(old *DataSource) (errs *apis.FieldError) { if old.Image != r.Image { errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image")) } - // TODO: Ensure ImageSecrets can be changed + + oldSecrets := make([]string, len(old.ImagePullSecrets)) + copy(oldSecrets, old.ImagePullSecrets) + sort.Strings(oldSecrets) + + newSecrets := make([]string, len(r.ImagePullSecrets)) + copy(newSecrets, r.ImagePullSecrets) + sort.Strings(newSecrets) + + if !reflect.DeepEqual(oldSecrets, newSecrets) { + errs = errs.Also(apis.ErrInvalidValue("ImagePullSecrets field cannot be changed once set", "ImagePullSecrets")) + } return errs } @@ -177,7 +201,10 @@ func (r *DataDestination) validateUpdate(old *DataDestination) (errs *apis.Field if old.Image != r.Image { errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image")) } - // TODO: Ensure ImageSecrets can be changed + + if old.ImagePushSecret != r.ImagePushSecret { + errs = errs.Also(apis.ErrInvalidValue("ImagePushSecret field cannot be changed once set", "ImagePushSecret")) + } return errs } @@ -263,7 +290,7 @@ func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) { presetName := string(i.Preset.Name) // Validate preset name if !isValidPreset(presetName) { - errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported preset name %s", presetName), "presetName")) + errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported inference preset name %s", presetName), "presetName")) } // Validate private preset has private image specified if plugin.KaitoModelRegister.MustGet(string(i.Preset.Name)).GetInferenceParameters().ImageAccessMode == "private" && From 3fa0e46f1096bd6ed4cb309c58a365b8166fe18c Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Tue, 19 Mar 2024 11:50:00 -0700 Subject: [PATCH 5/7] feat: Add UTs for workspace validation --- api/v1alpha1/workspace_validation.go | 4 +- api/v1alpha1/workspace_validation_test.go | 615 ++++++++++++++++++++++ 2 files changed, 617 insertions(+), 2 deletions(-) diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 9034deae3..5a8269353 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -148,11 +148,11 @@ func (r *DataSource) validateCreate() (errs *apis.FieldError) { func (r *DataSource) validateUpdate(old *DataSource) (errs *apis.FieldError) { oldURLs := make([]string, len(old.URLs)) copy(oldURLs, old.URLs) - sort.Strings(old.URLs) + sort.Strings(oldURLs) newURLs := make([]string, len(r.URLs)) copy(newURLs, r.URLs) - sort.Strings(r.URLs) + sort.Strings(newURLs) if !reflect.DeepEqual(oldURLs, newURLs) { errs = errs.Also(apis.ErrInvalidValue("URLs field cannot be changed once set", "URLs")) diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index 0a3fa2de1..6c2f1a650 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -488,6 +488,621 @@ func TestInferenceSpecValidateUpdate(t *testing.T) { } } +func TestWorkspaceValidateCreate(t *testing.T) { + tests := []struct { + name string + workspace *Workspace + wantErr bool + errField string + }{ + { + name: "Neither Inference nor Tuning specified", + workspace: &Workspace{ + Inference: InferenceSpec{}, + Tuning: TuningSpec{}, + }, + wantErr: true, + errField: "neither", + }, + { + name: "Both Inference and Tuning specified", + workspace: &Workspace{ + Inference: InferenceSpec{Preset: &PresetSpec{}}, + Tuning: TuningSpec{Input: &DataSource{}}, + }, + wantErr: true, + errField: "both", + }, + { + name: "Only Inference specified", + workspace: &Workspace{ + Inference: InferenceSpec{Preset: &PresetSpec{}}, + }, + wantErr: false, + errField: "", + }, + { + name: "Only Tuning specified", + workspace: &Workspace{ + Tuning: TuningSpec{Input: &DataSource{}}, + }, + wantErr: false, + errField: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.workspace.validateCreate() + if (errs != nil) != tt.wantErr { + t.Errorf("validateCreate() error = %v, wantErr %v", errs, tt.wantErr) + } + if errs != nil && !strings.Contains(errs.Error(), tt.errField) { + t.Errorf("validateCreate() expected error to contain field %s, but got %s", tt.errField, errs.Error()) + } + }) + } +} + +func TestWorkspaceValidateUpdate(t *testing.T) { + tests := []struct { + name string + oldWorkspace *Workspace + newWorkspace *Workspace + expectErrs bool + errFields []string // Fields we expect to have errors + }{ + { + name: "Inference toggled on", + oldWorkspace: &Workspace{ + Inference: InferenceSpec{}, + }, + newWorkspace: &Workspace{ + Inference: InferenceSpec{Preset: &PresetSpec{}}, + }, + expectErrs: true, + errFields: []string{"inference"}, + }, + { + name: "Inference toggled off", + oldWorkspace: &Workspace{ + Inference: InferenceSpec{Preset: &PresetSpec{}}, + }, + newWorkspace: &Workspace{ + Inference: InferenceSpec{}, + }, + expectErrs: true, + errFields: []string{"inference"}, + }, + { + name: "Tuning toggled on", + oldWorkspace: &Workspace{ + Tuning: TuningSpec{}, + }, + newWorkspace: &Workspace{ + Tuning: TuningSpec{Input: &DataSource{}}, + }, + expectErrs: true, + errFields: []string{"tuning"}, + }, + { + name: "Tuning toggled off", + oldWorkspace: &Workspace{ + Tuning: TuningSpec{Input: &DataSource{}}, + }, + newWorkspace: &Workspace{ + Tuning: TuningSpec{}, + }, + expectErrs: true, + errFields: []string{"tuning"}, + }, + { + name: "No toggling", + oldWorkspace: &Workspace{ + Tuning: TuningSpec{Input: &DataSource{}}, + }, + newWorkspace: &Workspace{ + Tuning: TuningSpec{Input: &DataSource{}}, + }, + expectErrs: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.newWorkspace.validateUpdate(tt.oldWorkspace) + hasErrs := errs != nil + + if hasErrs != tt.expectErrs { + t.Errorf("validateUpdate() errors = %v, expectErrs %v", errs, tt.expectErrs) + } + + if hasErrs { + for _, field := range tt.errFields { + if !strings.Contains(errs.Error(), field) { + t.Errorf("validateUpdate() expected errors to contain field %s, but got %s", field, errs.Error()) + } + } + } + }) + } +} + +func TestTuningSpecValidateCreate(t *testing.T) { + RegisterValidationTestModels() + tests := []struct { + name string + tuningSpec *TuningSpec + wantErr bool + errFields []string // Fields we expect to have errors + }{ + { + name: "All fields valid", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input"}, + Output: &DataDestination{HostPath: "valid-output"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodLora, + }, + wantErr: false, + errFields: nil, + }, + { + name: "Missing Input", + tuningSpec: &TuningSpec{ + Output: &DataDestination{HostPath: "valid-output"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodLora, + }, + wantErr: true, + errFields: []string{"Input"}, + }, + { + name: "Missing Output", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodLora, + }, + wantErr: true, + errFields: []string{"Output"}, + }, + { + name: "Missing Preset", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input"}, + Output: &DataDestination{HostPath: "valid-output"}, + Method: TuningMethodLora, + }, + wantErr: true, + errFields: []string{"Preset"}, + }, + { + name: "Invalid Preset", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input"}, + Output: &DataDestination{HostPath: "valid-output"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("invalid-preset")}}, + Method: TuningMethodLora, + }, + wantErr: true, + errFields: []string{"presetName"}, + }, + { + name: "Invalid Method", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input"}, + Output: &DataDestination{HostPath: "valid-output"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: "invalid-method", + }, + wantErr: true, + errFields: []string{"Method"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.tuningSpec.validateCreate() + hasErrs := errs != nil + + if hasErrs != tt.wantErr { + t.Errorf("validateCreate() errors = %v, wantErr %v", errs, tt.wantErr) + } + + if hasErrs { + for _, field := range tt.errFields { + if !strings.Contains(errs.Error(), field) { + t.Errorf("validateCreate() expected errors to contain field %s, but got %s", field, errs.Error()) + } + } + } + }) + } +} + +func TestTuningSpecValidateUpdate(t *testing.T) { + RegisterValidationTestModels() + tests := []struct { + name string + oldTuning *TuningSpec + newTuning *TuningSpec + expectErrs bool + errFields []string // Fields we expect to have errors + }{ + { + name: "No changes", + oldTuning: &TuningSpec{ + Input: &DataSource{Name: "input1"}, + Output: &DataDestination{HostPath: "path1"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodLora, + }, + newTuning: &TuningSpec{ + Input: &DataSource{Name: "input1"}, + Output: &DataDestination{HostPath: "path1"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodLora, + }, + expectErrs: false, + }, + { + name: "Input changed", + oldTuning: &TuningSpec{ + Input: &DataSource{Name: "input1"}, + }, + newTuning: &TuningSpec{ + Input: &DataSource{Name: "input2"}, + }, + expectErrs: true, + errFields: []string{"Input"}, + }, + { + name: "Output changed", + oldTuning: &TuningSpec{ + Output: &DataDestination{HostPath: "path1"}, + }, + newTuning: &TuningSpec{ + Output: &DataDestination{HostPath: "path2"}, + }, + expectErrs: true, + errFields: []string{"Output"}, + }, + { + name: "Preset changed", + oldTuning: &TuningSpec{ + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + }, + newTuning: &TuningSpec{ + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("invalid-preset")}}, + }, + expectErrs: true, + errFields: []string{"Preset"}, + }, + { + name: "Method changed", + oldTuning: &TuningSpec{ + Method: TuningMethodLora, + }, + newTuning: &TuningSpec{ + Method: TuningMethodQLora, + }, + expectErrs: true, + errFields: []string{"Method"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.newTuning.validateUpdate(tt.oldTuning) + hasErrs := errs != nil + + if hasErrs != tt.expectErrs { + t.Errorf("validateUpdate() errors = %v, expectErrs %v", errs, tt.expectErrs) + } + + if hasErrs { + for _, field := range tt.errFields { + if !strings.Contains(errs.Error(), field) { + t.Errorf("validateUpdate() expected errors to contain field %s, but got %s", field, errs.Error()) + } + } + } + }) + } +} + +func TestDataSourceValidateCreate(t *testing.T) { + tests := []struct { + name string + dataSource *DataSource + wantErr bool + errField string // The field we expect to have an error on + }{ + { + name: "URLs specified only", + dataSource: &DataSource{ + URLs: []string{"http://example.com/data"}, + }, + wantErr: false, + }, + { + name: "HostPath specified only", + dataSource: &DataSource{ + HostPath: "/data/path", + }, + wantErr: false, + }, + { + name: "Image specified only", + dataSource: &DataSource{ + Image: "data-image:latest", + }, + wantErr: false, + }, + { + name: "None specified", + dataSource: &DataSource{}, + wantErr: true, + errField: "Exactly one of URLs, HostPath, or Image must be specified", + }, + { + name: "URLs and HostPath specified", + dataSource: &DataSource{ + URLs: []string{"http://example.com/data"}, + HostPath: "/data/path", + }, + wantErr: true, + errField: "Exactly one of URLs, HostPath, or Image must be specified", + }, + { + name: "All fields specified", + dataSource: &DataSource{ + URLs: []string{"http://example.com/data"}, + HostPath: "/data/path", + Image: "data-image:latest", + }, + wantErr: true, + errField: "Exactly one of URLs, HostPath, or Image must be specified", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.dataSource.validateCreate() + hasErrs := errs != nil + + if hasErrs != tt.wantErr { + t.Errorf("validateCreate() error = %v, wantErr %v", errs, tt.wantErr) + } + + if hasErrs && tt.errField != "" && !strings.Contains(errs.Error(), tt.errField) { + t.Errorf("validateCreate() expected error to contain %s, but got %s", tt.errField, errs.Error()) + } + }) + } +} + +func TestDataSourceValidateUpdate(t *testing.T) { + tests := []struct { + name string + oldSource *DataSource + newSource *DataSource + wantErr bool + errFields []string // Fields we expect to have errors + }{ + { + name: "No changes", + oldSource: &DataSource{ + URLs: []string{"http://example.com/data1", "http://example.com/data2"}, + HostPath: "/data/path", + Image: "data-image:latest", + ImagePullSecrets: []string{"secret1", "secret2"}, + }, + newSource: &DataSource{ + URLs: []string{"http://example.com/data2", "http://example.com/data1"}, // Note the different order, should not matter + HostPath: "/data/path", + Image: "data-image:latest", + ImagePullSecrets: []string{"secret2", "secret1"}, // Note the different order, should not matter + }, + wantErr: false, + }, + { + name: "URLs changed", + oldSource: &DataSource{ + URLs: []string{"http://example.com/old"}, + }, + newSource: &DataSource{ + URLs: []string{"http://example.com/new"}, + }, + wantErr: true, + errFields: []string{"URLs"}, + }, + { + name: "HostPath changed", + oldSource: &DataSource{ + HostPath: "/old/path", + }, + newSource: &DataSource{ + HostPath: "/new/path", + }, + wantErr: true, + errFields: []string{"HostPath"}, + }, + { + name: "Image changed", + oldSource: &DataSource{ + Image: "old-image:latest", + }, + newSource: &DataSource{ + Image: "new-image:latest", + }, + wantErr: true, + errFields: []string{"Image"}, + }, + { + name: "ImagePullSecrets changed", + oldSource: &DataSource{ + ImagePullSecrets: []string{"old-secret"}, + }, + newSource: &DataSource{ + ImagePullSecrets: []string{"new-secret"}, + }, + wantErr: true, + errFields: []string{"ImagePullSecrets"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.newSource.validateUpdate(tt.oldSource) + hasErrs := errs != nil + + if hasErrs != tt.wantErr { + t.Errorf("validateUpdate() error = %v, wantErr %v", errs, tt.wantErr) + } + + if hasErrs { + for _, field := range tt.errFields { + if !strings.Contains(errs.Error(), field) { + t.Errorf("validateUpdate() expected errors to contain field %s, but got %s", field, errs.Error()) + } + } + } + }) + } +} + +func TestDataDestinationValidateCreate(t *testing.T) { + tests := []struct { + name string + dataDestination *DataDestination + wantErr bool + errField string // The field we expect to have an error on + }{ + { + name: "No fields specified", + dataDestination: &DataDestination{}, + wantErr: true, + errField: "At least one of HostPath or Image must be specified", + }, + { + name: "HostPath specified only", + dataDestination: &DataDestination{ + HostPath: "/data/path", + }, + wantErr: false, + }, + { + name: "Image specified only", + dataDestination: &DataDestination{ + Image: "data-image:latest", + }, + wantErr: false, + }, + { + name: "Both fields specified", + dataDestination: &DataDestination{ + HostPath: "/data/path", + Image: "data-image:latest", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.dataDestination.validateCreate() + hasErrs := errs != nil + + if hasErrs != tt.wantErr { + t.Errorf("validateCreate() error = %v, wantErr %v", errs, tt.wantErr) + } + + if hasErrs && tt.errField != "" && !strings.Contains(errs.Error(), tt.errField) { + t.Errorf("validateCreate() expected error to contain %s, but got %s", tt.errField, errs.Error()) + } + }) + } +} + +func TestDataDestinationValidateUpdate(t *testing.T) { + tests := []struct { + name string + oldDest *DataDestination + newDest *DataDestination + wantErr bool + errFields []string // Fields we expect to have errors + }{ + { + name: "No changes", + oldDest: &DataDestination{ + HostPath: "/data/old", + Image: "old-image:latest", + ImagePushSecret: "old-secret", + }, + newDest: &DataDestination{ + HostPath: "/data/old", + Image: "old-image:latest", + ImagePushSecret: "old-secret", + }, + wantErr: false, + }, + { + name: "HostPath changed", + oldDest: &DataDestination{ + HostPath: "/data/old", + }, + newDest: &DataDestination{ + HostPath: "/data/new", + }, + wantErr: true, + errFields: []string{"HostPath"}, + }, + { + name: "Image changed", + oldDest: &DataDestination{ + Image: "old-image:latest", + }, + newDest: &DataDestination{ + Image: "new-image:latest", + }, + wantErr: true, + errFields: []string{"Image"}, + }, + { + name: "ImagePushSecret changed", + oldDest: &DataDestination{ + ImagePushSecret: "old-secret", + }, + newDest: &DataDestination{ + ImagePushSecret: "new-secret", + }, + wantErr: true, + errFields: []string{"ImagePushSecret"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.newDest.validateUpdate(tt.oldDest) + hasErrs := errs != nil + + if hasErrs != tt.wantErr { + t.Errorf("validateUpdate() error = %v, wantErr %v", errs, tt.wantErr) + } + + if hasErrs { + for _, field := range tt.errFields { + if !strings.Contains(errs.Error(), field) { + t.Errorf("validateUpdate() expected errors to contain field %s, but got %s", field, errs.Error()) + } + } + } + }) + } +} + func TestGetSupportedSKUs(t *testing.T) { tests := []struct { name string From 392ff401d4537ede687bac8393f8195abb80fa50 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Tue, 19 Mar 2024 19:07:09 -0700 Subject: [PATCH 6/7] fix: Update CRD to use pointers --- api/v1alpha1/workspace_types.go | 8 +- api/v1alpha1/workspace_validation.go | 55 ++++--- api/v1alpha1/workspace_validation_test.go | 69 ++++---- api/v1alpha1/zz_generated.deepcopy.go | 12 +- .../workspace/crds/kaito.sh_workspaces.yaml | 147 +++++++++++++++++- config/crd/bases/kaito.sh_workspaces.yaml | 11 +- pkg/utils/testUtils.go | 6 +- test/e2e/preset_test.go | 19 ++- test/e2e/utils/utils.go | 30 ++-- 9 files changed, 252 insertions(+), 105 deletions(-) diff --git a/api/v1alpha1/workspace_types.go b/api/v1alpha1/workspace_types.go index 71e9f829c..4484b8250 100644 --- a/api/v1alpha1/workspace_types.go +++ b/api/v1alpha1/workspace_types.go @@ -150,9 +150,9 @@ type TuningSpec struct { // +optional Config string `json:"config,omitempty"` // Input describes the input used by the tuning method. - Input *DataSource `json:"input,omitempty"` + Input *DataSource `json:"input"` // Output specified where to store the tuning output. - Output *DataDestination `json:"output,omitempty"` + Output *DataDestination `json:"output"` } // WorkspaceStatus defines the observed state of Workspace @@ -181,8 +181,8 @@ type Workspace struct { metav1.ObjectMeta `json:"metadata,omitempty"` Resource ResourceSpec `json:"resource,omitempty"` - Inference InferenceSpec `json:"inference,omitempty"` - Tuning TuningSpec `json:"tuning,omitempty"` + Inference *InferenceSpec `json:"inference,omitempty"` + Tuning *TuningSpec `json:"tuning,omitempty"` Status WorkspaceStatus `json:"status,omitempty"` } diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 5a8269353..b135f5886 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -37,51 +37,48 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { klog.InfoS("Validate creation", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name)) errs = errs.Also( w.validateCreate().ViaField("spec"), - w.Inference.validateCreate().ViaField("inference"), - w.Tuning.validateCreate().ViaField("tuning"), - w.Tuning.Input.validateCreate().ViaField("input"), - w.Tuning.Output.validateCreate().ViaField("output"), - w.Resource.validateCreate(w.Inference).ViaField("resource"), + // TODO: Consider validate resource based on Tuning Spec + w.Resource.validateCreate(*w.Inference).ViaField("resource"), ) + if w.Inference != nil { + errs = errs.Also(w.Inference.validateCreate().ViaField("inference")) + } + if w.Tuning != nil { + errs = errs.Also(w.Tuning.validateCreate().ViaField("tuning")) + } } else { klog.InfoS("Validate update", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name)) old := base.(*Workspace) errs = errs.Also( w.validateUpdate(old).ViaField("spec"), - w.Inference.validateUpdate(&old.Inference).ViaField("inference"), - w.Tuning.validateUpdate(&old.Tuning).ViaField("tuning"), - w.Tuning.Input.validateUpdate(old.Tuning.Input).ViaField("input"), - w.Tuning.Output.validateUpdate(old.Tuning.Output).ViaField("output"), w.Resource.validateUpdate(&old.Resource).ViaField("resource"), ) + if w.Inference != nil { + errs = errs.Also(w.Inference.validateUpdate(old.Inference).ViaField("inference")) + } + if w.Tuning != nil { + errs = errs.Also(w.Tuning.validateUpdate(old.Tuning).ViaField("tuning")) + } } return errs } func (w *Workspace) validateCreate() (errs *apis.FieldError) { - inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil - tuningSpecified := w.Tuning.Input != nil - if !inferenceSpecified && !tuningSpecified { + if w.Inference == nil && w.Tuning == nil { errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, not neither", "")) } - if inferenceSpecified && tuningSpecified { + if w.Inference != nil && w.Tuning != nil { errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, but not both", "")) } return errs } func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { - // Check inference specified - oldInferenceSpecified := old.Inference.Preset != nil || old.Inference.Template != nil - inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil - // Check tuning specified - oldTuningSpecified := old.Tuning.Input != nil - tuningSpecified := w.Tuning.Input != nil - if (!oldInferenceSpecified && inferenceSpecified) || (oldInferenceSpecified && !inferenceSpecified) { + if (old.Inference == nil && w.Inference != nil) || (old.Inference != nil && w.Inference == nil) { errs = errs.Also(apis.ErrGeneric("Inference field cannot be toggled once set", "inference")) } - if (!oldTuningSpecified && tuningSpecified) || (oldTuningSpecified && !tuningSpecified) { + if (old.Tuning == nil && w.Tuning != nil) || (old.Tuning != nil && w.Tuning == nil) { errs = errs.Also(apis.ErrGeneric("Tuning field cannot be toggled once set", "tuning")) } return errs @@ -90,9 +87,13 @@ func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { func (r *TuningSpec) validateCreate() (errs *apis.FieldError) { if r.Input == nil { errs = errs.Also(apis.ErrMissingField("Input")) + } else { + errs = errs.Also(r.Input.validateCreate().ViaField("Input")) } if r.Output == nil { errs = errs.Also(apis.ErrMissingField("Output")) + } else { + errs = errs.Also(r.Output.validateCreate().ViaField("Output")) } // Currently require a preset to specified, in future we can consider defining a template if r.Preset == nil { @@ -108,11 +109,15 @@ func (r *TuningSpec) validateCreate() (errs *apis.FieldError) { } func (r *TuningSpec) validateUpdate(old *TuningSpec) (errs *apis.FieldError) { - if !reflect.DeepEqual(old.Input, r.Input) { - errs = errs.Also(apis.ErrGeneric("Input field cannot be changed", "Input")) + if r.Input == nil { + errs = errs.Also(apis.ErrMissingField("Input")) + } else { + errs = errs.Also(r.Input.validateUpdate(old.Input).ViaField("Input")) } - if !reflect.DeepEqual(old.Output, r.Output) { - errs = errs.Also(apis.ErrGeneric("Output field cannot be changed", "Output")) + if r.Output == nil { + errs = errs.Also(apis.ErrMissingField("Output")) + } else { + errs = errs.Also(r.Output.validateUpdate(old.Output).ViaField("Output")) } if !reflect.DeepEqual(old.Preset, r.Preset) { errs = errs.Also(apis.ErrGeneric("Preset cannot be changed", "Preset")) diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index 6c2f1a650..d1cea034d 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -496,19 +496,16 @@ func TestWorkspaceValidateCreate(t *testing.T) { errField string }{ { - name: "Neither Inference nor Tuning specified", - workspace: &Workspace{ - Inference: InferenceSpec{}, - Tuning: TuningSpec{}, - }, - wantErr: true, - errField: "neither", + name: "Neither Inference nor Tuning specified", + workspace: &Workspace{}, + wantErr: true, + errField: "neither", }, { name: "Both Inference and Tuning specified", workspace: &Workspace{ - Inference: InferenceSpec{Preset: &PresetSpec{}}, - Tuning: TuningSpec{Input: &DataSource{}}, + Inference: &InferenceSpec{}, + Tuning: &TuningSpec{}, }, wantErr: true, errField: "both", @@ -516,7 +513,7 @@ func TestWorkspaceValidateCreate(t *testing.T) { { name: "Only Inference specified", workspace: &Workspace{ - Inference: InferenceSpec{Preset: &PresetSpec{}}, + Inference: &InferenceSpec{}, }, wantErr: false, errField: "", @@ -524,7 +521,7 @@ func TestWorkspaceValidateCreate(t *testing.T) { { name: "Only Tuning specified", workspace: &Workspace{ - Tuning: TuningSpec{Input: &DataSource{}}, + Tuning: &TuningSpec{Input: &DataSource{}}, }, wantErr: false, errField: "", @@ -553,12 +550,10 @@ func TestWorkspaceValidateUpdate(t *testing.T) { errFields []string // Fields we expect to have errors }{ { - name: "Inference toggled on", - oldWorkspace: &Workspace{ - Inference: InferenceSpec{}, - }, + name: "Inference toggled on", + oldWorkspace: &Workspace{}, newWorkspace: &Workspace{ - Inference: InferenceSpec{Preset: &PresetSpec{}}, + Inference: &InferenceSpec{}, }, expectErrs: true, errFields: []string{"inference"}, @@ -566,21 +561,17 @@ func TestWorkspaceValidateUpdate(t *testing.T) { { name: "Inference toggled off", oldWorkspace: &Workspace{ - Inference: InferenceSpec{Preset: &PresetSpec{}}, - }, - newWorkspace: &Workspace{ - Inference: InferenceSpec{}, + Inference: &InferenceSpec{Preset: &PresetSpec{}}, }, - expectErrs: true, - errFields: []string{"inference"}, + newWorkspace: &Workspace{}, + expectErrs: true, + errFields: []string{"inference"}, }, { - name: "Tuning toggled on", - oldWorkspace: &Workspace{ - Tuning: TuningSpec{}, - }, + name: "Tuning toggled on", + oldWorkspace: &Workspace{}, newWorkspace: &Workspace{ - Tuning: TuningSpec{Input: &DataSource{}}, + Tuning: &TuningSpec{Input: &DataSource{}}, }, expectErrs: true, errFields: []string{"tuning"}, @@ -588,21 +579,19 @@ func TestWorkspaceValidateUpdate(t *testing.T) { { name: "Tuning toggled off", oldWorkspace: &Workspace{ - Tuning: TuningSpec{Input: &DataSource{}}, - }, - newWorkspace: &Workspace{ - Tuning: TuningSpec{}, + Tuning: &TuningSpec{Input: &DataSource{}}, }, - expectErrs: true, - errFields: []string{"tuning"}, + newWorkspace: &Workspace{}, + expectErrs: true, + errFields: []string{"tuning"}, }, { name: "No toggling", oldWorkspace: &Workspace{ - Tuning: TuningSpec{Input: &DataSource{}}, + Tuning: &TuningSpec{Input: &DataSource{}}, }, newWorkspace: &Workspace{ - Tuning: TuningSpec{Input: &DataSource{}}, + Tuning: &TuningSpec{Input: &DataSource{}}, }, expectErrs: false, }, @@ -639,7 +628,7 @@ func TestTuningSpecValidateCreate(t *testing.T) { { name: "All fields valid", tuningSpec: &TuningSpec{ - Input: &DataSource{Name: "valid-input"}, + Input: &DataSource{Name: "valid-input", HostPath: "valid-input"}, Output: &DataDestination{HostPath: "valid-output"}, Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, Method: TuningMethodLora, @@ -749,13 +738,15 @@ func TestTuningSpecValidateUpdate(t *testing.T) { { name: "Input changed", oldTuning: &TuningSpec{ - Input: &DataSource{Name: "input1"}, + Input: &DataSource{Name: "input", HostPath: "inputpath"}, + Output: &DataDestination{HostPath: "outputpath"}, }, newTuning: &TuningSpec{ - Input: &DataSource{Name: "input2"}, + Input: &DataSource{Name: "input", HostPath: "randompath"}, + Output: &DataDestination{HostPath: "outputpath"}, }, expectErrs: true, - errFields: []string{"Input"}, + errFields: []string{"HostPath"}, }, { name: "Output changed", diff --git a/api/v1alpha1/zz_generated.deepcopy.go b/api/v1alpha1/zz_generated.deepcopy.go index a9d662c0f..6c3ee1eb8 100644 --- a/api/v1alpha1/zz_generated.deepcopy.go +++ b/api/v1alpha1/zz_generated.deepcopy.go @@ -249,8 +249,16 @@ func (in *Workspace) DeepCopyInto(out *Workspace) { out.TypeMeta = in.TypeMeta in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) in.Resource.DeepCopyInto(&out.Resource) - in.Inference.DeepCopyInto(&out.Inference) - in.Tuning.DeepCopyInto(&out.Tuning) + if in.Inference != nil { + in, out := &in.Inference, &out.Inference + *out = new(InferenceSpec) + (*in).DeepCopyInto(*out) + } + if in.Tuning != nil { + in, out := &in.Tuning, &out.Tuning + *out = new(TuningSpec) + (*in).DeepCopyInto(*out) + } in.Status.DeepCopyInto(&out.Status) } diff --git a/charts/kaito/workspace/crds/kaito.sh_workspaces.yaml b/charts/kaito/workspace/crds/kaito.sh_workspaces.yaml index 40908f609..a4103a897 100644 --- a/charts/kaito/workspace/crds/kaito.sh_workspaces.yaml +++ b/charts/kaito/workspace/crds/kaito.sh_workspaces.yaml @@ -47,11 +47,56 @@ spec: type: string inference: properties: + adapters: + description: Adapters are integrated into the base model for inference. + Users can specify multiple adapters for the model and the respective + weight of using each of them. + items: + properties: + source: + description: Source describes where to obtain the adapter data. + properties: + hostPath: + description: The directory in the host that contains the + data. + type: string + image: + description: The name of the image that contains the source + data. The assumption is that the source data locates in + the `data` directory in the image. + type: string + imagePullSecrets: + description: ImagePullSecrets is a list of secret names + in the same namespace used for pulling the data image. + items: + type: string + type: array + name: + description: The name of the dataset. The same name will + be used as a container name. It must be a valid DNS subdomain + value, + type: string + urls: + description: URLs specifies the links to the public data + sources. E.g., files in a public github repository. + items: + type: string + type: array + type: object + strength: + description: Strength specifies the default multiplier for applying + the adapter weights to the raw model weights. It is usually + a float number between 0 and 1. It is defined as a string + type to be language agnostic. + type: string + type: object + type: array preset: - description: Preset describles the model that will be deployed with - preset configurations. + description: Preset describes the base model that will be deployed + with preset configurations. properties: accessMode: + default: public description: AccessMode specifies whether the containerized model image is accessible via public registry or private registry. This field defaults to "public" if not specified. If this field @@ -72,7 +117,7 @@ spec: type: string imagePullSecrets: description: ImagePullSecrets is a list of secret names in - the same namespace used for pulling the image. + the same namespace used for pulling the model image. items: type: string type: array @@ -95,7 +140,7 @@ spec: metadata: type: object resource: - description: ResourceSpec desicribes the resource requirement of running + description: ResourceSpec describes the resource requirement of running the workload. If the number of nodes in the cluster that meet the InstanceType and LabelSelector requirements is small than the Count, controller will provision new nodes before deploying the workload. The final list of @@ -245,6 +290,100 @@ spec: type: string type: array type: object + tuning: + properties: + config: + description: Config specifies the name of the configmap in the same + namespace that contains the arguments used by the tuning method. + If not specified, a default configmap is used based on the specified + method. + type: string + input: + description: Input describes the input used by the tuning method. + properties: + hostPath: + description: The directory in the host that contains the data. + type: string + image: + description: The name of the image that contains the source data. + The assumption is that the source data locates in the `data` + directory in the image. + type: string + imagePullSecrets: + description: ImagePullSecrets is a list of secret names in the + same namespace used for pulling the data image. + items: + type: string + type: array + name: + description: The name of the dataset. The same name will be used + as a container name. It must be a valid DNS subdomain value, + type: string + urls: + description: URLs specifies the links to the public data sources. + E.g., files in a public github repository. + items: + type: string + type: array + type: object + method: + description: Method specifies the Parameter-Efficient Fine-Tuning(PEFT) + method, such as lora, qlora, used for the tuning. + type: string + output: + description: Output specified where to store the tuning output. + properties: + hostPath: + description: The directory in the host that contains the output + data. + type: string + image: + description: Name of the image where the output data is pushed + to. + type: string + imagePushSecret: + description: ImagePushSecret is the name of the secret in the + same namespace that contains the authentication information + that is needed for running `docker push`. + type: string + type: object + preset: + description: Preset describes which model to load for tuning. + properties: + accessMode: + default: public + description: AccessMode specifies whether the containerized model + image is accessible via public registry or private registry. + This field defaults to "public" if not specified. If this field + is "private", user needs to provide the private image information + in PresetOptions. + enum: + - public + - private + type: string + name: + description: Name of the supported models with preset configurations. + type: string + presetOptions: + properties: + image: + description: Image is the name of the containerized model + image. + type: string + imagePullSecrets: + description: ImagePullSecrets is a list of secret names in + the same namespace used for pulling the model image. + items: + type: string + type: array + type: object + required: + - name + type: object + required: + - input + - output + type: object type: object served: true storage: true diff --git a/config/crd/bases/kaito.sh_workspaces.yaml b/config/crd/bases/kaito.sh_workspaces.yaml index b3af23a76..a4103a897 100644 --- a/config/crd/bases/kaito.sh_workspaces.yaml +++ b/config/crd/bases/kaito.sh_workspaces.yaml @@ -57,7 +57,7 @@ spec: description: Source describes where to obtain the adapter data. properties: hostPath: - description: The directory in the hsot that contains the + description: The directory in the host that contains the data. type: string image: @@ -96,6 +96,7 @@ spec: with preset configurations. properties: accessMode: + default: public description: AccessMode specifies whether the containerized model image is accessible via public registry or private registry. This field defaults to "public" if not specified. If this field @@ -139,7 +140,7 @@ spec: metadata: type: object resource: - description: ResourceSpec desicribes the resource requirement of running + description: ResourceSpec describes the resource requirement of running the workload. If the number of nodes in the cluster that meet the InstanceType and LabelSelector requirements is small than the Count, controller will provision new nodes before deploying the workload. The final list of @@ -301,7 +302,7 @@ spec: description: Input describes the input used by the tuning method. properties: hostPath: - description: The directory in the hsot that contains the data. + description: The directory in the host that contains the data. type: string image: description: The name of the image that contains the source data. @@ -350,6 +351,7 @@ spec: description: Preset describes which model to load for tuning. properties: accessMode: + default: public description: AccessMode specifies whether the containerized model image is accessible via public registry or private registry. This field defaults to "public" if not specified. If this field @@ -378,6 +380,9 @@ spec: required: - name type: object + required: + - input + - output type: object type: object served: true diff --git a/pkg/utils/testUtils.go b/pkg/utils/testUtils.go index f88b35a4f..5ef34af1d 100644 --- a/pkg/utils/testUtils.go +++ b/pkg/utils/testUtils.go @@ -35,7 +35,7 @@ var ( }, }, }, - Inference: v1alpha1.InferenceSpec{ + Inference: &v1alpha1.InferenceSpec{ Preset: &v1alpha1.PresetSpec{ PresetMeta: v1alpha1.PresetMeta{ Name: "test-distributed-model", @@ -60,7 +60,7 @@ var ( }, }, }, - Inference: v1alpha1.InferenceSpec{ + Inference: &v1alpha1.InferenceSpec{ Preset: &v1alpha1.PresetSpec{ PresetMeta: v1alpha1.PresetMeta{ Name: "test-model", @@ -85,7 +85,7 @@ var ( }, }, }, - Inference: v1alpha1.InferenceSpec{ + Inference: &v1alpha1.InferenceSpec{ Template: &corev1.PodTemplateSpec{}, }, } diff --git a/test/e2e/preset_test.go b/test/e2e/preset_test.go index eb0333df4..e8f262ef0 100644 --- a/test/e2e/preset_test.go +++ b/test/e2e/preset_test.go @@ -26,13 +26,13 @@ import ( ) const ( - PresetLlama2AChat = "llama-2-7b-chat" - PresetLlama2BChat = "llama-2-13b-chat" - PresetFalcon7BModel = "falcon-7b" - PresetFalcon40BModel = "falcon-40b" - PresetMistral7BModel = "mistral-7b" + PresetLlama2AChat = "llama-2-7b-chat" + PresetLlama2BChat = "llama-2-13b-chat" + PresetFalcon7BModel = "falcon-7b" + PresetFalcon40BModel = "falcon-40b" + PresetMistral7BModel = "mistral-7b" PresetMistral7BInstructModel = "mistral-7b-instruct" - PresetPhi2Model = "phi-2" + PresetPhi2Model = "phi-2" ) func createFalconWorkspaceWithPresetPublicMode(numOfNode int) *kaitov1alpha1.Workspace { @@ -348,17 +348,17 @@ var _ = Describe("Workspace Preset", func() { fmt.Print("Error: RUN_LLAMA_13B ENV Variable not set") runLlama13B = false } - + aiModelsRegistry = utils.GetEnv("AI_MODELS_REGISTRY") aiModelsRegistrySecret = utils.GetEnv("AI_MODELS_REGISTRY_SECRET") - + // Load stable model versions configs, err := utils.GetModelConfigInfo("/home/runner/work/kaito/kaito/presets/models/supported_models.yaml") if err != nil { fmt.Printf("Failed to load model configs: %v\n", err) os.Exit(1) } - + modelInfo, err = utils.ExtractModelVersion(configs) if err != nil { fmt.Printf("Failed to extract stable model versions: %v\n", err) @@ -404,7 +404,6 @@ var _ = Describe("Workspace Preset", func() { validateWorkspaceReadiness(workspaceObj) }) - It("should create a Phi-2 workspace with preset public mode successfully", func() { numOfNode := 1 workspaceObj := createPhi2WorkspaceWithPresetPublicMode(numOfNode) diff --git a/test/e2e/utils/utils.go b/test/e2e/utils/utils.go index 3914f00eb..38388374f 100644 --- a/test/e2e/utils/utils.go +++ b/test/e2e/utils/utils.go @@ -60,23 +60,23 @@ func ExtractModelVersion(configs map[string]interface{}) (map[string]string, err } for _, modelItem := range models { - model, ok := modelItem.(map[interface{}]interface{}) - if !ok { - return nil, fmt.Errorf("model item is not a map") - } + model, ok := modelItem.(map[interface{}]interface{}) + if !ok { + return nil, fmt.Errorf("model item is not a map") + } - modelName, ok := model["name"].(string) - if !ok { - return nil, fmt.Errorf("model name is not a string or not found") - } + modelName, ok := model["name"].(string) + if !ok { + return nil, fmt.Errorf("model name is not a string or not found") + } - modelTag, ok := model["tag"].(string) // Using 'tag' as the version - if !ok { - return nil, fmt.Errorf("model version for %s is not a string or not found", modelName) - } + modelTag, ok := model["tag"].(string) // Using 'tag' as the version + if !ok { + return nil, fmt.Errorf("model version for %s is not a string or not found", modelName) + } - modelsInfo[modelName] = modelTag - } + modelsInfo[modelName] = modelTag + } return modelsInfo, nil } @@ -117,7 +117,7 @@ func GenerateWorkspaceManifest(name, namespace, imageName string, resourceCount workspaceInference.Template = podTemplate } - workspace.Inference = workspaceInference + workspace.Inference = &workspaceInference return workspace } From 6c347c99fd4447c7bcff414207d8fad5bee67ce4 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Tue, 19 Mar 2024 19:33:01 -0700 Subject: [PATCH 7/7] fix: Add name flag --- api/v1alpha1/workspace_validation.go | 9 +++++++-- api/v1alpha1/workspace_validation_test.go | 13 ++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index b135f5886..79b27e1b9 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -41,6 +41,7 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { w.Resource.validateCreate(*w.Inference).ViaField("resource"), ) if w.Inference != nil { + // TODO: Add Adapter Spec Validation - Including DataSource Validation for Adapter errs = errs.Also(w.Inference.validateCreate().ViaField("inference")) } if w.Tuning != nil { @@ -54,6 +55,7 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { w.Resource.validateUpdate(&old.Resource).ViaField("resource"), ) if w.Inference != nil { + // TODO: Add Adapter Spec Validation - Including DataSource Validation for Adapter errs = errs.Also(w.Inference.validateUpdate(old.Inference).ViaField("inference")) } if w.Tuning != nil { @@ -112,7 +114,7 @@ func (r *TuningSpec) validateUpdate(old *TuningSpec) (errs *apis.FieldError) { if r.Input == nil { errs = errs.Also(apis.ErrMissingField("Input")) } else { - errs = errs.Also(r.Input.validateUpdate(old.Input).ViaField("Input")) + errs = errs.Also(r.Input.validateUpdate(old.Input, true).ViaField("Input")) } if r.Output == nil { errs = errs.Also(apis.ErrMissingField("Output")) @@ -150,7 +152,10 @@ func (r *DataSource) validateCreate() (errs *apis.FieldError) { return errs } -func (r *DataSource) validateUpdate(old *DataSource) (errs *apis.FieldError) { +func (r *DataSource) validateUpdate(old *DataSource, isTuning bool) (errs *apis.FieldError) { + if isTuning && !reflect.DeepEqual(old.Name, r.Name) { + errs = errs.Also(apis.ErrInvalidValue("During tuning Name field cannot be changed once set", "Name")) + } oldURLs := make([]string, len(old.URLs)) copy(oldURLs, old.URLs) sort.Strings(oldURLs) diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index d1cea034d..11631e67b 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -898,6 +898,17 @@ func TestDataSourceValidateUpdate(t *testing.T) { }, wantErr: false, }, + { + name: "Name changed", + oldSource: &DataSource{ + Name: "original-dataset", + }, + newSource: &DataSource{ + Name: "new-dataset", + }, + wantErr: true, + errFields: []string{"Name"}, + }, { name: "URLs changed", oldSource: &DataSource{ @@ -946,7 +957,7 @@ func TestDataSourceValidateUpdate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - errs := tt.newSource.validateUpdate(tt.oldSource) + errs := tt.newSource.validateUpdate(tt.oldSource, true) hasErrs := errs != nil if hasErrs != tt.wantErr {