Skip to content

Commit 0cdec1a

Browse files
authored
feat: add nullable bool defaulting to false (#758)
1 parent ee6571a commit 0cdec1a

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

sqlxx/types.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,56 @@ func (ns *NullBool) UnmarshalJSON(data []byte) error {
163163
return errors.WithStack(json.Unmarshal(data, &ns.Bool))
164164
}
165165

166+
// FalsyNullBool represents a bool that may be null.
167+
// It JSON decodes to false if null.
168+
//
169+
// swagger:type bool
170+
// swagger:model falsyNullBool
171+
type FalsyNullBool struct {
172+
Bool bool
173+
Valid bool // Valid is true if Bool is not NULL
174+
}
175+
176+
// Scan implements the Scanner interface.
177+
func (ns *FalsyNullBool) Scan(value interface{}) error {
178+
var d = sql.NullBool{}
179+
if err := d.Scan(value); err != nil {
180+
return err
181+
}
182+
183+
ns.Bool = d.Bool
184+
ns.Valid = d.Valid
185+
return nil
186+
}
187+
188+
// Value implements the driver Valuer interface.
189+
func (ns FalsyNullBool) Value() (driver.Value, error) {
190+
if !ns.Valid {
191+
return nil, nil
192+
}
193+
return ns.Bool, nil
194+
}
195+
196+
// MarshalJSON returns m as the JSON encoding of m.
197+
func (ns FalsyNullBool) MarshalJSON() ([]byte, error) {
198+
if !ns.Valid {
199+
return []byte("false"), nil
200+
}
201+
return json.Marshal(ns.Bool)
202+
}
203+
204+
// UnmarshalJSON sets *m to a copy of data.
205+
func (ns *FalsyNullBool) UnmarshalJSON(data []byte) error {
206+
if ns == nil {
207+
return errors.New("json.RawMessage: UnmarshalJSON on nil pointer")
208+
}
209+
if len(data) == 0 || string(data) == "null" {
210+
return nil
211+
}
212+
ns.Valid = true
213+
return errors.WithStack(json.Unmarshal(data, &ns.Bool))
214+
}
215+
166216
// swagger:type string
167217
// swagger:model nullString
168218
type NullString string

sqlxx/types_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,42 @@ func TestNullBoolMarshalJSON(t *testing.T) {
6464
}
6565
}
6666

67+
func TestNullBoolDefaultFalseMarshalJSON(t *testing.T) {
68+
type outer struct {
69+
Bool *FalsyNullBool `json:"null_bool,omitempty"`
70+
}
71+
72+
for k, tc := range []struct {
73+
in *outer
74+
expected string
75+
}{
76+
{in: &outer{&FalsyNullBool{Valid: false, Bool: true}}, expected: "{\"null_bool\":false}"},
77+
{in: &outer{&FalsyNullBool{Valid: false, Bool: false}}, expected: "{\"null_bool\":false}"},
78+
{in: &outer{&FalsyNullBool{Valid: true, Bool: true}}, expected: "{\"null_bool\":true}"},
79+
{in: &outer{&FalsyNullBool{Valid: true, Bool: false}}, expected: "{\"null_bool\":false}"},
80+
{in: &outer{}, expected: "{}"},
81+
} {
82+
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
83+
out, err := json.Marshal(tc.in)
84+
require.NoError(t, err)
85+
assert.EqualValues(t, tc.expected, string(out))
86+
87+
var actual outer
88+
require.NoError(t, json.Unmarshal(out, &actual))
89+
if tc.in.Bool == nil {
90+
assert.Nil(t, actual.Bool)
91+
return
92+
} else if !tc.in.Bool.Valid {
93+
assert.False(t, actual.Bool.Bool)
94+
return
95+
}
96+
97+
assert.EqualValues(t, tc.in.Bool.Bool, actual.Bool.Bool)
98+
assert.EqualValues(t, tc.in.Bool.Valid, actual.Bool.Valid)
99+
})
100+
}
101+
}
102+
67103
func TestNullInt64MarshalJSON(t *testing.T) {
68104
type outer struct {
69105
Int64 *NullInt64 `json:"null_int,omitempty"`

0 commit comments

Comments
 (0)