Skip to content

Commit

Permalink
feat: [SKU modularization] adding azure sku handler (kaito-project#360)
Browse files Browse the repository at this point in the history
**Reason for Change**:
Adding implementation for Azure sku handler along with unit test

**Requirements**

- [x] added unit tests and e2e tests (if applicable).

**Issue Fixed**:
<!-- If this PR fixes GitHub issue 4321, add "Fixes #4321" to the next
line. -->

**Notes for Reviewers**:
  • Loading branch information
smritidahal653 authored Apr 26, 2024
1 parent 8b7941a commit 5c3fb02
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 0 deletions.
54 changes: 54 additions & 0 deletions pkg/sku/azure_sku_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

package sku

var _ CloudSKUHandler = &AzureSKUHandler{}

type AzureSKUHandler struct {
supportedSKUs map[string]GPUConfig
}

func NewAzureSKUHandler() *AzureSKUHandler {
return &AzureSKUHandler{
supportedSKUs: map[string]GPUConfig{
"Standard_NC6s_v3": {SKU: "Standard_NC6s_v3", GPUCount: 1, GPUMem: 16, GPUModel: "NVIDIA V100"},
"Standard_NC12s_v3": {SKU: "Standard_NC12s_v3", GPUCount: 2, GPUMem: 32, GPUModel: "NVIDIA V100"},
"Standard_NC24s_v3": {SKU: "Standard_NC24s_v3", GPUCount: 4, GPUMem: 64, GPUModel: "NVIDIA V100"},
"Standard_NC24rs_v3": {SKU: "Standard_NC24rs_v3", GPUCount: 4, GPUMem: 64, GPUModel: "NVIDIA V100"},
"Standard_NC4as_T4_v3": {SKU: "Standard_NC4as_T4_v3", GPUCount: 1, GPUMem: 16, GPUModel: "NVIDIA T4"},
"Standard_NC8as_T4_v3": {SKU: "Standard_NC8as_T4_v3", GPUCount: 1, GPUMem: 16, GPUModel: "NVIDIA T4"},
"Standard_NC16as_T4_v3": {SKU: "Standard_NC16as_T4_v3", GPUCount: 1, GPUMem: 16, GPUModel: "NVIDIA T4"},
"Standard_NC64as_T4_v3": {SKU: "Standard_NC64as_T4_v3", GPUCount: 4, GPUMem: 64, GPUModel: "NVIDIA T4"},
"Standard_NC24ads_A100_v4": {SKU: "Standard_NC24ads_A100_v4", GPUCount: 1, GPUMem: 80, GPUModel: "NVIDIA A100"},
"Standard_NC48ads_A100_v4": {SKU: "Standard_NC48ads_A100_v4", GPUCount: 2, GPUMem: 160, GPUModel: "NVIDIA A100"},
"Standard_NC96ads_A100_v4": {SKU: "Standard_NC96ads_A100_v4", GPUCount: 4, GPUMem: 320, GPUModel: "NVIDIA A100"},
"Standard_ND96asr_A100_v4": {SKU: "Standard_ND96asr_A100_v4", GPUCount: 8, GPUMem: 320, GPUModel: "NVIDIA A100"},
"Standard_NG32ads_V620_v1": {SKU: "Standard_NG32ads_V620_v1", GPUCount: 1, GPUMem: 32, GPUModel: "AMD Radeon PRO V620"},
"Standard_NG32adms_V620_v1": {SKU: "Standard_NG32adms_V620_v1", GPUCount: 1, GPUMem: 32, GPUModel: "AMD Radeon PRO V620"},
"Standard_NV6": {SKU: "Standard_NV6", GPUCount: 1, GPUMem: 8, GPUModel: "NVIDIA M60"},
"Standard_NV12": {SKU: "Standard_NV12", GPUCount: 2, GPUMem: 16, GPUModel: "NVIDIA M60"},
"Standard_NV24": {SKU: "Standard_NV24", GPUCount: 4, GPUMem: 32, GPUModel: "NVIDIA M60"},
"Standard_NV12s_v3": {SKU: "Standard_NV12s_v3", GPUCount: 1, GPUMem: 8, GPUModel: "NVIDIA M60"},
"Standard_NV24s_v3": {SKU: "Standard_NV24s_v3", GPUCount: 2, GPUMem: 16, GPUModel: "NVIDIA M60"},
"Standard_NV48s_v3": {SKU: "Standard_NV48s_v3", GPUCount: 4, GPUMem: 32, GPUModel: "NVIDIA M60"},
"Standard_NV32as_v4": {SKU: "Standard_NV32as_v4", GPUCount: 1, GPUMem: 16, GPUModel: "AMD Radeon Instinct MI25"},
"Standard_ND96amsr_A100_v4": {SKU: "Standard_ND96amsr_A100_v4", GPUCount: 8, GPUMem: 80, GPUModel: "NVIDIA A100"},

// Not supporting partial gpu skus for now
// "Standard_NG8ads_V620_v1": {SKU: "Standard_NG8ads_V620_v1", GPUCount: 1.0 / 4.0, GPUMem: 8, GPUModel: "AMD Radeon PRO V620"},
// "Standard_NG16ads_V620_v1": {SKU: "Standard_NG16ads_V620_v1", GPUCount: 1.0 / 2.0, GPUMem: 16, GPUModel: "AMD Radeon PRO V620"},
// "Standard_NV4as_v4": {SKU: "Standard_NV4as_v4", GPUCount: 1.0 / 8.0, GPUMem: 2, GPUModel: "AMD Radeon Instinct MI25"},
// "Standard_NV8as_v4": {SKU: "Standard_NV8as_v4", GPUCount: 1.0 / 4.0, GPUMem: 4, GPUModel: "AMD Radeon Instinct MI25"},
// "Standard_NV16as_v4": {SKU: "Standard_NV16as_v4", GPUCount: 1.0 / 2.0, GPUMem: 8, GPUModel: "AMD Radeon Instinct MI25"},
},
}
}

func (a *AzureSKUHandler) GetSupportedSKUs() []string {
return GetMapKeys(a.supportedSKUs)
}

func (a *AzureSKUHandler) GetGPUConfigs() map[string]GPUConfig {
return a.supportedSKUs
}
36 changes: 36 additions & 0 deletions pkg/sku/cloud_sku_handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

package sku

import (
"testing"
)

func TestAzureSKUHandler(t *testing.T) {
handler := NewAzureSKUHandler()

// Test GetSupportedSKUs
skus := handler.GetSupportedSKUs()
if len(skus) == 0 {
t.Errorf("GetSupportedSKUs returned an empty array")
}

// Test GetGPUConfigs with a SKU that is supported
sku := "Standard_NC6s_v3"
configMap := handler.GetGPUConfigs()
config, exists := configMap[sku]
if !exists {
t.Errorf("Supported SKU missing from GPUConfigs")
}
if config.SKU != sku {
t.Errorf("Incorrect config returned for a supported SKU")
}

// Test GetGPUConfigs with a SKU that is not supported
sku = "Unsupported_SKU"
config, exists = configMap[sku]
if exists {
t.Errorf("Unsupported SKU found in GPUConfigs")
}
}
12 changes: 12 additions & 0 deletions pkg/sku/sku_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

package sku

func GetMapKeys(m map[string]GPUConfig) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}

0 comments on commit 5c3fb02

Please sign in to comment.