diff --git a/pkg/cdi/cache.go b/pkg/cdi/cache.go index c2f7fe3..f1e4fd0 100644 --- a/pkg/cdi/cache.go +++ b/pkg/cdi/cache.go @@ -28,6 +28,7 @@ import ( "github.com/fsnotify/fsnotify" oci "github.com/opencontainers/runtime-spec/specs-go" + "tags.cncf.io/container-device-interface/pkg/cdi/producer" cdi "tags.cncf.io/container-device-interface/specs-go" ) @@ -280,30 +281,31 @@ func (c *Cache) highestPrioritySpecDir() (string, int) { // priority Spec directory. If name has a "json" or "yaml" extension it // choses the encoding. Otherwise the default YAML encoding is used. func (c *Cache) WriteSpec(raw *cdi.Spec, name string) error { - var ( - specDir string - path string - prio int - spec *Spec - err error - ) - - specDir, prio = c.highestPrioritySpecDir() + specDir, _ := c.highestPrioritySpecDir() if specDir == "" { return errors.New("no Spec directories to write to") } - path = filepath.Join(specDir, name) - if ext := filepath.Ext(path); ext != ".json" && ext != ".yaml" { - path += defaultSpecExt + // Ideally we would like to pass the configured spec validator to the + // producer, but we would need to handle the synchronisation. + // Instead we call `validateSpec` here which is a no-op if no validator is + // configured. + if err := validateSpec(raw); err != nil { + return err } - spec, err = newSpec(raw, path, prio) + p, err := producer.New( + producer.WithOverwrite(true), + ) if err != nil { return err } - return spec.write(true) + path := filepath.Join(specDir, name) + if _, err := p.SaveSpec(raw, path); err != nil { + return err + } + return nil } // RemoveSpec removes a Spec with the given name from the highest diff --git a/pkg/cdi/container-edits.go b/pkg/cdi/container-edits.go index a7ac70d..c5da38c 100644 --- a/pkg/cdi/container-edits.go +++ b/pkg/cdi/container-edits.go @@ -26,6 +26,7 @@ import ( oci "github.com/opencontainers/runtime-spec/specs-go" ocigen "github.com/opencontainers/runtime-tools/generate" + "tags.cncf.io/container-device-interface/pkg/cdi/producer/validator" cdi "tags.cncf.io/container-device-interface/specs-go" ) @@ -167,32 +168,7 @@ func (e *ContainerEdits) Validate() error { if e == nil || e.ContainerEdits == nil { return nil } - - if err := ValidateEnv(e.Env); err != nil { - return fmt.Errorf("invalid container edits: %w", err) - } - for _, d := range e.DeviceNodes { - if err := (&DeviceNode{d}).Validate(); err != nil { - return err - } - } - for _, h := range e.Hooks { - if err := (&Hook{h}).Validate(); err != nil { - return err - } - } - for _, m := range e.Mounts { - if err := (&Mount{m}).Validate(); err != nil { - return err - } - } - if e.IntelRdt != nil { - if err := (&IntelRdt{e.IntelRdt}).Validate(); err != nil { - return err - } - } - - return nil + return validator.Default.ValidateAny(e.ContainerEdits) } // Append other edits into this one. If called with a nil receiver, @@ -220,43 +196,6 @@ func (e *ContainerEdits) Append(o *ContainerEdits) *ContainerEdits { return e } -// isEmpty returns true if these edits are empty. This is valid in a -// global Spec context but invalid in a Device context. -func (e *ContainerEdits) isEmpty() bool { - if e == nil { - return false - } - if len(e.Env) > 0 { - return false - } - if len(e.DeviceNodes) > 0 { - return false - } - if len(e.Hooks) > 0 { - return false - } - if len(e.Mounts) > 0 { - return false - } - if len(e.AdditionalGIDs) > 0 { - return false - } - if e.IntelRdt != nil { - return false - } - return true -} - -// ValidateEnv validates the given environment variables. -func ValidateEnv(env []string) error { - for _, v := range env { - if strings.IndexByte(v, byte('=')) <= 0 { - return fmt.Errorf("invalid environment variable %q", v) - } - } - return nil -} - // DeviceNode is a CDI Spec DeviceNode wrapper, used for validating DeviceNodes. type DeviceNode struct { *cdi.DeviceNode @@ -264,27 +203,7 @@ type DeviceNode struct { // Validate a CDI Spec DeviceNode. func (d *DeviceNode) Validate() error { - validTypes := map[string]struct{}{ - "": {}, - "b": {}, - "c": {}, - "u": {}, - "p": {}, - } - - if d.Path == "" { - return errors.New("invalid (empty) device path") - } - if _, ok := validTypes[d.Type]; !ok { - return fmt.Errorf("device %q: invalid type %q", d.Path, d.Type) - } - for _, bit := range d.Permissions { - if bit != 'r' && bit != 'w' && bit != 'm' { - return fmt.Errorf("device %q: invalid permissions %q", - d.Path, d.Permissions) - } - } - return nil + return validator.Default.ValidateAny(d.DeviceNode) } // Hook is a CDI Spec Hook wrapper, used for validating hooks. @@ -294,16 +213,7 @@ type Hook struct { // Validate a hook. func (h *Hook) Validate() error { - if _, ok := validHookNames[h.HookName]; !ok { - return fmt.Errorf("invalid hook name %q", h.HookName) - } - if h.Path == "" { - return fmt.Errorf("invalid hook %q with empty path", h.HookName) - } - if err := ValidateEnv(h.Env); err != nil { - return fmt.Errorf("invalid hook %q: %w", h.HookName, err) - } - return nil + return validator.Default.ValidateAny(h.Hook) } // Mount is a CDI Mount wrapper, used for validating mounts. @@ -313,13 +223,7 @@ type Mount struct { // Validate a mount. func (m *Mount) Validate() error { - if m.HostPath == "" { - return errors.New("invalid mount, empty host path") - } - if m.ContainerPath == "" { - return errors.New("invalid mount, empty container path") - } - return nil + return validator.Default.ValidateAny(m.Mount) } // IntelRdt is a CDI IntelRdt wrapper. @@ -337,11 +241,7 @@ func ValidateIntelRdt(i *cdi.IntelRdt) error { // Validate validates the IntelRdt configuration. func (i *IntelRdt) Validate() error { - // ClosID must be a valid Linux filename - if len(i.ClosID) >= 4096 || i.ClosID == "." || i.ClosID == ".." || strings.ContainsAny(i.ClosID, "/\n") { - return errors.New("invalid ClosID") - } - return nil + return validator.Default.ValidateAny(i.IntelRdt) } // Ensure OCI Spec hooks are not nil so we can add hooks. diff --git a/pkg/cdi/device.go b/pkg/cdi/device.go index 2e5fa57..9ac050f 100644 --- a/pkg/cdi/device.go +++ b/pkg/cdi/device.go @@ -17,10 +17,8 @@ package cdi import ( - "fmt" - oci "github.com/opencontainers/runtime-spec/specs-go" - "tags.cncf.io/container-device-interface/internal/validation" + "tags.cncf.io/container-device-interface/pkg/cdi/producer/validator" "tags.cncf.io/container-device-interface/pkg/parser" cdi "tags.cncf.io/container-device-interface/specs-go" ) @@ -67,22 +65,5 @@ func (d *Device) edits() *ContainerEdits { // Validate the device. func (d *Device) validate() error { - if err := parser.ValidateDeviceName(d.Name); err != nil { - return err - } - name := d.Name - if d.spec != nil { - name = d.GetQualifiedName() - } - if err := validation.ValidateSpecAnnotations(name, d.Annotations); err != nil { - return err - } - edits := d.edits() - if edits.isEmpty() { - return fmt.Errorf("invalid device, empty device edits") - } - if err := edits.Validate(); err != nil { - return fmt.Errorf("invalid device %q: %w", d.Name, err) - } - return nil + return validator.Default.ValidateAny(d.Device) } diff --git a/pkg/cdi/producer/api.go b/pkg/cdi/producer/api.go index 8fbe4b1..037eb8b 100644 --- a/pkg/cdi/producer/api.go +++ b/pkg/cdi/producer/api.go @@ -16,6 +16,8 @@ package producer +import cdi "tags.cncf.io/container-device-interface/specs-go" + type specFormat string const ( @@ -27,3 +29,8 @@ const ( // SpecFormatYAML defines a CDI spec formatted as YAML. SpecFormatYAML = specFormat(".yaml") ) + +// a specValidator is used to validate a CDI spec. +type specValidator interface { + Validate(*cdi.Spec) error +} diff --git a/pkg/cdi/producer/options.go b/pkg/cdi/producer/options.go index 75d1746..014562f 100644 --- a/pkg/cdi/producer/options.go +++ b/pkg/cdi/producer/options.go @@ -16,7 +16,11 @@ package producer -import "fmt" +import ( + "fmt" + + "tags.cncf.io/container-device-interface/pkg/cdi/producer/validator" +) // An Option defines a functional option for constructing a producer. type Option func(*Producer) error @@ -34,6 +38,17 @@ func WithSpecFormat(format specFormat) Option { } } +// WithSpecValidator sets a validator to be used when writing an output spec. +func WithSpecValidator(v specValidator) Option { + return func(p *Producer) error { + if v == nil { + v = validator.Disabled + } + p.validator = v + return nil + } +} + // WithOverwrite specifies whether a producer should overwrite a CDI spec when // saving to file. func WithOverwrite(overwrite bool) Option { diff --git a/pkg/cdi/producer/producer.go b/pkg/cdi/producer/producer.go index dcecd8f..63eeab9 100644 --- a/pkg/cdi/producer/producer.go +++ b/pkg/cdi/producer/producer.go @@ -17,8 +17,10 @@ package producer import ( + "fmt" "path/filepath" + "tags.cncf.io/container-device-interface/pkg/cdi/producer/validator" cdi "tags.cncf.io/container-device-interface/specs-go" ) @@ -26,12 +28,14 @@ import ( type Producer struct { format specFormat failIfExists bool + validator specValidator } // New creates a new producer with the supplied options. func New(opts ...Option) (*Producer, error) { p := &Producer{ - format: DefaultSpecFormat, + format: DefaultSpecFormat, + validator: validator.Default, } for _, opt := range opts { err := opt(p) @@ -47,8 +51,11 @@ func New(opts ...Option) (*Producer, error) { // extension takes precedence over the format with which the Producer was // configured. func (p *Producer) SaveSpec(s *cdi.Spec, filename string) (string, error) { - filename = p.normalizeFilename(filename) + if err := p.Validate(s); err != nil { + return "", fmt.Errorf("spec validation failed: %w", err) + } + filename = p.normalizeFilename(filename) sp := spec{ Spec: s, format: p.specFormatFromFilename(filename), @@ -61,6 +68,15 @@ func (p *Producer) SaveSpec(s *cdi.Spec, filename string) (string, error) { return filename, nil } +// Validate performs a validation on a CDI spec using the configured validator. +// If no validator is configured, the spec is considered unconditionaly valid. +func (p *Producer) Validate(s *cdi.Spec) error { + if p == nil || p.validator == nil { + return nil + } + return p.validator.Validate(s) +} + // specFormatFromFilename determines the CDI spec format for the given filename. func (p *Producer) specFormatFromFilename(filename string) specFormat { switch filepath.Ext(filename) { diff --git a/pkg/cdi/producer/validator/api.go b/pkg/cdi/producer/validator/api.go new file mode 100644 index 0000000..7650dff --- /dev/null +++ b/pkg/cdi/producer/validator/api.go @@ -0,0 +1,23 @@ +/* + Copyright © 2024 The CDI Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package validator + +// Validators as constants. +const ( + Default = defaultValidator("default") + Disabled = disabledValidator("disabled") +) diff --git a/pkg/cdi/producer/validator/validator-default.go b/pkg/cdi/producer/validator/validator-default.go new file mode 100644 index 0000000..a4e2187 --- /dev/null +++ b/pkg/cdi/producer/validator/validator-default.go @@ -0,0 +1,245 @@ +/* + Copyright © 2024 The CDI Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package validator + +import ( + "errors" + "fmt" + "strings" + + "tags.cncf.io/container-device-interface/internal/validation" + "tags.cncf.io/container-device-interface/pkg/parser" + cdi "tags.cncf.io/container-device-interface/specs-go" +) + +type defaultValidator string + +// ValidateAny implements a generic validation handler for the defaultValidator. +func (v defaultValidator) ValidateAny(o interface{}) error { + switch o := o.(type) { + case *cdi.ContainerEdits: + return v.validateEdits(o) + case *cdi.Device: + return v.validateDevice("", "", o) + case *cdi.DeviceNode: + return v.validateDeviceNode(o) + case *cdi.Hook: + return v.validateHook(o) + case *cdi.IntelRdt: + return v.validateIntelRdt(o) + case *cdi.Mount: + return v.validateMount(o) + case *cdi.Spec: + return v.Validate(o) + default: + return fmt.Errorf("unsupported validation type: %T", o) + } +} + +// Validate performs a default validation on a CDI spec. +func (v defaultValidator) Validate(s *cdi.Spec) error { + if err := cdi.ValidateVersion(s); err != nil { + return err + } + vendor, class := parser.ParseQualifier(s.Kind) + if err := parser.ValidateVendorName(vendor); err != nil { + return err + } + if err := parser.ValidateClassName(class); err != nil { + return err + } + if err := validation.ValidateSpecAnnotations(s.Kind, s.Annotations); err != nil { + return err + } + if err := v.validateEdits(&s.ContainerEdits); err != nil { + return err + } + + seen := make(map[string]bool) + for _, d := range s.Devices { + if seen[d.Name] { + return fmt.Errorf("invalid spec, multiple device %q", d.Name) + } + seen[d.Name] = true + if err := v.validateDevice(vendor, class, &d); err != nil { + return fmt.Errorf("invalid device %q: %w", d.Name, err) + } + } + return nil +} + +func (v defaultValidator) validateDevice(vendor string, class string, d *cdi.Device) error { + if err := parser.ValidateDeviceName(d.Name); err != nil { + return err + } + + name := parser.QualifiedName(vendor, class, d.Name) + if err := validation.ValidateSpecAnnotations(name, d.Annotations); err != nil { + return err + } + + if err := v.assertNonEmptyEdits(&d.ContainerEdits); err != nil { + return err + } + if err := v.validateEdits(&d.ContainerEdits); err != nil { + return err + } + return nil +} + +func (v defaultValidator) assertNonEmptyEdits(e *cdi.ContainerEdits) error { + if e == nil { + return nil + } + if len(e.Env) > 0 { + return nil + } + if len(e.DeviceNodes) > 0 { + return nil + } + if len(e.Hooks) > 0 { + return nil + } + if len(e.Mounts) > 0 { + return nil + } + if len(e.AdditionalGIDs) > 0 { + return nil + } + if e.IntelRdt != nil { + return nil + } + return errors.New("empty container edits") +} + +func (v defaultValidator) validateEdits(e *cdi.ContainerEdits) error { + if e == nil { + return nil + } + if err := v.validateEnv(e.Env); err != nil { + return fmt.Errorf("invalid container edits: %w", err) + } + for _, d := range e.DeviceNodes { + if err := v.validateDeviceNode(d); err != nil { + return err + } + } + for _, h := range e.Hooks { + if err := v.validateHook(h); err != nil { + return err + } + } + for _, m := range e.Mounts { + if err := v.validateMount(m); err != nil { + return err + } + } + if err := v.validateIntelRdt(e.IntelRdt); err != nil { + return err + } + return nil +} + +func (v defaultValidator) validateEnv(env []string) error { + for _, v := range env { + if strings.IndexByte(v, byte('=')) <= 0 { + return fmt.Errorf("invalid environment variable %q", v) + } + } + return nil +} + +func (v defaultValidator) validateDeviceNode(d *cdi.DeviceNode) error { + validTypes := map[string]struct{}{ + "": {}, + "b": {}, + "c": {}, + "u": {}, + "p": {}, + } + + if d.Path == "" { + return errors.New("invalid (empty) device path") + } + if _, ok := validTypes[d.Type]; !ok { + return fmt.Errorf("device %q: invalid type %q", d.Path, d.Type) + } + for _, bit := range d.Permissions { + if bit != 'r' && bit != 'w' && bit != 'm' { + return fmt.Errorf("device %q: invalid permissions %q", + d.Path, d.Permissions) + } + } + return nil +} + +func (v defaultValidator) validateHook(h *cdi.Hook) error { + const ( + // PrestartHook is the name of the OCI "prestart" hook. + PrestartHook = "prestart" + // CreateRuntimeHook is the name of the OCI "createRuntime" hook. + CreateRuntimeHook = "createRuntime" + // CreateContainerHook is the name of the OCI "createContainer" hook. + CreateContainerHook = "createContainer" + // StartContainerHook is the name of the OCI "startContainer" hook. + StartContainerHook = "startContainer" + // PoststartHook is the name of the OCI "poststart" hook. + PoststartHook = "poststart" + // PoststopHook is the name of the OCI "poststop" hook. + PoststopHook = "poststop" + ) + validHookNames := map[string]struct{}{ + PrestartHook: {}, + CreateRuntimeHook: {}, + CreateContainerHook: {}, + StartContainerHook: {}, + PoststartHook: {}, + PoststopHook: {}, + } + + if _, ok := validHookNames[h.HookName]; !ok { + return fmt.Errorf("invalid hook name %q", h.HookName) + } + if h.Path == "" { + return fmt.Errorf("invalid hook %q with empty path", h.HookName) + } + if err := v.validateEnv(h.Env); err != nil { + return fmt.Errorf("invalid hook %q: %w", h.HookName, err) + } + return nil +} + +func (v defaultValidator) validateMount(m *cdi.Mount) error { + if m.HostPath == "" { + return errors.New("invalid mount, empty host path") + } + if m.ContainerPath == "" { + return errors.New("invalid mount, empty container path") + } + return nil +} + +func (v defaultValidator) validateIntelRdt(i *cdi.IntelRdt) error { + if i == nil { + return nil + } + // ClosID must be a valid Linux filename + if len(i.ClosID) >= 4096 || i.ClosID == "." || i.ClosID == ".." || strings.ContainsAny(i.ClosID, "/\n") { + return errors.New("invalid ClosID") + } + return nil +} diff --git a/pkg/cdi/producer/validator/validator-disabled.go b/pkg/cdi/producer/validator/validator-disabled.go new file mode 100644 index 0000000..9323dc7 --- /dev/null +++ b/pkg/cdi/producer/validator/validator-disabled.go @@ -0,0 +1,29 @@ +/* + Copyright © 2024 The CDI Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package validator + +import ( + cdi "tags.cncf.io/container-device-interface/specs-go" +) + +// A disabledValidator performs no validation. +type disabledValidator string + +// Validate always passes for a disabledValidator. +func (v disabledValidator) Validate(*cdi.Spec) error { + return nil +} diff --git a/pkg/cdi/spec.go b/pkg/cdi/spec.go index d617046..0846be1 100644 --- a/pkg/cdi/spec.go +++ b/pkg/cdi/spec.go @@ -26,8 +26,8 @@ import ( oci "github.com/opencontainers/runtime-spec/specs-go" "sigs.k8s.io/yaml" - "tags.cncf.io/container-device-interface/internal/validation" "tags.cncf.io/container-device-interface/pkg/cdi/producer" + "tags.cncf.io/container-device-interface/pkg/cdi/producer/validator" "tags.cncf.io/container-device-interface/pkg/parser" cdi "tags.cncf.io/container-device-interface/specs-go" ) @@ -176,22 +176,12 @@ func MinimumRequiredVersion(spec *cdi.Spec) (string, error) { // Validate the Spec. func (s *Spec) validate() (map[string]*Device, error) { - if err := cdi.ValidateVersion(s.Spec); err != nil { - return nil, err - } - if err := parser.ValidateVendorName(s.vendor); err != nil { - return nil, err - } - if err := parser.ValidateClassName(s.class); err != nil { - return nil, err - } - if err := validation.ValidateSpecAnnotations(s.Kind, s.Annotations); err != nil { - return nil, err - } - if err := s.edits().Validate(); err != nil { + if err := validator.Default.Validate(s.Spec); err != nil { return nil, err } + // TODO: The validator above should perform the same validation as below but + // we still need to construct the device map. devices := make(map[string]*Device) for _, d := range s.Devices { dev, err := newDevice(s, d)