Skip to content

Commit 4c19a01

Browse files
authored
Merge pull request #52048 from shiv-tyagi/vendor-detection
Use CDI for GPU injection for AMD devices for --gpus
2 parents 018cdea + 13a8626 commit 4c19a01

7 files changed

Lines changed: 152 additions & 22 deletions

File tree

daemon/command/daemon.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,8 @@ func (cli *daemonCLI) start(ctx context.Context) (retErr error) {
285285
cdiCache = daemon.RegisterCDIDriver(cli.Config.CDISpecDirs...)
286286
}
287287

288+
daemon.RegisterGPUDeviceDrivers(cdiCache)
289+
288290
var apiServer apiserver.Server
289291
authz, err := initMiddlewares(ctx, &apiServer, cli.Config, pluginStore)
290292
if err != nil {

daemon/devices.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package daemon
22

33
import (
44
"context"
5+
"errors"
56

67
"github.com/containerd/log"
78
"github.com/moby/moby/api/types/container"
@@ -38,6 +39,18 @@ func registerDeviceDriver(name string, d *deviceDriver) {
3839
deviceDrivers[name] = d
3940
}
4041

42+
func getFirstAvailableVendor(vendorList []string) (string, error) {
43+
knownVendors := []string{"nvidia.com", "amd.com"}
44+
for _, vendor := range knownVendors {
45+
for _, available := range vendorList {
46+
if vendor == available {
47+
return vendor, nil
48+
}
49+
}
50+
}
51+
return "", errors.New("no known GPU vendor found")
52+
}
53+
4154
func (daemon *Daemon) handleDevice(req container.DeviceRequest, spec *specs.Spec) error {
4255
if req.Driver == "" {
4356
// If no driver is explicitly requested, we iterate over the registered

daemon/devices_amd_linux.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
11
package daemon
22

33
import (
4+
"errors"
5+
"fmt"
6+
"os/exec"
47
"strings"
58

9+
"github.com/moby/moby/v2/daemon/internal/capabilities"
610
"github.com/opencontainers/runtime-spec/specs-go"
11+
"tags.cncf.io/container-device-interface/pkg/cdi"
12+
)
13+
14+
const (
15+
amdContainerRuntimeExecutableName = "amd-container-runtime"
716
)
817

918
func setAMDGPUs(s *specs.Spec, dev *deviceInstance) error {
@@ -25,3 +34,46 @@ func setAMDGPUs(s *specs.Spec, dev *deviceInstance) error {
2534

2635
return nil
2736
}
37+
38+
func createAMDCDIUpdater(cdiCache *cdi.Cache) func(*specs.Spec, *deviceInstance) error {
39+
return func(s *specs.Spec, dev *deviceInstance) error {
40+
vendor, err := getFirstAvailableVendor(cdiCache.ListVendors())
41+
if err != nil {
42+
return fmt.Errorf("failed to discover GPU vendor from CDI: %w", err)
43+
}
44+
45+
if vendor != "amd.com" {
46+
return errors.New("AMD CDI spec not found")
47+
}
48+
49+
injector := &cdiDeviceInjector{
50+
defaultCDIDeviceKind: "amd.com/gpu",
51+
}
52+
return injector.injectDevices(s, dev)
53+
}
54+
}
55+
56+
func getAMDDeviceDrivers(cdiCache *cdi.Cache) *deviceDriver {
57+
var composite firstSuccessfulUpdater
58+
59+
if cdiCache != nil {
60+
composite = append(composite, createAMDCDIUpdater(cdiCache))
61+
}
62+
63+
if _, err := exec.LookPath(amdContainerRuntimeExecutableName); err == nil {
64+
composite = append(composite, setAMDGPUs)
65+
}
66+
67+
if len(composite) == 0 {
68+
return nil
69+
}
70+
71+
// We do not support specifying driver with device requests for AMD GPUs.
72+
// Hence only use the composite updater and try cdi and runtime driver in sequence
73+
// based on availability.
74+
capset := capabilities.Set{"gpu": struct{}{}, "amd": struct{}{}}
75+
return &deviceDriver{
76+
capset: capset,
77+
updateSpec: composite.updateSpec,
78+
}
79+
}

daemon/devices_linux.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package daemon
2+
3+
import "tags.cncf.io/container-device-interface/pkg/cdi"
4+
5+
// RegisterGPUDeviceDrivers registers GPU device drivers.
6+
// If the cdiCache is provided, it is used to detect presence of CDI specs for AMD GPUs.
7+
// For NVIDIA GPUs, presence of CDI specs is detected by checking for the nvidia-cdi-hook binary.
8+
func RegisterGPUDeviceDrivers(cdiCache *cdi.Cache) {
9+
// Register NVIDIA device drivers.
10+
if nvidiaDrivers := getNVIDIADeviceDrivers(); len(nvidiaDrivers) > 0 {
11+
for name, driver := range nvidiaDrivers {
12+
registerDeviceDriver(name, driver)
13+
}
14+
return
15+
}
16+
17+
// Register AMD driver if AMD CDI spec or helper binary is present.
18+
if amdDriver := getAMDDeviceDrivers(cdiCache); amdDriver != nil {
19+
registerDeviceDriver("amd", amdDriver)
20+
return
21+
}
22+
}

daemon/devices_nonlinux.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
//go:build !linux
2+
3+
package daemon
4+
5+
import "tags.cncf.io/container-device-interface/pkg/cdi"
6+
7+
// RegisterGPUDeviceDrivers is a no-op on non-Linux platforms.
8+
func RegisterGPUDeviceDrivers(_ *cdi.Cache) {}

daemon/devices_nvidia_linux.go

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ var errConflictCountDeviceIDs = errors.New("cannot set both Count and DeviceIDs
2323
const (
2424
nvidiaContainerRuntimeHookExecutableName = "nvidia-container-runtime-hook"
2525
nvidiaCDIHookExecutableName = "nvidia-cdi-hook"
26-
amdContainerRuntimeExecutableName = "amd-container-runtime"
2726
)
2827

2928
// These are NVIDIA-specific capabilities stolen from github.com/containerd/containerd/contrib/nvidia.allCaps
@@ -36,27 +35,6 @@ var allNvidiaCaps = map[string]struct{}{
3635
"display": {},
3736
}
3837

39-
func init() {
40-
// Register NVIDIA device drivers.
41-
if nvidiaDrivers := getNVIDIADeviceDrivers(); len(nvidiaDrivers) > 0 {
42-
for name, driver := range nvidiaDrivers {
43-
registerDeviceDriver(name, driver)
44-
}
45-
return
46-
}
47-
48-
// Register AMD driver if AMD helper binary is present.
49-
if _, err := exec.LookPath(amdContainerRuntimeExecutableName); err == nil {
50-
registerDeviceDriver("amd", &deviceDriver{
51-
capset: capabilities.Set{"gpu": struct{}{}, "amd": struct{}{}},
52-
updateSpec: setAMDGPUs,
53-
})
54-
return
55-
}
56-
57-
// No "gpu" capability
58-
}
59-
6038
func getNVIDIADeviceDrivers() map[string]*deviceDriver {
6139
var composite firstSuccessfulUpdater
6240
nvidiaDrivers := make(map[string]*deviceDriver)

daemon/devices_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package daemon
2+
3+
import (
4+
"testing"
5+
6+
"gotest.tools/v3/assert"
7+
)
8+
9+
func TestGetFirstAvailableVendor(t *testing.T) {
10+
tests := []struct {
11+
name string
12+
vendors []string
13+
expectVendor string
14+
expectError string
15+
}{
16+
{
17+
name: "NVIDIA vendor",
18+
vendors: []string{"nvidia.com"},
19+
expectVendor: "nvidia.com",
20+
},
21+
{
22+
name: "AMD vendor",
23+
vendors: []string{"amd.com"},
24+
expectVendor: "amd.com",
25+
},
26+
{
27+
name: "No vendors",
28+
vendors: nil,
29+
expectError: "no known GPU vendor found",
30+
},
31+
{
32+
name: "Unknown vendor",
33+
vendors: []string{"unknown.com"},
34+
expectError: "no known GPU vendor found",
35+
},
36+
{
37+
name: "Mixed vendor",
38+
vendors: []string{"amd.com", "nvidia.com"},
39+
expectVendor: "nvidia.com",
40+
},
41+
}
42+
43+
for _, tt := range tests {
44+
t.Run(tt.name, func(t *testing.T) {
45+
vendor, err := getFirstAvailableVendor(tt.vendors)
46+
47+
if tt.expectError != "" {
48+
assert.Error(t, err, tt.expectError)
49+
} else {
50+
assert.NilError(t, err)
51+
assert.Equal(t, tt.expectVendor, vendor)
52+
}
53+
})
54+
}
55+
}

0 commit comments

Comments
 (0)