Skip to content

Commit a044d53

Browse files
committed
Fix unmarshalling of registered types
This is a faithful backport of #53 --- In 7d3d258 I inadvterantly removed support for json unmarshalling for the case when a type implements a protobuf message but is a type that is registered. For types that are registered through the `Register` function typeurl is supposed to ignore the interfaces that the type implements and just use json. This change restores that behavior. Signed-off-by: Brian Goff <[email protected]>
1 parent 1666bdb commit a044d53

File tree

3 files changed

+42
-18
lines changed

3 files changed

+42
-18
lines changed

types.go

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ type handler interface {
3939
Marshaller(interface{}) func() ([]byte, error)
4040
Unmarshaller(interface{}) func([]byte) error
4141
TypeURL(interface{}) string
42-
GetType(url string) reflect.Type
42+
GetType(url string) (reflect.Type, bool)
4343
}
4444

4545
// Definitions of common error types used throughout typeurl.
@@ -240,7 +240,7 @@ func MarshalAnyToProto(from interface{}) (*anypb.Any, error) {
240240
}
241241

242242
func unmarshal(typeURL string, value []byte, v interface{}) (interface{}, error) {
243-
t, err := getTypeByUrl(typeURL)
243+
t, isProto, err := getTypeByUrl(typeURL)
244244
if err != nil {
245245
return nil, err
246246
}
@@ -258,43 +258,45 @@ func unmarshal(typeURL string, value []byte, v interface{}) (interface{}, error)
258258
}
259259
}
260260

261-
pm, ok := v.(proto.Message)
262-
if ok {
263-
return v, proto.Unmarshal(value, pm)
264-
}
261+
if isProto {
262+
pm, ok := v.(proto.Message)
263+
if ok {
264+
return v, proto.Unmarshal(value, pm)
265+
}
265266

266-
for _, h := range handlers {
267-
if unmarshal := h.Unmarshaller(v); unmarshal != nil {
268-
return v, unmarshal(value)
267+
for _, h := range handlers {
268+
if unmarshal := h.Unmarshaller(v); unmarshal != nil {
269+
return v, unmarshal(value)
270+
}
269271
}
270272
}
271273

272274
// fallback to json unmarshaller
273275
return v, json.Unmarshal(value, v)
274276
}
275277

276-
func getTypeByUrl(url string) (reflect.Type, error) {
278+
func getTypeByUrl(url string) (_ reflect.Type, isProto bool, _ error) {
277279
mu.RLock()
278280
for t, u := range registry {
279281
if u == url {
280282
mu.RUnlock()
281-
return t, nil
283+
return t, false, nil
282284
}
283285
}
284286
mu.RUnlock()
285287
mt, err := protoregistry.GlobalTypes.FindMessageByURL(url)
286288
if err != nil {
287289
if errors.Is(err, protoregistry.NotFound) {
288290
for _, h := range handlers {
289-
if t := h.GetType(url); t != nil {
290-
return t, nil
291+
if t, isProto := h.GetType(url); t != nil {
292+
return t, isProto, nil
291293
}
292294
}
293295
}
294-
return nil, fmt.Errorf("type with url %s: %w", url, ErrNotFound)
296+
return nil, false, fmt.Errorf("type with url %s: %w", url, ErrNotFound)
295297
}
296298
empty := mt.New().Interface()
297-
return reflect.TypeOf(empty).Elem(), nil
299+
return reflect.TypeOf(empty).Elem(), true, nil
298300
}
299301

300302
func tryDereference(v interface{}) reflect.Type {

types_gogo.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ func (gogoHandler) TypeURL(v interface{}) string {
5959
return gogoproto.MessageName(pm)
6060
}
6161

62-
func (gogoHandler) GetType(url string) reflect.Type {
62+
func (gogoHandler) GetType(url string) (reflect.Type, bool) {
6363
t := gogoproto.MessageType(url)
6464
if t == nil {
65-
return nil
65+
return nil, false
6666
}
67-
return t.Elem()
67+
return t.Elem(), true
6868
}

types_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package typeurl
1818

1919
import (
2020
"bytes"
21+
"encoding/json"
2122
"errors"
2223
"reflect"
2324
"testing"
@@ -241,3 +242,24 @@ func TestUnmarshalNotFound(t *testing.T) {
241242
t.Fatalf("unexpected error unmarshalling type which does not exist: %v", err)
242243
}
243244
}
245+
246+
func TestUnmarshalJSON(t *testing.T) {
247+
url := t.Name()
248+
Register(&timestamppb.Timestamp{}, url)
249+
250+
expected := timestamppb.Now()
251+
252+
dt, err := json.Marshal(expected)
253+
if err != nil {
254+
t.Fatal(err)
255+
}
256+
257+
var actual timestamppb.Timestamp
258+
if err := UnmarshalToByTypeURL(url, dt, &actual); err != nil {
259+
t.Fatal(err)
260+
}
261+
262+
if !expected.AsTime().Equal(actual.AsTime()) {
263+
t.Fatalf("expected value to be %q, got: %q", expected.AsTime(), actual.AsTime())
264+
}
265+
}

0 commit comments

Comments
 (0)