Skip to content

Commit

Permalink
Merge pull request #362 from elezar/add-feature-flags
Browse files Browse the repository at this point in the history
Add support for feature flags
  • Loading branch information
elezar authored Apr 3, 2024
2 parents c374520 + 09341a0 commit 413da20
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 15 deletions.
17 changes: 13 additions & 4 deletions cmd/nvidia-ctk/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ func run(c *cli.Context, opts *options) error {
if err != nil {
return fmt.Errorf("invalid --set option %v: %w", set, err)
}
cfgToml.Set(key, value)
if value == nil {
_ = cfgToml.Delete(key)
} else {
cfgToml.Set(key, value)
}
}

if err := opts.EnsureOutputFolder(); err != nil {
Expand Down Expand Up @@ -146,20 +150,25 @@ func setFlagToKeyValue(setFlag string) (string, interface{}, error) {

kind := field.Kind()
if len(setParts) != 2 {
if kind == reflect.Bool {
if kind == reflect.Bool || (kind == reflect.Pointer && field.Elem().Kind() == reflect.Bool) {
return key, true, nil
}
return key, nil, fmt.Errorf("%w: expected key=value; got %v", errInvalidFormat, setFlag)
}

value := setParts[1]
if kind == reflect.Pointer && value != "nil" {
kind = field.Elem().Kind()
}
switch kind {
case reflect.Pointer:
return key, nil, nil
case reflect.Bool:
b, err := strconv.ParseBool(value)
if err != nil {
return key, value, fmt.Errorf("%w: %w", errInvalidFormat, err)
}
return key, b, err
return key, b, nil
case reflect.String:
return key, value, nil
case reflect.Slice:
Expand Down Expand Up @@ -201,7 +210,7 @@ func getStruct(current reflect.Type, paths ...string) (reflect.StructField, erro
if !ok {
continue
}
if v != tomlField {
if strings.SplitN(v, ",", 2)[0] != tomlField {
continue
}
if len(paths) == 1 {
Expand Down
3 changes: 3 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ type Config struct {
NVIDIACTKConfig CTKConfig `toml:"nvidia-ctk"`
NVIDIAContainerRuntimeConfig RuntimeConfig `toml:"nvidia-container-runtime"`
NVIDIAContainerRuntimeHookConfig RuntimeHookConfig `toml:"nvidia-container-runtime-hook"`

// Features allows for finer control over optional features.
Features features `toml:"features,omitempty"`
}

// GetConfigFilePath returns the path to the config file for the configured system
Expand Down
85 changes: 85 additions & 0 deletions internal/config/features.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package config

type featureName string

const (
FeatureGDS = featureName("gds")
FeatureMOFED = featureName("mofed")
FeatureNVSWITCH = featureName("nvswitch")
FeatureGDRCopy = featureName("gdrcopy")
)

// features specifies a set of named features.
type features struct {
GDS *feature `toml:"gds,omitempty"`
MOFED *feature `toml:"mofed,omitempty"`
NVSWITCH *feature `toml:"nvswitch,omitempty"`
GDRCopy *feature `toml:"gdrcopy,omitempty"`
}

type feature bool

// IsEnabled checks whether a specified named feature is enabled.
// An optional list of environments to check for feature-specific environment
// variables can also be supplied.
func (fs features) IsEnabled(n featureName, in ...getenver) bool {
featureEnvvars := map[featureName]string{
FeatureGDS: "NVIDIA_GDS",
FeatureMOFED: "NVIDIA_MOFED",
FeatureNVSWITCH: "NVIDIA_NVSWITCH",
FeatureGDRCopy: "NVIDIA_GDRCOPY",
}

envvar := featureEnvvars[n]
switch n {
case FeatureGDS:
return fs.GDS.isEnabled(envvar, in...)
case FeatureMOFED:
return fs.MOFED.isEnabled(envvar, in...)
case FeatureNVSWITCH:
return fs.NVSWITCH.isEnabled(envvar, in...)
case FeatureGDRCopy:
return fs.GDRCopy.isEnabled(envvar, in...)
default:
return false
}
}

// isEnabled checks whether a feature is enabled.
// If the enabled value is explicitly set, this is returned, otherwise the
// associated envvar is checked in the specified getenver for the string "enabled"
// A CUDA container / image can be passed here.
func (f *feature) isEnabled(envvar string, ins ...getenver) bool {
if f != nil {
return bool(*f)
}
if envvar == "" {
return false
}
for _, in := range ins {
if in.Getenv(envvar) == "enabled" {
return true
}
}
return false
}

type getenver interface {
Getenv(string) string
}
15 changes: 4 additions & 11 deletions internal/modifier/gated.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
)

const (
nvidiaGDSEnvvar = "NVIDIA_GDS"
nvidiaMOFEDEnvvar = "NVIDIA_MOFED"
nvidiaNVSWITCHEnvvar = "NVIDIA_NVSWITCH"
nvidiaGDRCOPYEnvvar = "NVIDIA_GDRCOPY"
)

// NewFeatureGatedModifier creates the modifiers for optional features.
// These include:
//
Expand All @@ -53,31 +46,31 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image
driverRoot := cfg.NVIDIAContainerCLIConfig.Root
devRoot := cfg.NVIDIAContainerCLIConfig.Root

if image.Getenv(nvidiaGDSEnvvar) == "enabled" {
if cfg.Features.IsEnabled(config.FeatureGDS, image) {
d, err := discover.NewGDSDiscoverer(logger, driverRoot, devRoot)
if err != nil {
return nil, fmt.Errorf("failed to construct discoverer for GDS devices: %w", err)
}
discoverers = append(discoverers, d)
}

if image.Getenv(nvidiaMOFEDEnvvar) == "enabled" {
if cfg.Features.IsEnabled(config.FeatureMOFED, image) {
d, err := discover.NewMOFEDDiscoverer(logger, devRoot)
if err != nil {
return nil, fmt.Errorf("failed to construct discoverer for MOFED devices: %w", err)
}
discoverers = append(discoverers, d)
}

if image.Getenv(nvidiaNVSWITCHEnvvar) == "enabled" {
if cfg.Features.IsEnabled(config.FeatureNVSWITCH, image) {
d, err := discover.NewNvSwitchDiscoverer(logger, devRoot)
if err != nil {
return nil, fmt.Errorf("failed to construct discoverer for NVSWITCH devices: %w", err)
}
discoverers = append(discoverers, d)
}

if image.Getenv(nvidiaGDRCOPYEnvvar) == "enabled" {
if cfg.Features.IsEnabled(config.FeatureGDRCopy, image) {
d, err := discover.NewGDRCopyDiscoverer(logger, devRoot)
if err != nil {
return nil, fmt.Errorf("failed to construct discoverer for GDRCopy devices: %w", err)
Expand Down

0 comments on commit 413da20

Please sign in to comment.