Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Part 2 - Add validation checks for TuningSpec, DataSource, DataDestination #304

Merged
merged 10 commits into from
Mar 20, 2024
14 changes: 7 additions & 7 deletions api/v1alpha1/workspace_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ const (
ModelImageAccessModePrivate ModelImageAccessMode = "private"
)

// ResourceSpec desicribes the resource requirement of running the workload.
// ResourceSpec describes the resource requirement of running the workload.
// If the number of nodes in the cluster that meet the InstanceType and
// LabelSelector requirements is small than the Count, controller
// will provision new nodes before deploying the workload.
Expand Down Expand Up @@ -51,7 +51,7 @@ type PresetMeta struct {
// AccessMode specifies whether the containerized model image is accessible via public registry
// or private registry. This field defaults to "public" if not specified.
// If this field is "private", user needs to provide the private image information in PresetOptions.
// +bebuilder:default:="public"
// +kubebuilder:default:="public"
// +optional
AccessMode ModelImageAccessMode `json:"accessMode,omitempty"`
}
Expand Down Expand Up @@ -106,7 +106,7 @@ type DataSource struct {
// URLs specifies the links to the public data sources. E.g., files in a public github repository.
// +optional
URLs []string `json:"urls,omitempty"`
// The directory in the hsot that contains the data.
// The directory in the host that contains the data.
// +optional
HostPath string `json:"hostPath,omitempty"`
// The name of the image that contains the source data. The assumption is that the source data locates in the
Expand Down Expand Up @@ -150,9 +150,9 @@ type TuningSpec struct {
// +optional
Config string `json:"config,omitempty"`
// Input describes the input used by the tuning method.
Input *DataSource `json:"input,omitempty"`
Input *DataSource `json:"input"`
// Output specified where to store the tuning output.
Output *DataDestination `json:"output,omitempty"`
Output *DataDestination `json:"output"`
}

// WorkspaceStatus defines the observed state of Workspace
Expand Down Expand Up @@ -181,8 +181,8 @@ type Workspace struct {
metav1.ObjectMeta `json:"metadata,omitempty"`

Resource ResourceSpec `json:"resource,omitempty"`
Inference InferenceSpec `json:"inference,omitempty"`
Tuning TuningSpec `json:"tuning,omitempty"`
Inference *InferenceSpec `json:"inference,omitempty"`
Tuning *TuningSpec `json:"tuning,omitempty"`
Status WorkspaceStatus `json:"status,omitempty"`
}

Expand Down
179 changes: 174 additions & 5 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"fmt"
"reflect"
"sort"
"strings"

"github.com/azure/kaito/pkg/utils/plugin"
Expand Down Expand Up @@ -35,16 +36,184 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) {
if base == nil {
klog.InfoS("Validate creation", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name))
errs = errs.Also(
w.Inference.validateCreate().ViaField("inference"),
w.Resource.validateCreate(w.Inference).ViaField("resource"),
w.validateCreate().ViaField("spec"),
// TODO: Consider validate resource based on Tuning Spec
w.Resource.validateCreate(*w.Inference).ViaField("resource"),
)
if w.Inference != nil {
// TODO: Add Adapter Spec Validation - Including DataSource Validation for Adapter
errs = errs.Also(w.Inference.validateCreate().ViaField("inference"))
}
if w.Tuning != nil {
errs = errs.Also(w.Tuning.validateCreate().ViaField("tuning"))
}
} else {
klog.InfoS("Validate update", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name))
old := base.(*Workspace)
errs = errs.Also(
w.validateUpdate(old).ViaField("spec"),
w.Resource.validateUpdate(&old.Resource).ViaField("resource"),
w.Inference.validateUpdate(&old.Inference).ViaField("inference"),
)
if w.Inference != nil {
// TODO: Add Adapter Spec Validation - Including DataSource Validation for Adapter
errs = errs.Also(w.Inference.validateUpdate(old.Inference).ViaField("inference"))
}
if w.Tuning != nil {
errs = errs.Also(w.Tuning.validateUpdate(old.Tuning).ViaField("tuning"))
}
}
return errs
}

func (w *Workspace) validateCreate() (errs *apis.FieldError) {
if w.Inference == nil && w.Tuning == nil {
errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, not neither", ""))
}
if w.Inference != nil && w.Tuning != nil {
errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, but not both", ""))
}
return errs
}

func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) {
if (old.Inference == nil && w.Inference != nil) || (old.Inference != nil && w.Inference == nil) {
errs = errs.Also(apis.ErrGeneric("Inference field cannot be toggled once set", "inference"))
}

if (old.Tuning == nil && w.Tuning != nil) || (old.Tuning != nil && w.Tuning == nil) {
errs = errs.Also(apis.ErrGeneric("Tuning field cannot be toggled once set", "tuning"))
}
return errs
}

func (r *TuningSpec) validateCreate() (errs *apis.FieldError) {
if r.Input == nil {
errs = errs.Also(apis.ErrMissingField("Input"))
} else {
errs = errs.Also(r.Input.validateCreate().ViaField("Input"))
}
if r.Output == nil {
errs = errs.Also(apis.ErrMissingField("Output"))
} else {
errs = errs.Also(r.Output.validateCreate().ViaField("Output"))
}
// Currently require a preset to specified, in future we can consider defining a template
if r.Preset == nil {
errs = errs.Also(apis.ErrMissingField("Preset"))
} else if presetName := string(r.Preset.Name); !isValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported tuning preset name %s", presetName), "presetName"))
}
methodLowerCase := strings.ToLower(string(r.Method))
if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) {
errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method"))
}
return errs
}

func (r *TuningSpec) validateUpdate(old *TuningSpec) (errs *apis.FieldError) {
if r.Input == nil {
errs = errs.Also(apis.ErrMissingField("Input"))
} else {
errs = errs.Also(r.Input.validateUpdate(old.Input, true).ViaField("Input"))
}
if r.Output == nil {
errs = errs.Also(apis.ErrMissingField("Output"))
} else {
errs = errs.Also(r.Output.validateUpdate(old.Output).ViaField("Output"))
}
if !reflect.DeepEqual(old.Preset, r.Preset) {
errs = errs.Also(apis.ErrGeneric("Preset cannot be changed", "Preset"))
}
oldMethod, newMethod := strings.ToLower(string(old.Method)), strings.ToLower(string(r.Method))
if !reflect.DeepEqual(oldMethod, newMethod) {
errs = errs.Also(apis.ErrGeneric("Method cannot be changed", "Method"))
}
// Consider supporting config fields changing
return errs
}

func (r *DataSource) validateCreate() (errs *apis.FieldError) {
sourcesSpecified := 0
if len(r.URLs) > 0 {
sourcesSpecified++
}
if r.HostPath != "" {
sourcesSpecified++
}
if r.Image != "" {
sourcesSpecified++
}

// Ensure exactly one of URLs, HostPath, or Image is specified
if sourcesSpecified != 1 {
errs = errs.Also(apis.ErrGeneric("Exactly one of URLs, HostPath, or Image must be specified", "URLs", "HostPath", "Image"))
}

return errs
}

func (r *DataSource) validateUpdate(old *DataSource, isTuning bool) (errs *apis.FieldError) {
if isTuning && !reflect.DeepEqual(old.Name, r.Name) {
errs = errs.Also(apis.ErrInvalidValue("During tuning Name field cannot be changed once set", "Name"))
}
oldURLs := make([]string, len(old.URLs))
copy(oldURLs, old.URLs)
sort.Strings(oldURLs)

newURLs := make([]string, len(r.URLs))
copy(newURLs, r.URLs)
sort.Strings(newURLs)

if !reflect.DeepEqual(oldURLs, newURLs) {
errs = errs.Also(apis.ErrInvalidValue("URLs field cannot be changed once set", "URLs"))
}
if old.HostPath != r.HostPath {
errs = errs.Also(apis.ErrInvalidValue("HostPath field cannot be changed once set", "HostPath"))
}
if old.Image != r.Image {
errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image"))
}

oldSecrets := make([]string, len(old.ImagePullSecrets))
copy(oldSecrets, old.ImagePullSecrets)
sort.Strings(oldSecrets)

newSecrets := make([]string, len(r.ImagePullSecrets))
copy(newSecrets, r.ImagePullSecrets)
sort.Strings(newSecrets)

if !reflect.DeepEqual(oldSecrets, newSecrets) {
errs = errs.Also(apis.ErrInvalidValue("ImagePullSecrets field cannot be changed once set", "ImagePullSecrets"))
}
return errs
}

func (r *DataDestination) validateCreate() (errs *apis.FieldError) {
destinationsSpecified := 0
if r.HostPath != "" {
destinationsSpecified++
}
if r.Image != "" {
destinationsSpecified++
}

// If no destination is specified, return an error
if destinationsSpecified == 0 {
errs = errs.Also(apis.ErrMissingField("At least one of HostPath or Image must be specified"))
}
return errs
}

func (r *DataDestination) validateUpdate(old *DataDestination) (errs *apis.FieldError) {
if old.HostPath != r.HostPath {
errs = errs.Also(apis.ErrInvalidValue("HostPath field cannot be changed once set", "HostPath"))
}
if old.Image != r.Image {
errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image"))
}

if old.ImagePushSecret != r.ImagePushSecret {
errs = errs.Also(apis.ErrInvalidValue("ImagePushSecret field cannot be changed once set", "ImagePushSecret"))
}
return errs
}
Expand Down Expand Up @@ -131,7 +300,7 @@ func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) {
presetName := string(i.Preset.Name)
// Validate preset name
if !isValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported preset name %s", presetName), "presetName"))
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported inference preset name %s", presetName), "presetName"))
}
// Validate private preset has private image specified
if plugin.KaitoModelRegister.MustGet(string(i.Preset.Name)).GetInferenceParameters().ImageAccessMode == "private" &&
Expand All @@ -151,7 +320,7 @@ func (i *InferenceSpec) validateUpdate(old *InferenceSpec) (errs *apis.FieldErro
if !reflect.DeepEqual(i.Preset, old.Preset) {
errs = errs.Also(apis.ErrGeneric("field is immutable", "preset"))
}
//inference.template can be changed, but cannot be unset.
// inference.template can be changed, but cannot be set/unset.
if (i.Template != nil && old.Template == nil) || (i.Template == nil && old.Template != nil) {
errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "template"))
}
Expand Down
Loading
Loading