Skip to content

Commit 80a51ff

Browse files
elezarclaude
authored andcommitted
wsl: report a single "all" device to kubelet
On WSL, all GPUs are accessed through /dev/dxg. Replace the per-GPU wslDevice (which reported one device per physical GPU with individual UUIDs) with a stateless wslAllGPUsDevice that always returns UUID "all" and path "/dev/dxg". This causes the device map to collapse to a single entry per resource, so kubelet sees exactly one GPU device on WSL. When allocated, this flows naturally through all strategy paths (envvar, CDI, volume mounts) to set NVIDIA_VISIBLE_DEVICES=all, which is what nvidia-container-runtime on WSL expects. Co-Authored-By: Claude Sonnet 4.6 <[email protected]> Signed-off-by: Evan Lezar <[email protected]> (cherry picked from commit 1bb3658)
1 parent d63d160 commit 80a51ff

5 files changed

Lines changed: 70 additions & 34 deletions

File tree

internal/rm/device_map.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,18 @@ type deviceMapBuilder struct {
4141
type DeviceMap map[spec.ResourceName]Devices
4242

4343
// NewDeviceMap creates a device map for the specified NVML library and config.
44-
func NewDeviceMap(infolib info.Interface, devicelib device.Interface, config *spec.Config) (DeviceMap, error) {
44+
func NewDeviceMap(devicelib device.Interface, config *spec.Config, platform info.Platform) (DeviceMap, error) {
45+
newGPUDevice := newNvmlGPUDevice
46+
if platform == info.PlatformWSL {
47+
newGPUDevice = newWslAllGPUsDevice
48+
}
49+
4550
b := deviceMapBuilder{
4651
Interface: devicelib,
4752
migStrategy: config.Flags.MigStrategy,
4853
resources: &config.Resources,
4954
replicatedResources: config.Sharing.ReplicatedResources(),
50-
newGPUDevice: newNvmlGPUDevice,
51-
}
52-
53-
if infolib.ResolvePlatform() == info.PlatformWSL {
54-
b.newGPUDevice = newWslGPUDevice
55+
newGPUDevice: newGPUDevice,
5556
}
5657

5758
return b.build()

internal/rm/device_map_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,30 @@ import (
2525
spec "github.com/NVIDIA/k8s-device-plugin/api/config/v1"
2626
)
2727

28+
func TestWslDeviceMapHasSingleAllDevice(t *testing.T) {
29+
// Simulate building a GPU device map with 3 GPUs on WSL.
30+
// Because newWslAllGPUsDevice always returns index/UUID "all", the map
31+
// should collapse to exactly one device entry per resource.
32+
devices := make(DeviceMap)
33+
resourceName := spec.ResourceName("nvidia.com/gpu")
34+
35+
for i := 0; i < 3; i++ {
36+
index, info := newWslAllGPUsDevice(i, nil)
37+
err := devices.setEntry(resourceName, index, info)
38+
require.NoError(t, err)
39+
}
40+
41+
gpuDevices, ok := devices[resourceName]
42+
require.True(t, ok)
43+
require.Len(t, gpuDevices, 1)
44+
45+
dev, ok := gpuDevices["all"]
46+
require.True(t, ok)
47+
require.Equal(t, "all", dev.ID)
48+
require.Equal(t, "all", dev.Index)
49+
require.Equal(t, []string{"/dev/dxg"}, dev.Paths)
50+
}
51+
2852
func TestDeviceMapInsert(t *testing.T) {
2953
device0 := Device{Device: pluginapi.Device{ID: "0"}}
3054
device0withIndex := Device{Device: pluginapi.Device{ID: "0"}, Index: "index"}

internal/rm/nvml_devices.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,8 @@ func newNvmlGPUDevice(i int, gpu nvml.Device) (string, deviceInfo) {
4949
return index, nvmlDevice{gpu}
5050
}
5151

52-
func newWslGPUDevice(i int, gpu nvml.Device) (string, deviceInfo) {
53-
index := fmt.Sprintf("%v", i)
54-
return index, wslDevice{gpu}
52+
func newWslAllGPUsDevice(_ int, _ nvml.Device) (string, deviceInfo) {
53+
return "all", wslAllGPUsDevice{}
5554
}
5655

5756
func newMigDevice(i int, j int, mig nvml.Device) (string, nvmlMigDevice) {

internal/rm/nvml_manager.go

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ func NewNVMLResourceManagers(infolib info.Interface, nvmllib nvml.Interface, dev
4848
}
4949
}()
5050

51-
deviceMap, err := NewDeviceMap(infolib, devicelib, config)
51+
platform := infolib.ResolvePlatform()
52+
53+
deviceMap, err := NewDeviceMap(devicelib, config, platform)
5254
if err != nil {
5355
return nil, fmt.Errorf("error building device map: %v", err)
5456
}
@@ -58,15 +60,25 @@ func NewNVMLResourceManagers(infolib info.Interface, nvmllib nvml.Interface, dev
5860
if len(devices) == 0 {
5961
continue
6062
}
61-
r := &nvmlResourceManager{
62-
resourceManager: resourceManager{
63-
config: config,
64-
resource: resourceName,
65-
devices: devices,
66-
},
67-
nvml: nvmllib,
63+
64+
resources := resourceManager{
65+
config: config,
66+
resource: resourceName,
67+
devices: devices,
6868
}
69-
rms = append(rms, r)
69+
70+
var rm ResourceManager
71+
switch platform {
72+
case info.PlatformWSL:
73+
rm = &resources
74+
default:
75+
rm = &nvmlResourceManager{
76+
resourceManager: resources,
77+
nvml: nvmllib,
78+
}
79+
}
80+
81+
rms = append(rms, rm)
7082
}
7183

7284
return rms, nil

internal/rm/wsl_devices.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,31 @@
1616

1717
package rm
1818

19-
type wslDevice nvmlDevice
19+
type wslAllGPUsDevice struct{}
2020

21-
var _ deviceInfo = (*wslDevice)(nil)
21+
var _ deviceInfo = (*wslAllGPUsDevice)(nil)
2222

23-
// GetUUID returns the UUID of the device
24-
func (d wslDevice) GetUUID() (string, error) {
25-
return nvmlDevice(d).GetUUID()
23+
// GetUUID returns "all" to represent all GPUs accessible via /dev/dxg on WSL.
24+
func (d wslAllGPUsDevice) GetUUID() (string, error) {
25+
return "all", nil
2626
}
2727

28-
// GetPaths returns the paths for a tegra device.
29-
func (d wslDevice) GetPaths() ([]string, error) {
28+
// GetPaths returns the WSL GPU device path.
29+
func (d wslAllGPUsDevice) GetPaths() ([]string, error) {
3030
return []string{"/dev/dxg"}, nil
3131
}
3232

33-
// GetNumaNode returns the NUMA node associated with the GPU device
34-
func (d wslDevice) GetNumaNode() (bool, int, error) {
35-
return nvmlDevice(d).GetNumaNode()
33+
// GetNumaNode returns no NUMA node association for WSL devices.
34+
func (d wslAllGPUsDevice) GetNumaNode() (bool, int, error) {
35+
return false, 0, nil
3636
}
3737

38-
// GetTotalMemory returns the total memory available on the device.
39-
func (d wslDevice) GetTotalMemory() (uint64, error) {
40-
return nvmlDevice(d).GetTotalMemory()
38+
// GetTotalMemory returns 0 as memory info is not available for WSL devices.
39+
func (d wslAllGPUsDevice) GetTotalMemory() (uint64, error) {
40+
return 0, nil
4141
}
4242

43-
// GetComputeCapability returns the CUDA compute capability for the device.
44-
func (d wslDevice) GetComputeCapability() (string, error) {
45-
return nvmlDevice(d).GetComputeCapability()
43+
// GetComputeCapability returns an empty string as compute capability is not available for WSL devices.
44+
func (d wslAllGPUsDevice) GetComputeCapability() (string, error) {
45+
return "", nil
4646
}

0 commit comments

Comments
 (0)