diff --git a/pkg/sku/azure_sku_handler.go b/pkg/sku/azure_sku_handler.go new file mode 100644 index 000000000..04d3fe6b7 --- /dev/null +++ b/pkg/sku/azure_sku_handler.go @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sku + +var _ CloudSKUHandler = &AzureSKUHandler{} + +type AzureSKUHandler struct { + supportedSKUs map[string]GPUConfig +} + +func NewAzureSKUHandler() *AzureSKUHandler { + return &AzureSKUHandler{ + supportedSKUs: map[string]GPUConfig{ + "Standard_NC6s_v3": {SKU: "Standard_NC6s_v3", GPUCount: 1, GPUMem: 16, GPUModel: "NVIDIA V100"}, + "Standard_NC12s_v3": {SKU: "Standard_NC12s_v3", GPUCount: 2, GPUMem: 32, GPUModel: "NVIDIA V100"}, + "Standard_NC24s_v3": {SKU: "Standard_NC24s_v3", GPUCount: 4, GPUMem: 64, GPUModel: "NVIDIA V100"}, + "Standard_NC24rs_v3": {SKU: "Standard_NC24rs_v3", GPUCount: 4, GPUMem: 64, GPUModel: "NVIDIA V100"}, + "Standard_NC4as_T4_v3": {SKU: "Standard_NC4as_T4_v3", GPUCount: 1, GPUMem: 16, GPUModel: "NVIDIA T4"}, + "Standard_NC8as_T4_v3": {SKU: "Standard_NC8as_T4_v3", GPUCount: 1, GPUMem: 16, GPUModel: "NVIDIA T4"}, + "Standard_NC16as_T4_v3": {SKU: "Standard_NC16as_T4_v3", GPUCount: 1, GPUMem: 16, GPUModel: "NVIDIA T4"}, + "Standard_NC64as_T4_v3": {SKU: "Standard_NC64as_T4_v3", GPUCount: 4, GPUMem: 64, GPUModel: "NVIDIA T4"}, + "Standard_NC24ads_A100_v4": {SKU: "Standard_NC24ads_A100_v4", GPUCount: 1, GPUMem: 80, GPUModel: "NVIDIA A100"}, + "Standard_NC48ads_A100_v4": {SKU: "Standard_NC48ads_A100_v4", GPUCount: 2, GPUMem: 160, GPUModel: "NVIDIA A100"}, + "Standard_NC96ads_A100_v4": {SKU: "Standard_NC96ads_A100_v4", GPUCount: 4, GPUMem: 320, GPUModel: "NVIDIA A100"}, + "Standard_ND96asr_A100_v4": {SKU: "Standard_ND96asr_A100_v4", GPUCount: 8, GPUMem: 320, GPUModel: "NVIDIA A100"}, + "Standard_NG32ads_V620_v1": {SKU: "Standard_NG32ads_V620_v1", GPUCount: 1, GPUMem: 32, GPUModel: "AMD Radeon PRO V620"}, + "Standard_NG32adms_V620_v1": {SKU: "Standard_NG32adms_V620_v1", GPUCount: 1, GPUMem: 32, GPUModel: "AMD Radeon PRO V620"}, + "Standard_NV6": {SKU: "Standard_NV6", GPUCount: 1, GPUMem: 8, GPUModel: "NVIDIA M60"}, + "Standard_NV12": {SKU: "Standard_NV12", GPUCount: 2, GPUMem: 16, GPUModel: "NVIDIA M60"}, + "Standard_NV24": {SKU: "Standard_NV24", GPUCount: 4, GPUMem: 32, GPUModel: "NVIDIA M60"}, + "Standard_NV12s_v3": {SKU: "Standard_NV12s_v3", GPUCount: 1, GPUMem: 8, GPUModel: "NVIDIA M60"}, + "Standard_NV24s_v3": {SKU: "Standard_NV24s_v3", GPUCount: 2, GPUMem: 16, GPUModel: "NVIDIA M60"}, + "Standard_NV48s_v3": {SKU: "Standard_NV48s_v3", GPUCount: 4, GPUMem: 32, GPUModel: "NVIDIA M60"}, + "Standard_NV32as_v4": {SKU: "Standard_NV32as_v4", GPUCount: 1, GPUMem: 16, GPUModel: "AMD Radeon Instinct MI25"}, + "Standard_ND96amsr_A100_v4": {SKU: "Standard_ND96amsr_A100_v4", GPUCount: 8, GPUMem: 80, GPUModel: "NVIDIA A100"}, + + // Not supporting partial gpu skus for now + // "Standard_NG8ads_V620_v1": {SKU: "Standard_NG8ads_V620_v1", GPUCount: 1.0 / 4.0, GPUMem: 8, GPUModel: "AMD Radeon PRO V620"}, + // "Standard_NG16ads_V620_v1": {SKU: "Standard_NG16ads_V620_v1", GPUCount: 1.0 / 2.0, GPUMem: 16, GPUModel: "AMD Radeon PRO V620"}, + // "Standard_NV4as_v4": {SKU: "Standard_NV4as_v4", GPUCount: 1.0 / 8.0, GPUMem: 2, GPUModel: "AMD Radeon Instinct MI25"}, + // "Standard_NV8as_v4": {SKU: "Standard_NV8as_v4", GPUCount: 1.0 / 4.0, GPUMem: 4, GPUModel: "AMD Radeon Instinct MI25"}, + // "Standard_NV16as_v4": {SKU: "Standard_NV16as_v4", GPUCount: 1.0 / 2.0, GPUMem: 8, GPUModel: "AMD Radeon Instinct MI25"}, + }, + } +} + +func (a *AzureSKUHandler) GetSupportedSKUs() []string { + return GetMapKeys(a.supportedSKUs) +} + +func (a *AzureSKUHandler) GetGPUConfigs() map[string]GPUConfig { + return a.supportedSKUs +} diff --git a/pkg/sku/cloud_sku_handler_test.go b/pkg/sku/cloud_sku_handler_test.go new file mode 100644 index 000000000..2b016f834 --- /dev/null +++ b/pkg/sku/cloud_sku_handler_test.go @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sku + +import ( + "testing" +) + +func TestAzureSKUHandler(t *testing.T) { + handler := NewAzureSKUHandler() + + // Test GetSupportedSKUs + skus := handler.GetSupportedSKUs() + if len(skus) == 0 { + t.Errorf("GetSupportedSKUs returned an empty array") + } + + // Test GetGPUConfigs with a SKU that is supported + sku := "Standard_NC6s_v3" + configMap := handler.GetGPUConfigs() + config, exists := configMap[sku] + if !exists { + t.Errorf("Supported SKU missing from GPUConfigs") + } + if config.SKU != sku { + t.Errorf("Incorrect config returned for a supported SKU") + } + + // Test GetGPUConfigs with a SKU that is not supported + sku = "Unsupported_SKU" + config, exists = configMap[sku] + if exists { + t.Errorf("Unsupported SKU found in GPUConfigs") + } +} diff --git a/pkg/sku/sku_utils.go b/pkg/sku/sku_utils.go new file mode 100644 index 000000000..4198e3b49 --- /dev/null +++ b/pkg/sku/sku_utils.go @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sku + +func GetMapKeys(m map[string]GPUConfig) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +}