Skip to content

Commit 977ad5b

Browse files
yongruilink8s-publishing-bot
authored andcommitted
Use IsZero instead of IsNil for union ratcheting check
- Use reflect.ValueOf(oldObj).IsZero() instead of IsNil() so union validation works with non-nilable T (e.g. value types) - Remove hasOldValue guard from inner loop conditionals; only check at the final ratcheting skip point - Add doc comments explaining T is "any" rather than "comparable" because union members can be slices - Add value-type subtests for Union and DiscriminatedUnion Co-authored-by: Tim Hockin <[email protected]> Kubernetes-commit: 3d39627cd9815f39fdf8243a60cd0768538e1b1f
1 parent a128230 commit 977ad5b

2 files changed

Lines changed: 74 additions & 32 deletions

File tree

pkg/api/validate/union.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ type UnionValidationOptions struct {
6161
// )...)
6262
// return errs
6363
// }
64+
//
65+
// Note that T is "any", rather than "comparable", because union-members can be
66+
// slices, meaning T might be a struct with a slice, meaning it is not
67+
// comparable.
6468
func Union[T any](_ context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj T, union *UnionMembership, isSetFns ...ExtractorFn[T, bool]) field.ErrorList {
6569
options := UnionValidationOptions{
6670
ErrorForEmpty: func(fldPath *field.Path, allFields []string) *field.Error {
@@ -99,6 +103,10 @@ func Union[T any](_ context.Context, op operation.Operation, fldPath *field.Path
99103
//
100104
// It is not an error for the discriminatorValue to be unknown. That must be
101105
// validated on its own.
106+
//
107+
// Note that T is "any", rather than "comparable", because union-members can be
108+
// slices, meaning T might be a struct with a slice, meaning it is not
109+
// comparable.
102110
func DiscriminatedUnion[T any, D ~string](_ context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj T, union *UnionMembership, discriminatorExtractor ExtractorFn[T, D], isSetFns ...ExtractorFn[T, bool]) (errs field.ErrorList) {
103111
if len(union.members) != len(isSetFns) {
104112
return field.ErrorList{
@@ -107,10 +115,10 @@ func DiscriminatedUnion[T any, D ~string](_ context.Context, op operation.Operat
107115
len(isSetFns), len(union.members))),
108116
}
109117
}
110-
hasOldValue := !reflect.ValueOf(oldObj).IsNil()
118+
hasOldValue := !reflect.ValueOf(oldObj).IsZero() // because T is any, rather than comparable
111119
var changed bool
112120
discriminatorValue := discriminatorExtractor(obj)
113-
if op.Type == operation.Update && hasOldValue {
121+
if op.Type == operation.Update {
114122
oldDiscriminatorValue := discriminatorExtractor(oldObj)
115123
changed = discriminatorValue != oldDiscriminatorValue
116124
}
@@ -119,7 +127,7 @@ func DiscriminatedUnion[T any, D ~string](_ context.Context, op operation.Operat
119127
member := union.members[i]
120128
isDiscriminatedMember := string(discriminatorValue) == member.discriminatorValue
121129
newIsSet := fieldIsSet(obj)
122-
if op.Type == operation.Update && hasOldValue && !changed {
130+
if op.Type == operation.Update && !changed {
123131
oldIsSet := fieldIsSet(oldObj)
124132
changed = changed || newIsSet != oldIsSet
125133
}
@@ -197,12 +205,12 @@ func unionValidate[T any](op operation.Operation, fldPath *field.Path,
197205
}
198206
}
199207

200-
hasOldValue := !reflect.ValueOf(oldObj).IsNil()
208+
hasOldValue := !reflect.ValueOf(oldObj).IsZero() // because T is any, rather than comparable
201209
var specifiedFields []string
202210
var changed bool
203211
for i, fieldIsSet := range isSetFns {
204212
newIsSet := fieldIsSet(obj)
205-
if op.Type == operation.Update && hasOldValue && !changed {
213+
if op.Type == operation.Update && !changed {
206214
oldIsSet := fieldIsSet(oldObj)
207215
changed = changed || newIsSet != oldIsSet
208216
}

pkg/api/validate/union_test.go

Lines changed: 61 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,34 @@ func TestUnion(t *testing.T) {
6767
members = append(members, NewUnionMember(f))
6868
}
6969

70-
// Create mock extractors that return predefined values instead of
71-
// actually extracting from the object.
72-
extractors := make([]ExtractorFn[*testMember, bool], len(tc.fieldValues))
73-
for i, val := range tc.fieldValues {
74-
extractors[i] = func(_ *testMember) bool { return val }
75-
}
70+
t.Run("pointer", func(t *testing.T) {
71+
// Create mock extractors that return predefined values instead of
72+
// actually extracting from the object.
73+
extractors := make([]ExtractorFn[*testMember, bool], len(tc.fieldValues))
74+
for i, val := range tc.fieldValues {
75+
extractors[i] = func(_ *testMember) bool { return val }
76+
}
7677

77-
got := Union(context.Background(), operation.Operation{}, nil, &testMember{}, nil,
78-
NewUnionMembership(members...), extractors...)
79-
if !reflect.DeepEqual(got, tc.expected) {
80-
t.Errorf("got %v want %v", got, tc.expected)
81-
}
78+
got := Union(context.Background(), operation.Operation{}, nil, &testMember{}, nil,
79+
NewUnionMembership(members...), extractors...)
80+
if !reflect.DeepEqual(got, tc.expected) {
81+
t.Errorf("got %v want %v", got, tc.expected)
82+
}
83+
})
84+
t.Run("value", func(t *testing.T) {
85+
// Create mock extractors that return predefined values instead of
86+
// actually extracting from the object.
87+
extractors := make([]ExtractorFn[testMember, bool], len(tc.fieldValues))
88+
for i, val := range tc.fieldValues {
89+
extractors[i] = func(_ testMember) bool { return val }
90+
}
91+
92+
got := Union(context.Background(), operation.Operation{}, nil, testMember{}, testMember{},
93+
NewUnionMembership(members...), extractors...)
94+
if !reflect.DeepEqual(got, tc.expected) {
95+
t.Errorf("got %v want %v", got, tc.expected)
96+
}
97+
})
8298
})
8399
}
84100
}
@@ -131,26 +147,44 @@ func TestDiscriminatedUnion(t *testing.T) {
131147
}
132148

133149
for _, tc := range testCases {
150+
members := []UnionMember{}
151+
for _, f := range tc.fields {
152+
members = append(members, NewDiscriminatedUnionMember(f[0], f[1]))
153+
}
154+
134155
t.Run(tc.name, func(t *testing.T) {
135-
members := []UnionMember{}
136-
for _, f := range tc.fields {
137-
members = append(members, NewDiscriminatedUnionMember(f[0], f[1]))
138-
}
156+
t.Run("pointer", func(t *testing.T) {
157+
discriminatorExtractor := func(_ *testMember) string { return tc.discriminatorValue }
139158

140-
discriminatorExtractor := func(_ *testMember) string { return tc.discriminatorValue }
159+
// Create mock extractors that return predefined values instead of
160+
// actually extracting from the object.
161+
extractors := make([]ExtractorFn[*testMember, bool], len(tc.fieldValues))
162+
for i, val := range tc.fieldValues {
163+
extractors[i] = func(_ *testMember) bool { return val }
164+
}
141165

142-
// Create mock extractors that return predefined values instead of
143-
// actually extracting from the object.
144-
extractors := make([]ExtractorFn[*testMember, bool], len(tc.fieldValues))
145-
for i, val := range tc.fieldValues {
146-
extractors[i] = func(_ *testMember) bool { return val }
147-
}
166+
got := DiscriminatedUnion(context.Background(), operation.Operation{}, nil, &testMember{}, nil,
167+
NewDiscriminatedUnionMembership(tc.discriminatorField, members...), discriminatorExtractor, extractors...)
168+
if !reflect.DeepEqual(got, tc.expected) {
169+
t.Errorf("got %v want %v", got.ToAggregate(), tc.expected.ToAggregate())
170+
}
171+
})
172+
t.Run("value", func(t *testing.T) {
173+
discriminatorExtractor := func(_ testMember) string { return tc.discriminatorValue }
148174

149-
got := DiscriminatedUnion(context.Background(), operation.Operation{}, nil, &testMember{}, nil,
150-
NewDiscriminatedUnionMembership(tc.discriminatorField, members...), discriminatorExtractor, extractors...)
151-
if !reflect.DeepEqual(got, tc.expected) {
152-
t.Errorf("got %v want %v", got.ToAggregate(), tc.expected.ToAggregate())
153-
}
175+
// Create mock extractors that return predefined values instead of
176+
// actually extracting from the object.
177+
extractors := make([]ExtractorFn[testMember, bool], len(tc.fieldValues))
178+
for i, val := range tc.fieldValues {
179+
extractors[i] = func(_ testMember) bool { return val }
180+
}
181+
182+
got := DiscriminatedUnion(context.Background(), operation.Operation{}, nil, testMember{}, testMember{},
183+
NewDiscriminatedUnionMembership(tc.discriminatorField, members...), discriminatorExtractor, extractors...)
184+
if !reflect.DeepEqual(got, tc.expected) {
185+
t.Errorf("got %v want %v", got.ToAggregate(), tc.expected.ToAggregate())
186+
}
187+
})
154188
})
155189
}
156190
}

0 commit comments

Comments
 (0)