Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions internal/config/features.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ type features struct {
// If this feature flag is not set to 'true' only host-rooted config paths
// (i.e. paths starting with an '@' are considered valid)
AllowLDConfigFromContainer *feature `toml:"allow-ldconfig-from-container,omitempty"`
// AllowUnknownOCISpecFields allows the nvidia-container-runtime to ignore
// unknown fields when loading the config (OCI spec) associated with a
// container.
// If this is enabled, these fields are silently dropped.
AllowUnknownOCISpecFields *feature `toml:"allow-unknown-oci-spec-fields,omitempty"`
// DisableCUDACompatLibHook, when enabled skips the injection of a specific
// hook to process CUDA compatibility libraries.
//
Expand Down
39 changes: 39 additions & 0 deletions internal/oci/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/**
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package oci

import "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"

type options struct {
logger logger.Interface
allowUnkownFields bool
}

type Option func(*options)

func WithLogger(logger logger.Interface) Option {
return func(o *options) {
o.logger = logger
}
}

func WithAllowUnknownFields(allowUnknownFields bool) Option {
return func(o *options) {
o.allowUnkownFields = allowUnknownFields
}
}
17 changes: 11 additions & 6 deletions internal/oci/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ import (
"fmt"

"github.com/opencontainers/runtime-spec/specs-go"

"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
)

// SpecModifier defines an interface for modifying a (raw) OCI spec
Expand Down Expand Up @@ -49,17 +47,24 @@ type Spec interface {

// NewSpec creates fileSpec based on the command line arguments passed to the
// application using the specified logger.
func NewSpec(logger logger.Interface, args []string) (Spec, error) {
func NewSpec(args []string, opts ...Option) (Spec, error) {
o := &options{
allowUnkownFields: false,
}
for _, opt := range opts {
opt(o)
}

bundleDir, err := GetBundleDir(args)
if err != nil {
return nil, fmt.Errorf("error getting bundle directory: %v", err)
}
logger.Debugf("Using bundle directory: %v", bundleDir)
o.logger.Debugf("Using bundle directory: %v", bundleDir)

ociSpecPath := GetSpecFilePath(bundleDir)
logger.Infof("Using OCI specification file path: %v", ociSpecPath)
o.logger.Infof("Using OCI specification file path: %v", ociSpecPath)

ociSpec := NewFileSpec(ociSpecPath)
ociSpec := NewFileSpec(ociSpecPath, !o.allowUnkownFields)

return ociSpec, nil
}
Expand Down
23 changes: 19 additions & 4 deletions internal/oci/spec_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,28 @@ import (
type fileSpec struct {
memorySpec
path string
loader
}

type loader bool

const (
strictLoader = loader(true)
)

var _ Spec = (*fileSpec)(nil)

// NewFileSpec creates an object that encapsulates a file-backed OCI spec.
// This can be used to read from the file, modify the spec, and write to the
// same file.
func NewFileSpec(filepath string) Spec {
func NewFileSpec(filepath string, isStrict bool) Spec {
var loader loader
if isStrict {
loader = strictLoader
}
oci := fileSpec{
path: filepath,
path: filepath,
loader: loader,
}

return &oci
Expand All @@ -52,7 +64,7 @@ func (s *fileSpec) Load() (*specs.Spec, error) {
}
defer specFile.Close()

spec, err := LoadFrom(specFile)
spec, err := s.loadFrom(specFile)
if err != nil {
return nil, fmt.Errorf("error loading OCI specification from file: %v", err)
}
Expand All @@ -61,8 +73,11 @@ func (s *fileSpec) Load() (*specs.Spec, error) {
}

// LoadFrom reads the contents of the OCI spec from the specified io.Reader.
func LoadFrom(reader io.Reader) (*specs.Spec, error) {
func (isStrict loader) loadFrom(reader io.Reader) (*specs.Spec, error) {
decoder := json.NewDecoder(reader)
if isStrict {
decoder.DisallowUnknownFields()
}

var spec specs.Spec

Expand Down
2 changes: 1 addition & 1 deletion internal/oci/spec_file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestLoadFrom(t *testing.T) {

for i, tc := range testCases {
var spec *specs.Spec
spec, err := LoadFrom(bytes.NewReader(tc.contents))
spec, err := strictLoader.loadFrom(bytes.NewReader(tc.contents))

if tc.isError {
require.Error(t, err, "%d: %v", i, tc)
Expand Down
2 changes: 1 addition & 1 deletion internal/oci/spec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestMaintainSpec(t *testing.T) {
for _, f := range files {
inputSpecPath := filepath.Join(moduleRoot, "tests/input", f)

spec := NewFileSpec(inputSpecPath).(*fileSpec)
spec := NewFileSpec(inputSpecPath, true).(*fileSpec)

_, err := spec.Load()
require.NoError(t, err)
Expand Down
42 changes: 25 additions & 17 deletions internal/oci/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,26 +56,10 @@ func ReadContainerState(reader io.Reader) (*State, error) {
return &s, nil
}

// LoadSpec loads the OCI spec associated with the container state
func (s *State) LoadSpec() (*specs.Spec, error) {
specFilePath := GetSpecFilePath(s.Bundle)
specFile, err := os.Open(specFilePath)
if err != nil {
return nil, fmt.Errorf("failed to open OCI spec file: %v", err)
}
defer specFile.Close()

spec, err := LoadFrom(specFile)
if err != nil {
return nil, fmt.Errorf("failed to load OCI spec: %v", err)
}
return spec, nil
}

// GetContainerRoot returns the root for the container from the associated spec. If the spec is not yet loaded, it is
// loaded and cached.
func (s *State) GetContainerRoot() (string, error) {
spec, err := s.LoadSpec()
spec, err := s.loadMinimalSpec()
if err != nil {
return "", err
}
Expand All @@ -91,3 +75,27 @@ func (s *State) GetContainerRoot() (string, error) {

return filepath.Join(s.Bundle, containerRoot), nil
}

// loadMinimalSpec loads a reduced OCI spec associated with the container state.
func (s *State) loadMinimalSpec() (*minimalSpec, error) {
specFilePath := GetSpecFilePath(s.Bundle)
specFile, err := os.Open(specFilePath)
if err != nil {
return nil, fmt.Errorf("failed to open OCI spec file: %v", err)
}
defer specFile.Close()

ms := &minimalSpec{}
if err := json.NewDecoder(specFile).Decode(ms); err != nil {
return nil, fmt.Errorf("failed to load minimal OCI spec: %v", err)
}
return ms, nil
}

// A minimalSpec is used to return desired properties from the container config.
// We define this here instead of using specs.Spec as is to avoid decoding
// unneeded fields in container lifecycle hooks.
type minimalSpec struct {
// Root configures the container's root filesystem.
Root *specs.Root `json:"root,omitempty"`
}
5 changes: 4 additions & 1 deletion internal/runtime/runtime_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv
return lowLevelRuntime, nil
}

ociSpec, err := oci.NewSpec(logger, argv)
ociSpec, err := oci.NewSpec(argv,
oci.WithLogger(logger),
oci.WithAllowUnknownFields(cfg.Features.AllowUnknownOCISpecFields.IsEnabled()),
)
if err != nil {
return nil, fmt.Errorf("error constructing OCI specification: %v", err)
}
Expand Down