diff --git a/pkg/inference/preset-inference-types.go b/pkg/inference/preset-inference-types.go index f72eb8bbc..6ee8e9e8a 100644 --- a/pkg/inference/preset-inference-types.go +++ b/pkg/inference/preset-inference-types.go @@ -30,19 +30,17 @@ const ( var ( registryName = os.Getenv("PRESET_REGISTRY_NAME") - presetLlama2AChatImage = registryName + fmt.Sprintf("/%s:latest", kaitov1alpha1.PresetLlama2AChat) - presetLlama2BChatImage = registryName + fmt.Sprintf("/%s:latest", kaitov1alpha1.PresetLlama2BChat) - presetLlama2CChatImage = registryName + fmt.Sprintf("/%s:latest", kaitov1alpha1.PresetLlama2CChat) + presetLlama2AChatImage = registryName + fmt.Sprintf("/%s:0.0.1", kaitov1alpha1.PresetLlama2AChat) + presetLlama2BChatImage = registryName + fmt.Sprintf("/%s:0.0.1", kaitov1alpha1.PresetLlama2BChat) + presetLlama2CChatImage = registryName + fmt.Sprintf("/%s:0.0.1", kaitov1alpha1.PresetLlama2CChat) - presetFalcon7bImage = registryName + fmt.Sprintf("/%s:latest", kaitov1alpha1.PresetFalcon7BModel) - presetFalcon7bInstructImage = registryName + fmt.Sprintf("/%s:latest", kaitov1alpha1.PresetFalcon7BInstructModel) + presetFalcon7bImage = registryName + fmt.Sprintf("/%s:0.0.1", kaitov1alpha1.PresetFalcon7BModel) + presetFalcon7bInstructImage = registryName + fmt.Sprintf("/%s:0.0.1", kaitov1alpha1.PresetFalcon7BInstructModel) - presetFalcon40bImage = registryName + fmt.Sprintf("/%s:latest", kaitov1alpha1.PresetFalcon40BModel) - presetFalcon40bInstructImage = registryName + fmt.Sprintf("/%s:latest", kaitov1alpha1.PresetFalcon40BInstructModel) + presetFalcon40bImage = registryName + fmt.Sprintf("/%s:0.0.1", kaitov1alpha1.PresetFalcon40BModel) + presetFalcon40bInstructImage = registryName + fmt.Sprintf("/%s:0.0.1", kaitov1alpha1.PresetFalcon40BInstructModel) - baseCommandPresetLlama2AChat = fmt.Sprintf("cd /workspace/llama/%s && torchrun", kaitov1alpha1.PresetLlama2AChat) - baseCommandPresetLlama2BChat = fmt.Sprintf("cd /workspace/llama/%s && torchrun", kaitov1alpha1.PresetLlama2BChat) - baseCommandPresetLlama2CChat = fmt.Sprintf("cd /workspace/llama/%s && torchrun", kaitov1alpha1.PresetLlama2CChat) + baseCommandPresetLlama = "cd /workspace/llama/llama-2 && torchrun" // llamaTextInferenceFile = "inference-api.py" TODO: To support Text Generation Llama Models llamaChatInferenceFile = "inference-api.py" llamaRunParams = map[string]string{ @@ -113,7 +111,7 @@ var ( ModelRunParams: llamaRunParams, InferenceFile: llamaChatInferenceFile, DeploymentTimeout: time.Duration(10) * time.Minute, - BaseCommand: baseCommandPresetLlama2AChat, + BaseCommand: baseCommandPresetLlama, WorldSize: 1, DefaultVolumeMountPath: "/dev/shm", }, @@ -129,7 +127,7 @@ var ( ModelRunParams: llamaRunParams, InferenceFile: llamaChatInferenceFile, DeploymentTimeout: time.Duration(20) * time.Minute, - BaseCommand: baseCommandPresetLlama2BChat, + BaseCommand: baseCommandPresetLlama, WorldSize: 2, DefaultVolumeMountPath: "/dev/shm", }, @@ -145,7 +143,7 @@ var ( ModelRunParams: llamaRunParams, InferenceFile: llamaChatInferenceFile, DeploymentTimeout: time.Duration(30) * time.Minute, - BaseCommand: baseCommandPresetLlama2CChat, + BaseCommand: baseCommandPresetLlama, WorldSize: 8, DefaultVolumeMountPath: "/dev/shm", },