Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
19 changes: 13 additions & 6 deletions cmd/nvidia-ctk-installer/container/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,15 @@ type Options struct {
// mount.
ExecutablePath string
// EnabledCDI indicates whether CDI should be enabled.
EnableCDI bool
RuntimeName string
RuntimeDir string
SetAsDefault bool
RestartMode string
HostRootMount string
EnableCDI bool
EnableNRI bool
RuntimeName string
RuntimeDir string
SetAsDefault bool
RestartMode string
HostRootMount string
NRIPluginIndex string
NRISocket string

ConfigSources []string
}
Expand Down Expand Up @@ -128,6 +131,10 @@ func (o Options) UpdateConfig(cfg engine.Interface) error {
cfg.EnableCDI()
}

if o.EnableNRI {
cfg.EnableNRI()
}

return nil
}

Expand Down
146 changes: 146 additions & 0 deletions cmd/nvidia-ctk-installer/container/runtime/nri/plugin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package nri

import (
"context"
"fmt"
"os"

"github.com/containerd/nri/pkg/api"
nriplugin "github.com/containerd/nri/pkg/stub"
"sigs.k8s.io/yaml"

"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
)

// Compile-time interface checks
var (
_ nriplugin.Plugin = (*Plugin)(nil)
)

const (
// nodeResourceCDIDeviceKey is the prefix of the key used for CDI device annotations.
nodeResourceCDIDeviceKey = "cdi-devices.noderesource.dev"
// nriCDIDeviceKey is the prefix of the key used for CDI device annotations.
nriCDIDeviceKey = "cdi-devices.nri.io"
// defaultNRISocket represents the default path of the NRI socket
defaultNRISocket = "/var/run/nri/nri.sock"
)

type Plugin struct {
logger logger.Interface

stub nriplugin.Stub
}

// NewPlugin creates a new NRI plugin for injecting CDI devices
func NewPlugin(logger logger.Interface) *Plugin {
return &Plugin{
logger: logger,
}
}

// CreateContainer handles container creation requests.
func (p *Plugin) CreateContainer(_ context.Context, pod *api.PodSandbox, ctr *api.Container) (*api.ContainerAdjustment, []*api.ContainerUpdate, error) {
adjust := &api.ContainerAdjustment{}

if err := p.injectCDIDevices(pod, ctr, adjust); err != nil {
return nil, nil, err
}

return adjust, nil, nil
}

func (p *Plugin) injectCDIDevices(pod *api.PodSandbox, ctr *api.Container, a *api.ContainerAdjustment) error {
devices, err := parseCDIDevices(ctr.Name, pod.Annotations)
if err != nil {
return err
}

if len(devices) == 0 {
p.logger.Debugf("%s: no CDI devices annotated...", containerName(pod, ctr))
return nil
}

for _, name := range devices {
a.AddCDIDevice(
&api.CDIDevice{
Name: name,
},
)
p.logger.Infof("%s: injected CDI device %q...", containerName(pod, ctr), name)
}

return nil
}

func parseCDIDevices(ctr string, annotations map[string]string) ([]string, error) {
var (
cdiDevices []string
)

annotation := getAnnotation(annotations, nodeResourceCDIDeviceKey, nriCDIDeviceKey, ctr)
if len(annotation) == 0 {
return nil, nil
}

if err := yaml.Unmarshal(annotation, &cdiDevices); err != nil {
return nil, fmt.Errorf("invalid CDI device annotation %q: %w", string(annotation), err)
}

return cdiDevices, nil
}

func getAnnotation(annotations map[string]string, mainKey, oldKey, ctr string) []byte {
for _, key := range []string{
mainKey + "/container." + ctr,
oldKey + "/container." + ctr,
mainKey + "/pod",
oldKey + "/pod",
mainKey,
oldKey,
} {
if value, ok := annotations[key]; ok {
return []byte(value)
}
}

return nil
}

// Construct a container name for log messages.
func containerName(pod *api.PodSandbox, container *api.Container) string {
if pod != nil {
return pod.Name + "/" + container.Name
}
return container.Name
}

// Start starts the NRI plugin
func (p *Plugin) Start(ctx context.Context, nriSocketPath, nriPluginIdx string) error {
if len(nriSocketPath) == 0 {
nriSocketPath = defaultNRISocket
}
_, err := os.Stat(nriSocketPath)
if err != nil {
return fmt.Errorf("failed to find valid nri socket in %s: %w", nriSocketPath, err)
}

var pluginOpts []nriplugin.Option
pluginOpts = append(pluginOpts, nriplugin.WithPluginIdx(nriPluginIdx))
pluginOpts = append(pluginOpts, nriplugin.WithSocketPath(nriSocketPath))
if p.stub, err = nriplugin.New(p, pluginOpts...); err != nil {
return fmt.Errorf("failed to initialise plugin at %s: %w", nriSocketPath, err)
}
err = p.stub.Start(ctx)
if err != nil {
return fmt.Errorf("plugin exited with error: %w", err)
}
return nil
}

// Stop stops the NRI plugin
func (p *Plugin) Stop() {
if p != nil && p.stub != nil {
p.stub.Stop()
}
}
23 changes: 23 additions & 0 deletions cmd/nvidia-ctk-installer/container/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ const (
// defaultRuntimeName specifies the NVIDIA runtime to be use as the default runtime if setting the default runtime is enabled
defaultRuntimeName = "nvidia"
defaultHostRootMount = "/host"
defaultNRIPluginIdx = "10"
defaultNRISocket = "/var/run/nri/nri.sock"

runtimeSpecificDefault = "RUNTIME_SPECIFIC_DEFAULT"
)
Expand Down Expand Up @@ -94,6 +96,27 @@ func Flags(opts *Options) []cli.Flag {
Destination: &opts.EnableCDI,
Sources: cli.EnvVars("RUNTIME_ENABLE_CDI"),
},
&cli.BoolFlag{
Name: "enable-nri-in-runtime",
Usage: "Enable NRI in the configured runtime",
Destination: &opts.EnableNRI,
Value: true,
Sources: cli.EnvVars("RUNTIME_ENABLE_NRI"),
},
&cli.StringFlag{
Name: "nri-plugin-index",
Usage: "Specify the plugin index to register to NRI",
Value: defaultNRIPluginIdx,
Destination: &opts.NRIPluginIndex,
Sources: cli.EnvVars("RUNTIME_NRI_PLUGIN_INDEX"),
},
&cli.StringFlag{
Name: "nri-socket",
Usage: "Specify the path to the NRI socket file to register the NRI plugin server",
Value: defaultNRISocket,
Destination: &opts.NRISocket,
Sources: cli.EnvVars("RUNTIME_NRI_SOCKET"),
},
&cli.StringFlag{
Name: "host-root",
Usage: "Specify the path to the host root to be used when restarting the runtime using systemd",
Expand Down
51 changes: 46 additions & 5 deletions cmd/nvidia-ctk-installer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ import (
"os/signal"
"path/filepath"
"syscall"
"time"

"github.com/urfave/cli/v3"
"golang.org/x/sys/unix"

"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime/nri"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/toolkit"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
Expand All @@ -26,6 +28,9 @@ const (
toolkitSubDir = "toolkit"

defaultRuntime = "docker"

retryBackoff = 2 * time.Second
maxRetryAttempts = 5
)

var availableRuntimes = map[string]struct{}{"docker": {}, "crio": {}, "containerd": {}}
Expand Down Expand Up @@ -70,13 +75,15 @@ func main() {
type app struct {
logger logger.Interface

toolkit *toolkit.Installer
nriPlugin *nri.Plugin
toolkit *toolkit.Installer
}

// NewApp creates the CLI app fro the specified options.
func NewApp(logger logger.Interface) *cli.Command {
a := app{
logger: logger,
logger: logger,
nriPlugin: nri.NewPlugin(logger),
}
return a.build()
}
Expand All @@ -93,8 +100,8 @@ func (a app) build() *cli.Command {
Before: func(ctx context.Context, cmd *cli.Command) (context.Context, error) {
return ctx, a.Before(cmd, &options)
},
Action: func(_ context.Context, cmd *cli.Command) error {
return a.Run(cmd, &options)
Action: func(ctx context.Context, cmd *cli.Command) error {
return a.Run(ctx, cmd, &options)
},
Flags: []cli.Flag{
&cli.BoolFlag{
Expand Down Expand Up @@ -194,7 +201,7 @@ func (a *app) validateFlags(c *cli.Command, o *options) error {
// Run installs the NVIDIA Container Toolkit and updates the requested runtime.
// If the application is run as a daemon, the application waits and unconfigures
// the runtime on termination.
func (a *app) Run(c *cli.Command, o *options) error {
func (a *app) Run(ctx context.Context, c *cli.Command, o *options) error {
err := a.initialize(o.pidFile)
if err != nil {
return fmt.Errorf("unable to initialize: %v", err)
Expand Down Expand Up @@ -222,6 +229,11 @@ func (a *app) Run(c *cli.Command, o *options) error {
}

if !o.noDaemon {
if o.runtimeOptions.EnableNRI {
if err = a.startNRIPluginServer(ctx, o.runtimeOptions); err != nil {
a.logger.Errorf("unable to start NRI plugin server: %v", err)
}
}
err = a.waitForSignal()
if err != nil {
return fmt.Errorf("unable to wait for signal: %v", err)
Expand Down Expand Up @@ -287,9 +299,38 @@ func (a *app) waitForSignal() error {
return nil
}

func (a *app) startNRIPluginServer(ctx context.Context, opts runtime.Options) error {
a.logger.Infof("Starting the NRI Plugin server....")

retriable := func() error {
return a.nriPlugin.Start(ctx, opts.NRISocket, opts.NRIPluginIndex)
}
var err error
for i := 0; i < maxRetryAttempts; i++ {
err = retriable()
if err == nil {
break
}
if i == maxRetryAttempts-1 {
break
}
time.Sleep(retryBackoff)
}
if err != nil {
a.logger.Errorf("Max retries reached %d/%d, aborting", maxRetryAttempts, maxRetryAttempts)
return err
}
return nil
}

func (a *app) shutdown(pidFile string) {
a.logger.Infof("Shutting Down")

if a.nriPlugin != nil {
a.logger.Infof("Stopping NRI plugin server...")
a.nriPlugin.Stop()
}

err := os.Remove(pidFile)
if err != nil {
a.logger.Warningf("Unable to remove pidfile: %v", err)
Expand Down
1 change: 1 addition & 0 deletions cmd/nvidia-ctk-installer/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ version = 2
"--pid-file=" + filepath.Join(testRoot, "toolkit.pid"),
"--restart-mode=none",
"--toolkit-source-root=" + filepath.Join(artifactRoot, "deb"),
"--enable-nri-in-runtime=false",
}

err := app.Run(context.Background(), append(testArgs, tc.args...))
Expand Down
14 changes: 11 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.25.0
require (
github.com/NVIDIA/go-nvlib v0.9.0
github.com/NVIDIA/go-nvml v0.13.0-1
github.com/containerd/nri v0.10.1-0.20251120153915-7d8611f87ad7
github.com/google/uuid v1.6.0
github.com/moby/sys/mountinfo v0.7.2
github.com/moby/sys/reexec v0.1.0
Expand All @@ -19,24 +20,31 @@ require (
github.com/urfave/cli/v3 v3.6.1
golang.org/x/mod v0.30.0
golang.org/x/sys v0.38.0
sigs.k8s.io/yaml v1.4.0
tags.cncf.io/container-device-interface v1.0.2-0.20251114135136-1b24d969689f
tags.cncf.io/container-device-interface/specs-go v1.0.0
)

require (
cyphar.com/go-pathrs v0.2.1 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/ttrpc v1.2.7 // indirect
github.com/cyphar/filepath-securejoin v0.6.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/knqyf263/go-plugin v0.9.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/moby/sys/capability v0.4.0 // indirect
github.com/opencontainers/cgroups v0.0.4 // indirect
github.com/opencontainers/runtime-tools v0.9.1-0.20251114084447-edf4cb3d2116 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.11.0 // indirect
github.com/tetratelabs/wazero v1.9.0 // indirect
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230731190214-cbb8c96f2d6d // indirect
google.golang.org/grpc v1.57.1 // indirect
google.golang.org/protobuf v1.36.5 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
sigs.k8s.io/yaml v1.4.0 // indirect
)
Loading
Loading