diff --git a/internal/config/features.go b/internal/config/features.go index 913f14525..7ec55aa1a 100644 --- a/internal/config/features.go +++ b/internal/config/features.go @@ -25,6 +25,11 @@ type features struct { // If this feature flag is not set to 'true' only host-rooted config paths // (i.e. paths starting with an '@' are considered valid) AllowLDConfigFromContainer *feature `toml:"allow-ldconfig-from-container,omitempty"` + // AllowUnknownOCISpecFields allows the nvidia-container-runtime to ignore + // unknown fields when loading the config (OCI spec) associated with a + // container. + // If this is enabled, these fields are silently dropped. + AllowUnknownOCISpecFields *feature `toml:"allow-unknown-oci-spec-fields,omitempty"` // DisableCUDACompatLibHook, when enabled skips the injection of a specific // hook to process CUDA compatibility libraries. // diff --git a/internal/oci/options.go b/internal/oci/options.go new file mode 100644 index 000000000..f9e469aff --- /dev/null +++ b/internal/oci/options.go @@ -0,0 +1,39 @@ +/** +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package oci + +import "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + +type options struct { + logger logger.Interface + allowUnkownFields bool +} + +type Option func(*options) + +func WithLogger(logger logger.Interface) Option { + return func(o *options) { + o.logger = logger + } +} + +func WithAllowUnknownFields(allowUnknownFields bool) Option { + return func(o *options) { + o.allowUnkownFields = allowUnknownFields + } +} diff --git a/internal/oci/spec.go b/internal/oci/spec.go index 1e2c144a7..73030b76d 100644 --- a/internal/oci/spec.go +++ b/internal/oci/spec.go @@ -20,8 +20,6 @@ import ( "fmt" "github.com/opencontainers/runtime-spec/specs-go" - - "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" ) // SpecModifier defines an interface for modifying a (raw) OCI spec @@ -49,17 +47,24 @@ type Spec interface { // NewSpec creates fileSpec based on the command line arguments passed to the // application using the specified logger. -func NewSpec(logger logger.Interface, args []string) (Spec, error) { +func NewSpec(args []string, opts ...Option) (Spec, error) { + o := &options{ + allowUnkownFields: false, + } + for _, opt := range opts { + opt(o) + } + bundleDir, err := GetBundleDir(args) if err != nil { return nil, fmt.Errorf("error getting bundle directory: %v", err) } - logger.Debugf("Using bundle directory: %v", bundleDir) + o.logger.Debugf("Using bundle directory: %v", bundleDir) ociSpecPath := GetSpecFilePath(bundleDir) - logger.Infof("Using OCI specification file path: %v", ociSpecPath) + o.logger.Infof("Using OCI specification file path: %v", ociSpecPath) - ociSpec := NewFileSpec(ociSpecPath) + ociSpec := NewFileSpec(ociSpecPath, !o.allowUnkownFields) return ociSpec, nil } diff --git a/internal/oci/spec_file.go b/internal/oci/spec_file.go index 8784ae922..d108e56c0 100644 --- a/internal/oci/spec_file.go +++ b/internal/oci/spec_file.go @@ -28,16 +28,28 @@ import ( type fileSpec struct { memorySpec path string + loader } +type loader bool + +const ( + strictLoader = loader(true) +) + var _ Spec = (*fileSpec)(nil) // NewFileSpec creates an object that encapsulates a file-backed OCI spec. // This can be used to read from the file, modify the spec, and write to the // same file. -func NewFileSpec(filepath string) Spec { +func NewFileSpec(filepath string, isStrict bool) Spec { + var loader loader + if isStrict { + loader = strictLoader + } oci := fileSpec{ - path: filepath, + path: filepath, + loader: loader, } return &oci @@ -52,7 +64,7 @@ func (s *fileSpec) Load() (*specs.Spec, error) { } defer specFile.Close() - spec, err := LoadFrom(specFile) + spec, err := s.loadFrom(specFile) if err != nil { return nil, fmt.Errorf("error loading OCI specification from file: %v", err) } @@ -61,8 +73,11 @@ func (s *fileSpec) Load() (*specs.Spec, error) { } // LoadFrom reads the contents of the OCI spec from the specified io.Reader. -func LoadFrom(reader io.Reader) (*specs.Spec, error) { +func (isStrict loader) loadFrom(reader io.Reader) (*specs.Spec, error) { decoder := json.NewDecoder(reader) + if isStrict { + decoder.DisallowUnknownFields() + } var spec specs.Spec diff --git a/internal/oci/spec_file_test.go b/internal/oci/spec_file_test.go index e1c1fe0f3..ca1ef6258 100644 --- a/internal/oci/spec_file_test.go +++ b/internal/oci/spec_file_test.go @@ -44,7 +44,7 @@ func TestLoadFrom(t *testing.T) { for i, tc := range testCases { var spec *specs.Spec - spec, err := LoadFrom(bytes.NewReader(tc.contents)) + spec, err := strictLoader.loadFrom(bytes.NewReader(tc.contents)) if tc.isError { require.Error(t, err, "%d: %v", i, tc) diff --git a/internal/oci/spec_test.go b/internal/oci/spec_test.go index f11493785..e083966a8 100644 --- a/internal/oci/spec_test.go +++ b/internal/oci/spec_test.go @@ -21,7 +21,7 @@ func TestMaintainSpec(t *testing.T) { for _, f := range files { inputSpecPath := filepath.Join(moduleRoot, "tests/input", f) - spec := NewFileSpec(inputSpecPath).(*fileSpec) + spec := NewFileSpec(inputSpecPath, true).(*fileSpec) _, err := spec.Load() require.NoError(t, err) diff --git a/internal/oci/state.go b/internal/oci/state.go index 2bb4e6e53..411bb306c 100644 --- a/internal/oci/state.go +++ b/internal/oci/state.go @@ -56,26 +56,10 @@ func ReadContainerState(reader io.Reader) (*State, error) { return &s, nil } -// LoadSpec loads the OCI spec associated with the container state -func (s *State) LoadSpec() (*specs.Spec, error) { - specFilePath := GetSpecFilePath(s.Bundle) - specFile, err := os.Open(specFilePath) - if err != nil { - return nil, fmt.Errorf("failed to open OCI spec file: %v", err) - } - defer specFile.Close() - - spec, err := LoadFrom(specFile) - if err != nil { - return nil, fmt.Errorf("failed to load OCI spec: %v", err) - } - return spec, nil -} - // GetContainerRoot returns the root for the container from the associated spec. If the spec is not yet loaded, it is // loaded and cached. func (s *State) GetContainerRoot() (string, error) { - spec, err := s.LoadSpec() + spec, err := s.loadMinimalSpec() if err != nil { return "", err } @@ -91,3 +75,27 @@ func (s *State) GetContainerRoot() (string, error) { return filepath.Join(s.Bundle, containerRoot), nil } + +// loadMinimalSpec loads a reduced OCI spec associated with the container state. +func (s *State) loadMinimalSpec() (*minimalSpec, error) { + specFilePath := GetSpecFilePath(s.Bundle) + specFile, err := os.Open(specFilePath) + if err != nil { + return nil, fmt.Errorf("failed to open OCI spec file: %v", err) + } + defer specFile.Close() + + ms := &minimalSpec{} + if err := json.NewDecoder(specFile).Decode(ms); err != nil { + return nil, fmt.Errorf("failed to load minimal OCI spec: %v", err) + } + return ms, nil +} + +// A minimalSpec is used to return desired properties from the container config. +// We define this here instead of using specs.Spec as is to avoid decoding +// unneeded fields in container lifecycle hooks. +type minimalSpec struct { + // Root configures the container's root filesystem. + Root *specs.Root `json:"root,omitempty"` +} diff --git a/internal/runtime/runtime_factory.go b/internal/runtime/runtime_factory.go index dc6424c36..bb232a27e 100644 --- a/internal/runtime/runtime_factory.go +++ b/internal/runtime/runtime_factory.go @@ -42,7 +42,10 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv return lowLevelRuntime, nil } - ociSpec, err := oci.NewSpec(logger, argv) + ociSpec, err := oci.NewSpec(argv, + oci.WithLogger(logger), + oci.WithAllowUnknownFields(cfg.Features.AllowUnknownOCISpecFields.IsEnabled()), + ) if err != nil { return nil, fmt.Errorf("error constructing OCI specification: %v", err) }