Skip to content

Commit

Permalink
feat: Add URL as data source (kaito-project#365)
Browse files Browse the repository at this point in the history
**Reason for Change**:
Add URL as a data source
**Requirements**
- [x] added unit tests and e2e tests (if applicable).
  • Loading branch information
ishaansehgal99 authored Apr 30, 2024
1 parent 962e3d6 commit a9af171
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 5 deletions.
39 changes: 37 additions & 2 deletions pkg/tuning/preset-tuning.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package tuning
import (
"context"
"fmt"
"k8s.io/apimachinery/pkg/api/resource"
"os"
"strings"

"k8s.io/apimachinery/pkg/api/resource"

kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
"github.com/azure/kaito/pkg/model"
Expand Down Expand Up @@ -158,8 +160,9 @@ func prepareDataSource(ctx context.Context, workspaceObj *kaitov1alpha1.Workspac
case workspaceObj.Tuning.Input.Image != "":
initContainers, volumes, volumeMounts = handleImageDataSource(ctx, workspaceObj)
_, imagePullSecrets = GetDataSrcImageInfo(ctx, workspaceObj)
case len(workspaceObj.Tuning.Input.URLs) > 0:
initContainers, volumes, volumeMounts = handleURLDataSource(ctx, workspaceObj)
// TODO: Future PR include
// case len(workspaceObj.Tuning.Input.URLs) > 0:
// case workspaceObj.Tuning.Input.Volume != nil:
}
return initContainers, imagePullSecrets, volumes, volumeMounts, nil
Expand All @@ -185,6 +188,38 @@ func handleImageDataSource(ctx context.Context, workspaceObj *kaitov1alpha1.Work
return initContainers, volumes, volumeMounts
}

func handleURLDataSource(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace) ([]corev1.Container, []corev1.Volume, []corev1.VolumeMount) {
var initContainers []corev1.Container
initContainers = append(initContainers, corev1.Container{
Name: "data-downloader",
Image: "curlimages/curl",
Command: []string{"sh", "-c", `
for url in $DATA_URLS; do
filename=$(basename "$url" | sed 's/[?=&]/_/g')
curl -sSL $url -o $DATA_VOLUME_PATH/$filename
done
`},
VolumeMounts: []corev1.VolumeMount{
{
Name: "data-volume",
MountPath: utils.DefaultDataVolumePath,
},
},
Env: []corev1.EnvVar{
{
Name: "DATA_URLS",
Value: strings.Join(workspaceObj.Tuning.Input.URLs, " "),
},
{
Name: "DATA_VOLUME_PATH",
Value: utils.DefaultDataVolumePath,
},
},
})
volumes, volumeMounts := utils.ConfigDataVolume("")
return initContainers, volumes, volumeMounts
}

func prepareModelRunParameters(ctx context.Context, tuningObj *model.PresetParam) (string, error) {
modelCommand := utils.BuildCmdStr(TuningFile, tuningObj.ModelRunParams)
return modelCommand, nil
Expand Down
55 changes: 53 additions & 2 deletions pkg/tuning/preset-tuning_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ package tuning

import (
"context"
"os"
"strings"
"testing"

"github.com/azure/kaito/pkg/utils"

kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
"github.com/azure/kaito/pkg/model"
"github.com/azure/kaito/pkg/utils/test"
Expand All @@ -15,8 +21,6 @@ import (
"k8s.io/apimachinery/pkg/api/resource"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/utils/pointer"
"os"
"testing"
)

// Mocking the SupportedGPUConfigs to be used in test scenarios.
Expand All @@ -26,6 +30,10 @@ var mockSupportedGPUConfigs = map[string]kaitov1alpha1.GPUConfig{
"sku3": {GPUCount: 0},
}

func normalize(s string) string {
return strings.Join(strings.Fields(s), " ")
}

func TestGetInstanceGPUCount(t *testing.T) {
kaitov1alpha1.SupportedGPUConfigs = mockSupportedGPUConfigs
testcases := map[string]struct {
Expand Down Expand Up @@ -254,6 +262,49 @@ func TestHandleImageDataSource(t *testing.T) {
}
}

func TestHandleURLDataSource(t *testing.T) {
testcases := map[string]struct {
workspaceObj *kaitov1alpha1.Workspace
expectedInitContainerName string
expectedImage string
expectedCommands string
expectedVolumeName string
expectedVolumeMountPath string
}{
"Handle URL Data Source": {
workspaceObj: &kaitov1alpha1.Workspace{
Tuning: &kaitov1alpha1.TuningSpec{
Input: &kaitov1alpha1.DataSource{
URLs: []string{"http://example.com/data1.zip", "http://example.com/data2.zip"},
},
},
},
expectedInitContainerName: "data-downloader",
expectedImage: "curlimages/curl",
expectedCommands: "filename=$(basename \"$url\" | sed 's/[?=&]/_/g')\ncurl -sSL $url -o $DATA_VOLUME_PATH/$filename",
expectedVolumeName: "data-volume",
expectedVolumeMountPath: utils.DefaultDataVolumePath,
},
}

for name, tc := range testcases {
t.Run(name, func(t *testing.T) {
initContainers, volumes, volumeMounts := handleURLDataSource(context.Background(), tc.workspaceObj)

assert.Len(t, initContainers, 1)
assert.Equal(t, tc.expectedInitContainerName, initContainers[0].Name)
assert.Equal(t, tc.expectedImage, initContainers[0].Image)
assert.Contains(t, normalize(initContainers[0].Command[2]), normalize(tc.expectedCommands))

assert.Len(t, volumes, 1)
assert.Equal(t, tc.expectedVolumeName, volumes[0].Name)

assert.Len(t, volumeMounts, 1)
assert.Equal(t, tc.expectedVolumeMountPath, volumeMounts[0].MountPath)
})
}
}

func TestPrepareTuningParameters(t *testing.T) {
ctx := context.TODO()

Expand Down
2 changes: 1 addition & 1 deletion pkg/utils/common-preset.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

const (
DefaultVolumeMountPath = "/dev/shm"
DefaultConfigMapMountPath = "/config"
DefaultConfigMapMountPath = "/mnt/config"
DefaultDataVolumePath = "/mnt/data"
)

Expand Down

0 comments on commit a9af171

Please sign in to comment.