Skip to content

Commit

Permalink
feat: Part 4 - Add image tag to models (#221)
Browse files Browse the repository at this point in the history
This PR adds image tag into models.go, so that we can control what
preset image Kaito controller is using (establishes controllable link
between kaito and preset).

PR also adds minor fix to e2e-preset-tests ensuring image tag is bumped.
  • Loading branch information
ishaansehgal99 authored Jan 30, 2024
1 parent ae881d2 commit a70138e
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 11 deletions.
25 changes: 15 additions & 10 deletions .github/workflows/e2e-preset-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,18 @@ jobs:
echo "Image $IMAGE_NAME:$TAG not found in $ACR_NAME."
fi
- name: Check if Image is Test and Prod ACRs
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true' && steps.check_prod_image.outputs.IMAGE_EXISTS == 'true'
run: |
echo "Skipping: Image already exists in both Test and Prod ACRs, remember to bump tag"
- name: Set up kubectl context
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true'
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true' && steps.check_prod_image.outputs.IMAGE_EXISTS == 'false'
run: |
az aks get-credentials --resource-group llm-test --name GitRunner
- name: Get Nodepool Name
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true'
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true' && steps.check_prod_image.outputs.IMAGE_EXISTS == 'false'
id: get_nodepool_name
run: |
NAME_SUFFIX=${{ matrix.name }}
Expand All @@ -142,7 +147,7 @@ jobs:
echo "NODEPOOL_NAME=$TRUNCATED_NAME_SUFFIX" >> $GITHUB_OUTPUT
- name: Create Nodepool
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true'
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true' && steps.check_prod_image.outputs.IMAGE_EXISTS == 'false'
run: |
NODEPOOL_EXIST=$(az aks nodepool show \
--name ${{ steps.get_nodepool_name.outputs.NODEPOOL_NAME }} \
Expand Down Expand Up @@ -177,11 +182,11 @@ jobs:
fi
- name: Create Service
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true'
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true' && steps.check_prod_image.outputs.IMAGE_EXISTS == 'false'
run: kubectl apply -f presets/test/manifests/${{ matrix.name }}/${{ matrix.name }}-service.yaml

- name: Retrieve External Service IP
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true'
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true' && steps.check_prod_image.outputs.IMAGE_EXISTS == 'false'
id: get_ip
run: |
while [[ -z $SERVICE_IP ]]; do
Expand All @@ -192,30 +197,30 @@ jobs:
echo "SERVICE_IP=$SERVICE_IP" >> $GITHUB_OUTPUT
- name: Replace IP and Deploy Statefulset to K8s
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true'
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true' && steps.check_prod_image.outputs.IMAGE_EXISTS == 'false'
run: |
sed -i "s/MASTER_ADDR_HERE/${{ steps.get_ip.outputs.SERVICE_IP }}/g" presets/test/manifests/${{ matrix.name }}/${{ matrix.name }}-statefulset.yaml
sed -i "s/TAG_HERE/${{ matrix.tag }}/g" presets/test/manifests/${{ matrix.name }}/${{ matrix.name }}-statefulset.yaml
sed -i "s/REPO_HERE/${{ secrets.ACR_AMRT_USERNAME }}/g" presets/test/manifests/${{ matrix.name }}/${{ matrix.name }}-statefulset.yaml
kubectl apply -f presets/test/manifests/${{ matrix.name }}/${{ matrix.name }}-statefulset.yaml
- name: Wait for Statefulset to be ready
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true'
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true' && steps.check_prod_image.outputs.IMAGE_EXISTS == 'false'
run: |
kubectl rollout status statefulset/${{ matrix.name }}
- name: Test home endpoint
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true'
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true' && steps.check_prod_image.outputs.IMAGE_EXISTS == 'false'
run: |
curl http://${{ steps.get_ip.outputs.SERVICE_IP }}:80/
- name: Test healthz endpoint
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true'
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true' && steps.check_prod_image.outputs.IMAGE_EXISTS == 'false'
run: |
curl http://${{ steps.get_ip.outputs.SERVICE_IP }}:80/healthz
- name: Test inference endpoint
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true'
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true' && steps.check_prod_image.outputs.IMAGE_EXISTS == 'false'
run: |
if [[ "${{ matrix.name }}" == *"llama"* && "${{ matrix.name }}" == *"-chat"* ]]; then
echo "Testing inference for ${{ matrix.name }}"
Expand Down
3 changes: 2 additions & 1 deletion pkg/inference/preset-inferences.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient cl

func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetInferenceParam) (string, []corev1.LocalObjectReference) {
imageName := string(workspaceObj.Inference.Preset.Name)
imageTag := inferenceObj.Tag
imagePullSecretRefs := []corev1.LocalObjectReference{}
if inferenceObj.ImageAccessMode == "private" {
imageName = string(workspaceObj.Inference.Preset.PresetOptions.Image)
Expand All @@ -104,7 +105,7 @@ func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, in
}

registryName := os.Getenv("PRESET_REGISTRY_NAME")
imageName = registryName + fmt.Sprintf("/kaito-%s:0.0.1", imageName)
imageName = registryName + fmt.Sprintf("/kaito-%s:%s", imageName, imageTag)
return imageName, imagePullSecretRefs
}

Expand Down
1 change: 1 addition & 0 deletions pkg/model/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ type PresetInferenceParam struct {
BaseCommand string
// WorldSize defines the number of processes required for distributed inference.
WorldSize int
Tag string // The model image tag
}
11 changes: 11 additions & 0 deletions presets/models/falcon/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ var (
PresetFalcon7BInstructModel = PresetFalcon7BModel + "-instruct"
PresetFalcon40BInstructModel = PresetFalcon40BModel + "-instruct"

PresetFalconTagMap = map[string]string{
"Falcon7B": "0.0.1",
"Falcon7BInstruct": "0.0.1",
"Falcon40B": "0.0.1",
"Falcon40BInstruct": "0.0.1",
}

baseCommandPresetFalcon = "accelerate launch --use_deepspeed"
falconRunParams = map[string]string{}
)
Expand All @@ -56,6 +63,7 @@ func (*falcon7b) GetInferenceParameters() *model.PresetInferenceParam {
ModelRunParams: falconRunParams,
DeploymentTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetFalcon,
Tag: PresetFalconTagMap["Falcon7B"],
}

}
Expand All @@ -79,6 +87,7 @@ func (*falcon7bInst) GetInferenceParameters() *model.PresetInferenceParam {
ModelRunParams: falconRunParams,
DeploymentTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetFalcon,
Tag: PresetFalconTagMap["Falcon7BInstruct"],
}

}
Expand All @@ -102,6 +111,7 @@ func (*falcon40b) GetInferenceParameters() *model.PresetInferenceParam {
ModelRunParams: falconRunParams,
DeploymentTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetFalcon,
Tag: PresetFalconTagMap["Falcon40B"],
}

}
Expand All @@ -125,6 +135,7 @@ func (*falcon40bInst) GetInferenceParameters() *model.PresetInferenceParam {
ModelRunParams: falconRunParams,
DeploymentTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetFalcon,
Tag: PresetFalconTagMap["Falcon40BInstruct"],
}
}

Expand Down
9 changes: 9 additions & 0 deletions presets/models/llama2/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ func init() {
}

var (
PresetLlamaTagMap = map[string]string{
"llama-2-7b": "0.0.1",
"llama-2-13b": "0.0.1",
"llama-2-70b": "0.0.1",
}

baseCommandPresetLlama = "cd /workspace/llama/llama-2 && torchrun"
llamaRunParams = map[string]string{
"max_seq_len": "512",
Expand All @@ -52,6 +58,7 @@ func (*llama2Text7b) GetInferenceParameters() *model.PresetInferenceParam {
DeploymentTimeout: time.Duration(10) * time.Minute,
BaseCommand: baseCommandPresetLlama,
WorldSize: 1,
Tag: PresetLlamaTagMap["llama-2-7b"],
}

}
Expand All @@ -77,6 +84,7 @@ func (*llama2Text13b) GetInferenceParameters() *model.PresetInferenceParam {
DeploymentTimeout: time.Duration(20) * time.Minute,
BaseCommand: baseCommandPresetLlama,
WorldSize: 2,
Tag: PresetLlamaTagMap["llama-2-13b"],
}
}
func (*llama2Text13b) SupportDistributedInference() bool {
Expand All @@ -101,6 +109,7 @@ func (*llama2Text70b) GetInferenceParameters() *model.PresetInferenceParam {
DeploymentTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetLlama,
WorldSize: 8,
Tag: PresetLlamaTagMap["llama-2-70b"],
}
}
func (*llama2Text70b) SupportDistributedInference() bool {
Expand Down
9 changes: 9 additions & 0 deletions presets/models/llama2chat/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ func init() {
}

var (
PresetLlamaTagMap = map[string]string{
"llama-2-7b-chat": "0.0.1",
"llama-2-13b-chat": "0.0.1",
"llama-2-70b-chat": "0.0.1",
}

baseCommandPresetLlama = "cd /workspace/llama/llama-2 && torchrun"
llamaRunParams = map[string]string{
"max_seq_len": "512",
Expand All @@ -52,6 +58,7 @@ func (*llama2Chat7b) GetInferenceParameters() *model.PresetInferenceParam {
DeploymentTimeout: time.Duration(10) * time.Minute,
BaseCommand: baseCommandPresetLlama,
WorldSize: 1,
Tag: PresetLlamaTagMap["llama-2-7b-chat"],
}

}
Expand All @@ -77,6 +84,7 @@ func (*llama2Chat13b) GetInferenceParameters() *model.PresetInferenceParam {
DeploymentTimeout: time.Duration(20) * time.Minute,
BaseCommand: baseCommandPresetLlama,
WorldSize: 2,
Tag: PresetLlamaTagMap["llama-2-13b-chat"],
}
}
func (*llama2Chat13b) SupportDistributedInference() bool {
Expand All @@ -101,6 +109,7 @@ func (*llama2Chat70b) GetInferenceParameters() *model.PresetInferenceParam {
DeploymentTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetLlama,
WorldSize: 8,
Tag: PresetLlamaTagMap["llama-2-70b-chat"],
}
}
func (*llama2Chat70b) SupportDistributedInference() bool {
Expand Down

0 comments on commit a70138e

Please sign in to comment.