forked from kaito-project/kaito
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: [SKU modularization] adding azure sku handler (kaito-project#360)
**Reason for Change**: Adding implementation for Azure sku handler along with unit test **Requirements** - [x] added unit tests and e2e tests (if applicable). **Issue Fixed**: <!-- If this PR fixes GitHub issue 4321, add "Fixes #4321" to the next line. --> **Notes for Reviewers**:
- Loading branch information
1 parent
8b7941a
commit 5c3fb02
Showing
3 changed files
with
102 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT license. | ||
|
||
package sku | ||
|
||
var _ CloudSKUHandler = &AzureSKUHandler{} | ||
|
||
type AzureSKUHandler struct { | ||
supportedSKUs map[string]GPUConfig | ||
} | ||
|
||
func NewAzureSKUHandler() *AzureSKUHandler { | ||
return &AzureSKUHandler{ | ||
supportedSKUs: map[string]GPUConfig{ | ||
"Standard_NC6s_v3": {SKU: "Standard_NC6s_v3", GPUCount: 1, GPUMem: 16, GPUModel: "NVIDIA V100"}, | ||
"Standard_NC12s_v3": {SKU: "Standard_NC12s_v3", GPUCount: 2, GPUMem: 32, GPUModel: "NVIDIA V100"}, | ||
"Standard_NC24s_v3": {SKU: "Standard_NC24s_v3", GPUCount: 4, GPUMem: 64, GPUModel: "NVIDIA V100"}, | ||
"Standard_NC24rs_v3": {SKU: "Standard_NC24rs_v3", GPUCount: 4, GPUMem: 64, GPUModel: "NVIDIA V100"}, | ||
"Standard_NC4as_T4_v3": {SKU: "Standard_NC4as_T4_v3", GPUCount: 1, GPUMem: 16, GPUModel: "NVIDIA T4"}, | ||
"Standard_NC8as_T4_v3": {SKU: "Standard_NC8as_T4_v3", GPUCount: 1, GPUMem: 16, GPUModel: "NVIDIA T4"}, | ||
"Standard_NC16as_T4_v3": {SKU: "Standard_NC16as_T4_v3", GPUCount: 1, GPUMem: 16, GPUModel: "NVIDIA T4"}, | ||
"Standard_NC64as_T4_v3": {SKU: "Standard_NC64as_T4_v3", GPUCount: 4, GPUMem: 64, GPUModel: "NVIDIA T4"}, | ||
"Standard_NC24ads_A100_v4": {SKU: "Standard_NC24ads_A100_v4", GPUCount: 1, GPUMem: 80, GPUModel: "NVIDIA A100"}, | ||
"Standard_NC48ads_A100_v4": {SKU: "Standard_NC48ads_A100_v4", GPUCount: 2, GPUMem: 160, GPUModel: "NVIDIA A100"}, | ||
"Standard_NC96ads_A100_v4": {SKU: "Standard_NC96ads_A100_v4", GPUCount: 4, GPUMem: 320, GPUModel: "NVIDIA A100"}, | ||
"Standard_ND96asr_A100_v4": {SKU: "Standard_ND96asr_A100_v4", GPUCount: 8, GPUMem: 320, GPUModel: "NVIDIA A100"}, | ||
"Standard_NG32ads_V620_v1": {SKU: "Standard_NG32ads_V620_v1", GPUCount: 1, GPUMem: 32, GPUModel: "AMD Radeon PRO V620"}, | ||
"Standard_NG32adms_V620_v1": {SKU: "Standard_NG32adms_V620_v1", GPUCount: 1, GPUMem: 32, GPUModel: "AMD Radeon PRO V620"}, | ||
"Standard_NV6": {SKU: "Standard_NV6", GPUCount: 1, GPUMem: 8, GPUModel: "NVIDIA M60"}, | ||
"Standard_NV12": {SKU: "Standard_NV12", GPUCount: 2, GPUMem: 16, GPUModel: "NVIDIA M60"}, | ||
"Standard_NV24": {SKU: "Standard_NV24", GPUCount: 4, GPUMem: 32, GPUModel: "NVIDIA M60"}, | ||
"Standard_NV12s_v3": {SKU: "Standard_NV12s_v3", GPUCount: 1, GPUMem: 8, GPUModel: "NVIDIA M60"}, | ||
"Standard_NV24s_v3": {SKU: "Standard_NV24s_v3", GPUCount: 2, GPUMem: 16, GPUModel: "NVIDIA M60"}, | ||
"Standard_NV48s_v3": {SKU: "Standard_NV48s_v3", GPUCount: 4, GPUMem: 32, GPUModel: "NVIDIA M60"}, | ||
"Standard_NV32as_v4": {SKU: "Standard_NV32as_v4", GPUCount: 1, GPUMem: 16, GPUModel: "AMD Radeon Instinct MI25"}, | ||
"Standard_ND96amsr_A100_v4": {SKU: "Standard_ND96amsr_A100_v4", GPUCount: 8, GPUMem: 80, GPUModel: "NVIDIA A100"}, | ||
|
||
// Not supporting partial gpu skus for now | ||
// "Standard_NG8ads_V620_v1": {SKU: "Standard_NG8ads_V620_v1", GPUCount: 1.0 / 4.0, GPUMem: 8, GPUModel: "AMD Radeon PRO V620"}, | ||
// "Standard_NG16ads_V620_v1": {SKU: "Standard_NG16ads_V620_v1", GPUCount: 1.0 / 2.0, GPUMem: 16, GPUModel: "AMD Radeon PRO V620"}, | ||
// "Standard_NV4as_v4": {SKU: "Standard_NV4as_v4", GPUCount: 1.0 / 8.0, GPUMem: 2, GPUModel: "AMD Radeon Instinct MI25"}, | ||
// "Standard_NV8as_v4": {SKU: "Standard_NV8as_v4", GPUCount: 1.0 / 4.0, GPUMem: 4, GPUModel: "AMD Radeon Instinct MI25"}, | ||
// "Standard_NV16as_v4": {SKU: "Standard_NV16as_v4", GPUCount: 1.0 / 2.0, GPUMem: 8, GPUModel: "AMD Radeon Instinct MI25"}, | ||
}, | ||
} | ||
} | ||
|
||
func (a *AzureSKUHandler) GetSupportedSKUs() []string { | ||
return GetMapKeys(a.supportedSKUs) | ||
} | ||
|
||
func (a *AzureSKUHandler) GetGPUConfigs() map[string]GPUConfig { | ||
return a.supportedSKUs | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT license. | ||
|
||
package sku | ||
|
||
import ( | ||
"testing" | ||
) | ||
|
||
func TestAzureSKUHandler(t *testing.T) { | ||
handler := NewAzureSKUHandler() | ||
|
||
// Test GetSupportedSKUs | ||
skus := handler.GetSupportedSKUs() | ||
if len(skus) == 0 { | ||
t.Errorf("GetSupportedSKUs returned an empty array") | ||
} | ||
|
||
// Test GetGPUConfigs with a SKU that is supported | ||
sku := "Standard_NC6s_v3" | ||
configMap := handler.GetGPUConfigs() | ||
config, exists := configMap[sku] | ||
if !exists { | ||
t.Errorf("Supported SKU missing from GPUConfigs") | ||
} | ||
if config.SKU != sku { | ||
t.Errorf("Incorrect config returned for a supported SKU") | ||
} | ||
|
||
// Test GetGPUConfigs with a SKU that is not supported | ||
sku = "Unsupported_SKU" | ||
config, exists = configMap[sku] | ||
if exists { | ||
t.Errorf("Unsupported SKU found in GPUConfigs") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT license. | ||
|
||
package sku | ||
|
||
func GetMapKeys(m map[string]GPUConfig) []string { | ||
keys := make([]string, 0, len(m)) | ||
for k := range m { | ||
keys = append(keys, k) | ||
} | ||
return keys | ||
} |