From 4f0f32494deece5b415af9e280b38873dbe6bead Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Fri, 26 Jul 2019 16:50:06 -0700 Subject: [PATCH] all: fix reflect.Value.Interface races The reflect.Value.Interface method shallow copies the underlying value, which may copy mutexes and atomically-accessed fields. Some usages of the Interface method is only to check if the interface value implements an interface. In which case the shallow copy was unnecessary. Change those usages to use the reflect.Value.Implements method instead. Fixes #838 --- go.mod | 2 ++ jsonpb/jsonpb.go | 12 +++++++++--- proto/all_test.go | 32 ++++++++++++++++++++++++++++++++ proto/text.go | 6 ++++-- 4 files changed, 47 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index eccf7fd9cc..de28f6f0aa 100644 --- a/go.mod +++ b/go.mod @@ -1 +1,3 @@ module github.com/golang/protobuf + +go 1.12 diff --git a/jsonpb/jsonpb.go b/jsonpb/jsonpb.go index e9cc202585..f0d66befbb 100644 --- a/jsonpb/jsonpb.go +++ b/jsonpb/jsonpb.go @@ -165,6 +165,11 @@ type wkt interface { XXX_WellKnownType() string } +var ( + wktType = reflect.TypeOf((*wkt)(nil)).Elem() + messageType = reflect.TypeOf((*proto.Message)(nil)).Elem() +) + // marshalObject writes a struct to the Writer. func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent, typeURL string) error { if jsm, ok := v.(JSONPBMarshaler); ok { @@ -531,7 +536,8 @@ func (m *Marshaler) marshalValue(out *errWriter, prop *proto.Properties, v refle // Handle well-known types. // Most are handled up in marshalObject (because 99% are messages). - if wkt, ok := v.Interface().(wkt); ok { + if v.Type().Implements(wktType) { + wkt := v.Interface().(wkt) switch wkt.XXX_WellKnownType() { case "NullValue": out.write("null") @@ -1277,8 +1283,8 @@ func checkRequiredFields(pb proto.Message) error { } func checkRequiredFieldsInValue(v reflect.Value) error { - if pm, ok := v.Interface().(proto.Message); ok { - return checkRequiredFields(pm) + if v.Type().Implements(messageType) { + return checkRequiredFields(v.Interface().(proto.Message)) } return nil } diff --git a/proto/all_test.go b/proto/all_test.go index 1bea4b6e8e..a7ed2521ea 100644 --- a/proto/all_test.go +++ b/proto/all_test.go @@ -45,9 +45,11 @@ import ( "testing" "time" + "github.com/golang/protobuf/jsonpb" . "github.com/golang/protobuf/proto" pb3 "github.com/golang/protobuf/proto/proto3_proto" . "github.com/golang/protobuf/proto/test_proto" + descriptorpb "github.com/golang/protobuf/protoc-gen-go/descriptor" ) var globalO *Buffer @@ -2490,3 +2492,33 @@ func BenchmarkUnmarshalUnrecognizedFields(b *testing.B) { p2.Unmarshal(pbd) } } + +// TestRace tests whether there are races among the different marshalers. +func TestRace(t *testing.T) { + m := &descriptorpb.FileDescriptorProto{ + Options: &descriptorpb.FileOptions{ + GoPackage: String("path/to/my/package"), + }, + } + + wg := &sync.WaitGroup{} + defer wg.Wait() + + wg.Add(1) + go func() { + defer wg.Done() + Marshal(m) + }() + + wg.Add(1) + go func() { + defer wg.Done() + (&jsonpb.Marshaler{}).MarshalToString(m) + }() + + wg.Add(1) + go func() { + defer wg.Done() + m.String() + }() +} diff --git a/proto/text.go b/proto/text.go index 1aaee725b4..d97f9b3563 100644 --- a/proto/text.go +++ b/proto/text.go @@ -456,6 +456,8 @@ func (tm *TextMarshaler) writeStruct(w *textWriter, sv reflect.Value) error { return nil } +var textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() + // writeAny writes an arbitrary field. func (tm *TextMarshaler) writeAny(w *textWriter, v reflect.Value, props *Properties) error { v = reflect.Indirect(v) @@ -519,8 +521,8 @@ func (tm *TextMarshaler) writeAny(w *textWriter, v reflect.Value, props *Propert // mutating this value. v = v.Addr() } - if etm, ok := v.Interface().(encoding.TextMarshaler); ok { - text, err := etm.MarshalText() + if v.Type().Implements(textMarshalerType) { + text, err := v.Interface().(encoding.TextMarshaler).MarshalText() if err != nil { return err }