Skip to content

Commit

Permalink
feat: add inference config api
Browse files Browse the repository at this point in the history
- API change: add config to workspace.inference.config
- generate a default config if no user config specified
  by copying from release-namespace

Signed-off-by: jerryzhuang <[email protected]>
  • Loading branch information
zhuangqh committed Dec 20, 2024
1 parent 2c1d5bf commit 2eafbb8
Show file tree
Hide file tree
Showing 12 changed files with 269 additions and 29 deletions.
4 changes: 4 additions & 0 deletions api/v1alpha1/workspace_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ type InferenceSpec struct {
// +kubebuilder:validation:Schemaless
// +optional
Template *v1.PodTemplateSpec `json:"template,omitempty"`
// Config specifies the name of a custom ConfigMap that contains inference arguments.
// If specified, the ConfigMap must be in the same namespace as the Workspace custom resource.
// +optional
Config string `json:"config,omitempty"`
// Adapters are integrated into the base model for inference.
// Users can specify multiple adapters for the model and the respective weight of using each of them.
// +optional
Expand Down
7 changes: 4 additions & 3 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ const (
N_SERIES_PREFIX = "Standard_N"
D_SERIES_PREFIX = "Standard_D"

DefaultLoraConfigMapTemplate = "lora-params-template"
DefaultQloraConfigMapTemplate = "qlora-params-template"
MaxAdaptersNumber = 10
DefaultLoraConfigMapTemplate = "lora-params-template"
DefaultQloraConfigMapTemplate = "qlora-params-template"
DefaultInferenceConfigTemplate = "inference-params-template"
MaxAdaptersNumber = 10
)

func (w *Workspace) SupportedVerbs() []admissionregistrationv1.OperationType {
Expand Down
5 changes: 5 additions & 0 deletions charts/kaito/workspace/crds/kaito.sh_workspaces.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ spec:
type: string
type: object
type: array
config:
description: |-
Config specifies the name of a custom ConfigMap that contains inference arguments.
If specified, the ConfigMap must be in the same namespace as the Workspace custom resource.
type: string
preset:
description: Preset describes the base model that will be deployed
with preset configurations.
Expand Down
5 changes: 5 additions & 0 deletions config/crd/bases/kaito.sh_workspaces.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ spec:
type: string
type: object
type: array
config:
description: |-
Config specifies the name of a custom ConfigMap that contains inference arguments.
If specified, the ConfigMap must be in the same namespace as the Workspace custom resource.
type: string
preset:
description: Preset describes the base model that will be deployed
with preset configurations.
Expand Down
9 changes: 8 additions & 1 deletion pkg/model/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
package model

import (
"path"
"time"

"github.com/kaito-project/kaito/pkg/utils"
corev1 "k8s.io/api/core/v1"
)

type Model interface {
Expand All @@ -21,6 +23,8 @@ type RuntimeName string
const (
RuntimeNameHuggingfaceTransformers RuntimeName = "transformers"
RuntimeNameVLLM RuntimeName = "vllm"

ConfigfileNameVLLM = "inference_config.yaml"
)

// PresetParam defines the preset inference parameters for a model.
Expand Down Expand Up @@ -133,7 +137,7 @@ func (v *VLLMParam) DeepCopy() VLLMParam {

// builds the container command:
// eg. torchrun <TORCH_PARAMS> <OPTIONAL_RDZV_PARAMS> baseCommand <MODEL_PARAMS>
func (p *PresetParam) GetInferenceCommand(runtime RuntimeName, skuNumGPUs string) []string {
func (p *PresetParam) GetInferenceCommand(runtime RuntimeName, skuNumGPUs string, configVolume *corev1.VolumeMount) []string {
switch runtime {
case RuntimeNameHuggingfaceTransformers:
torchCommand := utils.BuildCmdStr(p.Transformers.BaseCommand, p.Transformers.TorchRunParams, p.Transformers.TorchRunRdzvParams)
Expand All @@ -146,6 +150,9 @@ func (p *PresetParam) GetInferenceCommand(runtime RuntimeName, skuNumGPUs string
if !p.DisableTensorParallelism {
p.VLLM.ModelRunParams["tensor-parallel-size"] = skuNumGPUs
}
if configVolume != nil {
p.VLLM.ModelRunParams["kaito-config-file"] = path.Join(configVolume.MountPath, ConfigfileNameVLLM)
}
modelCommand := utils.BuildCmdStr(p.VLLM.BaseCommand, p.VLLM.ModelRunParams)
return utils.ShellCmd(modelCommand)
default:
Expand Down
20 changes: 20 additions & 0 deletions pkg/utils/test/testUtils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
package test

import (
"os"

"github.com/aws/karpenter-core/pkg/apis/v1alpha5"
"github.com/kaito-project/kaito/api/v1alpha1"
"github.com/kaito-project/kaito/pkg/model"
Expand Down Expand Up @@ -948,3 +950,21 @@ func NotFoundError() error {
func IsAlreadyExistsError() error {
return &apierrors.StatusError{ErrStatus: metav1.Status{Reason: metav1.StatusReasonAlreadyExists}}
}

// Saves state of current env, and returns function to restore to saved state
func SaveEnv(key string) func() {
envVal, envExists := os.LookupEnv(key)
return func() {
if envExists {
err := os.Setenv(key, envVal)
if err != nil {
return
}
} else {
err := os.Unsetenv(key)
if err != nil {
return
}
}
}
}
4 changes: 3 additions & 1 deletion pkg/workspace/controllers/workspace_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,7 @@ func TestApplyInferenceWithPreset(t *testing.T) {
},
"Create preset inference because inference workload did not exist": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.ConfigMap{}), mock.Anything).Return(nil)
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(test.NotFoundError()).Times(4)
c.On("Get", mock.Anything, mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil).Run(func(args mock.Arguments) {
depObj := &appsv1.Deployment{}
Expand Down Expand Up @@ -692,9 +693,10 @@ func TestApplyInferenceWithPreset(t *testing.T) {
}
ctx := context.Background()

os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)
err := reconciler.applyInference(ctx, &tc.workspace)
if tc.expectedError == nil {
assert.Check(t, err == nil, "Not expected to return error")
assert.Check(t, err == nil, fmt.Sprintf("Not expected to return error: %v", err))
} else {
assert.Equal(t, tc.expectedError.Error(), err.Error())
}
Expand Down
83 changes: 81 additions & 2 deletions pkg/workspace/inference/preset-inferences.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/kaito-project/kaito/pkg/utils/resources"
"github.com/kaito-project/kaito/pkg/workspace/manifests"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/resource"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/klog/v2"
Expand Down Expand Up @@ -131,8 +132,13 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work
model model.Model, kubeClient client.Client) (client.Object, error) {
inferenceParam := model.GetInferenceParameters().DeepCopy()

configVolume, err := EnsureInferenceConfigMap(ctx, workspaceObj, kubeClient)
if err != nil {
return nil, err
}

if model.SupportDistributedInference() {
if err := updateTorchParamsForDistributedInference(ctx, kubeClient, workspaceObj, inferenceParam); err != nil { //
if err := updateTorchParamsForDistributedInference(ctx, kubeClient, workspaceObj, inferenceParam); err != nil {
klog.ErrorS(err, "failed to update torch params", "workspace", workspaceObj)
return nil, err
}
Expand All @@ -157,6 +163,12 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work
// additional volume
var volumes []corev1.Volume
var volumeMounts []corev1.VolumeMount

// Add config volume mount
cmVolume, cmVolumeMount := utils.ConfigCMVolume(configVolume.Name)
volumes = append(volumes, cmVolume)
volumeMounts = append(volumeMounts, cmVolumeMount)

// add share memory for cross process communication
shmVolume, shmVolumeMount := utils.ConfigSHMVolume(skuGPUCount)
if shmVolume.Name != "" {
Expand All @@ -173,7 +185,7 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work

// inference command
runtimeName := kaitov1alpha1.GetWorkspaceRuntimeName(workspaceObj)
commands := inferenceParam.GetInferenceCommand(runtimeName, skuNumGPUs)
commands := inferenceParam.GetInferenceCommand(runtimeName, skuNumGPUs, &cmVolumeMount)

image, imagePullSecrets := GetInferenceImageInfo(ctx, workspaceObj, inferenceParam)

Expand All @@ -191,3 +203,70 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work
}
return depObj, nil
}

// EnsureInferenceConfigMap handles two scenarios:
// 1. User provided config (workspaceObj.Inference.Config):
// - Check if it exists in the target namespace
// - If not found, return error as this is user-specified
//
// 2. No user config specified:
// - Use the default config template (inference-params-template)
// - Check if it exists in the target namespace
// - If not, copy from release namespace to target namespace
func EnsureInferenceConfigMap(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace,
kubeClient client.Client) (*corev1.ConfigMap, error) {

// If user specified a config, use that
if workspaceObj.Inference.Config != "" {
userCM := &corev1.ConfigMap{}
err := resources.GetResource(ctx, workspaceObj.Inference.Config, workspaceObj.Namespace, kubeClient, userCM)
if err != nil {
if errors.IsNotFound(err) {
return nil, fmt.Errorf("user specified ConfigMap %s not found in namespace %s",
workspaceObj.Inference.Config, workspaceObj.Namespace)
}
return nil, err
}

return userCM, nil
}

// Otherwise use default template
configMapName := kaitov1alpha1.DefaultInferenceConfigTemplate

// Check if default configmap already exists in target namespace
existingCM := &corev1.ConfigMap{}
err := resources.GetResource(ctx, configMapName, workspaceObj.Namespace, kubeClient, existingCM)
if err != nil {
if !errors.IsNotFound(err) {
return nil, err
}
} else {
klog.Infof("Default ConfigMap already exists in target namespace: %s, no action taken.", workspaceObj.Namespace)
return existingCM, nil
}

// Copy default template from release namespace if not found
releaseNamespace, err := utils.GetReleaseNamespace()
if err != nil {
return nil, fmt.Errorf("failed to get release namespace: %v", err)
}

templateCM := &corev1.ConfigMap{}
err = resources.GetResource(ctx, configMapName, releaseNamespace, kubeClient, templateCM)
if err != nil {
return nil, fmt.Errorf("failed to get default ConfigMap from template namespace: %v", err)
}

templateCM.Namespace = workspaceObj.Namespace
templateCM.ResourceVersion = "" // Clear metadata not needed for creation
templateCM.UID = "" // Clear UID

err = resources.CreateResource(ctx, templateCM, kubeClient)
if err != nil {
return nil, fmt.Errorf("failed to create default ConfigMap in target namespace %s: %v",
workspaceObj.Namespace, err)
}

return templateCM, nil
}
Loading

0 comments on commit 2eafbb8

Please sign in to comment.