Skip to content

Commit f626d43

Browse files
authored
Merge pull request #65 from mresvanis/update-mdev-type-name-parsing
Update mdev type parsing to handle multiple spaces in type strings
2 parents 532907f + 1c2d932 commit f626d43

2 files changed

Lines changed: 85 additions & 12 deletions

File tree

pkg/nvmdev/nvmdev.go

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -226,14 +226,27 @@ func (m mdev) Type() (string, error) {
226226
if err != nil {
227227
return "", fmt.Errorf("unable to read mdev_type name for mdev %s: %v", m, err)
228228
}
229-
// file in the format: [NVIDIA|GRID] <vGPU type>
230-
mdevTypeStr := strings.TrimSpace(string(mdevType))
231-
mdevTypeSplit := strings.SplitN(mdevTypeStr, " ", 2)
232-
if len(mdevTypeSplit) != 2 {
233-
return "", fmt.Errorf("unable to parse mdev_type name %s for mdev %s", mdevTypeStr, m)
229+
typeName, err := parseMdevTypeName(string(mdevType))
230+
if err != nil {
231+
return "", fmt.Errorf("unable to parse mdev_type name for mdev %s: %v", m, err)
234232
}
235233

236-
return mdevTypeSplit[1], nil
234+
return typeName, nil
235+
}
236+
237+
// parseMdevTypeName extracts the vGPU type name from a string that may contain
238+
// product prefixes.
239+
// Examples:
240+
// - "NVIDIA A100-4C" -> "A100-4C".
241+
// - "NVIDIA RTX Pro 6000 Blackwell DC-48C" -> "DC-48C"
242+
func parseMdevTypeName(rawName string) (string, error) {
243+
nameStr := strings.TrimSpace(rawName)
244+
nameSplit := strings.Split(nameStr, " ")
245+
typeName := nameSplit[len(nameSplit)-1]
246+
if typeName == "" {
247+
return "", fmt.Errorf("unable to parse mdev_type name from: %s", rawName)
248+
}
249+
return typeName, nil
237250
}
238251

239252
func (m mdev) driver() (string, error) {
@@ -280,13 +293,10 @@ func (m *nvmdev) NewParentDevice(devicePath string) (*ParentDevice, error) {
280293
if err != nil {
281294
return nil, fmt.Errorf("unable to read file %s: %v", path, err)
282295
}
283-
// file in the format: [NVIDIA|GRID] <vGPU type>
284-
nameStr := strings.TrimSpace(string(name))
285-
nameSplit := strings.SplitN(nameStr, " ", 2)
286-
if len(nameSplit) != 2 {
287-
return nil, fmt.Errorf("unable to parse mdev_type name %s at path %s", nameStr, path)
296+
nameStr, err := parseMdevTypeName(string(name))
297+
if err != nil {
298+
return nil, fmt.Errorf("unable to parse mdev_type name at path %s: %v", path, err)
288299
}
289-
nameStr = nameSplit[len(nameSplit)-1]
290300

291301
mdevTypesMap[nameStr] = filepath.Dir(path)
292302
}

pkg/nvmdev/nvmdev_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,66 @@ func TestNvmdev(t *testing.T) {
6161
pf = mdevA100.GetPhysicalFunction()
6262
require.Equal(t, "0000:3b:04.1", pf.Address, "Wrong address for Mock A100 physical function")
6363
}
64+
65+
func TestParseMdevTypeName(t *testing.T) {
66+
testCases := []struct {
67+
name string
68+
mdevTypeStr string
69+
expectedType string
70+
expectError bool
71+
}{
72+
{
73+
name: "NVIDIA prefix format",
74+
mdevTypeStr: "NVIDIA A100-4C",
75+
expectedType: "A100-4C",
76+
expectError: false,
77+
},
78+
{
79+
name: "GRID prefix format",
80+
mdevTypeStr: "GRID V100-8Q",
81+
expectedType: "V100-8Q",
82+
expectError: false,
83+
},
84+
{
85+
name: "Multi-word NVIDIA prefix format",
86+
mdevTypeStr: "NVIDIA RTX Pro 6000 Blackwell A100-4C",
87+
expectedType: "A100-4C",
88+
expectError: false,
89+
},
90+
{
91+
name: "Complex multi-word prefix",
92+
mdevTypeStr: "NVIDIA RTX A6000 Ada Generation H100-8C",
93+
expectedType: "H100-8C",
94+
expectError: false,
95+
},
96+
{
97+
name: "Single word only",
98+
mdevTypeStr: "A100-4C",
99+
expectedType: "A100-4C",
100+
expectError: false,
101+
},
102+
{
103+
name: "Empty string",
104+
mdevTypeStr: "",
105+
expectError: true,
106+
},
107+
{
108+
name: "Only spaces",
109+
mdevTypeStr: " ",
110+
expectError: true,
111+
},
112+
}
113+
114+
for _, tc := range testCases {
115+
t.Run(tc.name, func(t *testing.T) {
116+
actualType, err := parseMdevTypeName(tc.mdevTypeStr)
117+
118+
if tc.expectError {
119+
require.Error(t, err)
120+
} else {
121+
require.NoError(t, err)
122+
require.Equal(t, tc.expectedType, actualType)
123+
}
124+
})
125+
}
126+
}

0 commit comments

Comments
 (0)