diff --git a/pkg/scheduler/scheduler.go b/pkg/scheduler/scheduler.go index de6963247..fa1774656 100644 --- a/pkg/scheduler/scheduler.go +++ b/pkg/scheduler/scheduler.go @@ -285,9 +285,14 @@ func filterWorkersByResources(workers []*types.Worker, request *types.ContainerR filteredWorkers := []*types.Worker{} gpuRequestsMap := map[string]int{} requiresGPU := request.RequiresGPU() + requiresAnyGPU := false for index, gpu := range request.GpuRequest { gpuRequestsMap[gpu] = index + if gpu == types.ANY_GPU.String() { + requiresAnyGPU = true + break + } } for _, worker := range workers { @@ -309,7 +314,7 @@ func filterWorkersByResources(workers []*types.Worker, request *types.ContainerR continue } - if requiresGPU { + if requiresGPU && !requiresAnyGPU { // Validate GPU resource availability priorityModifier, validGpu := gpuRequestsMap[worker.Gpu] if !validGpu || worker.FreeGpuCount < request.GpuCount { diff --git a/pkg/scheduler/scheduler_test.go b/pkg/scheduler/scheduler_test.go index 87343206e..a47dd2509 100644 --- a/pkg/scheduler/scheduler_test.go +++ b/pkg/scheduler/scheduler_test.go @@ -375,6 +375,7 @@ func TestSelectGPUWorker(t *testing.T) { assert.NotNil(t, wb) newWorker := &types.Worker{ + Id: uuid.New().String(), Status: types.WorkerStatusPending, FreeCpu: 1000, FreeMemory: 1000, @@ -402,6 +403,12 @@ func TestSelectGPUWorker(t *testing.T) { GpuRequest: []string{"T4"}, } + thirdRequest := &types.ContainerRequest{ + Cpu: 1000, + Memory: 1000, + GpuRequest: []string{"any"}, + } + // CPU request should not be able to select a GPU worker _, err = wb.selectWorker(cpuRequest) assert.Error(t, err) @@ -427,6 +434,29 @@ func TestSelectGPUWorker(t *testing.T) { _, ok = err.(*types.ErrNoSuitableWorkerFound) assert.True(t, ok) + + newWorkerAnyGpu := &types.Worker{ + Id: uuid.New().String(), + Status: types.WorkerStatusPending, + FreeCpu: 1000, + FreeMemory: 1000, + Gpu: "T4", + } + + err = wb.workerRepo.AddWorker(newWorkerAnyGpu) + assert.Nil(t, err) + + // Select a worker for the request + worker, err = wb.selectWorker(thirdRequest) + assert.Nil(t, err) + + // Check if the worker selected has the "T4" GPU + assert.Equal(t, newWorkerAnyGpu.Gpu, worker.Gpu) + assert.Equal(t, newWorkerAnyGpu.Id, worker.Id) + + // Actually schedule the request + err = wb.scheduleRequest(worker, thirdRequest) + assert.Nil(t, err) } func TestSelectCPUWorker(t *testing.T) { diff --git a/pkg/types/gpu.go b/pkg/types/gpu.go index 646576cd3..7171ee01d 100644 --- a/pkg/types/gpu.go +++ b/pkg/types/gpu.go @@ -16,7 +16,8 @@ const ( GPU_A6000 GPUType = "A6000" GPU_RTX4090 GPUType = "RTX4090" - NO_GPU GPUType = "NO_GPU" + NO_GPU GPUType = "NO_GPU" + ANY_GPU GPUType = "any" ) func AllGPUTypes() []GPUType {