From a84f0f865b36d196672c7abdeb43392ef90896da Mon Sep 17 00:00:00 2001 From: Matt Dale <9760375+matthewdale@users.noreply.github.com> Date: Wed, 19 Feb 2025 13:29:39 -0800 Subject: [PATCH] GODRIVER-3470 Correct BSON unmarshaling logic for null values (#1924) [master] (#1945) Co-authored-by: Preston Vasquez (cherry picked from commit 25df82ff2956607022f725f544f10261dd08796d) --- bson/default_value_decoders.go | 8 ++- bson/unmarshal.go | 6 +- bson/unmarshal_value_test.go | 21 +++++++ bson/unmarshaling_cases_test.go | 100 ++++++++++++++++++++++++++++++++ 4 files changed, 133 insertions(+), 2 deletions(-) diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index 2f195329ca..f20871041b 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -1166,7 +1166,13 @@ func valueUnmarshalerDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Va return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} } - if vr.Type() == TypeNull { + // If BSON value is null and the go value is a pointer, then don't call + // UnmarshalBSONValue. Even if the Go pointer is already initialized (i.e., + // non-nil), encountering null in BSON will result in the pointer being + // directly set to nil here. Since the pointer is being replaced with nil, + // there is no opportunity (or reason) for the custom UnmarshalBSONValue logic + // to be called. + if vr.Type() == TypeNull && val.Kind() == reflect.Ptr { val.Set(reflect.Zero(val.Type())) return vr.ReadNull() diff --git a/bson/unmarshal.go b/bson/unmarshal.go index b1089fca9a..72870c10ab 100644 --- a/bson/unmarshal.go +++ b/bson/unmarshal.go @@ -36,7 +36,11 @@ type ValueUnmarshaler interface { } // Unmarshal parses the BSON-encoded data and stores the result in the value -// pointed to by val. If val is nil or not a pointer, Unmarshal returns an error. +// pointed to by val. If val is nil or not a pointer, Unmarshal returns an +// error. +// +// When unmarshaling BSON, if the BSON value is null and the Go value is a +// pointer, the pointer is set to nil without calling UnmarshalBSONValue. func Unmarshal(data []byte, val interface{}) error { vr := newDocumentReader(bytes.NewReader(data)) if l, err := vr.peekLength(); err != nil { diff --git a/bson/unmarshal_value_test.go b/bson/unmarshal_value_test.go index ffaea010c9..ca6ab6d125 100644 --- a/bson/unmarshal_value_test.go +++ b/bson/unmarshal_value_test.go @@ -13,6 +13,7 @@ import ( "testing" "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" ) @@ -39,6 +40,26 @@ func TestUnmarshalValue(t *testing.T) { }) } +func TestInitializedPointerDataWithBSONNull(t *testing.T) { + // Set up the test case with initialized pointers. + tc := unmarshalBehaviorTestCase{ + BSONValuePtrTracker: &unmarshalBSONValueCallTracker{}, + BSONPtrTracker: &unmarshalBSONCallTracker{}, + } + // Create BSON data where the '*_ptr_tracker' fields are explicitly set to + // null. + bytes := docToBytes(D{ + {Key: "bv_ptr_tracker", Value: nil}, + {Key: "b_ptr_tracker", Value: nil}, + }) + // Unmarshal the BSON data into the test case struct. This should set the + // pointer fields to nil due to the BSON null value. + err := Unmarshal(bytes, &tc) + require.NoError(t, err) + assert.Nil(t, tc.BSONValuePtrTracker) + assert.Nil(t, tc.BSONPtrTracker) +} + // tests covering GODRIVER-2779 func BenchmarkSliceCodecUnmarshal(b *testing.B) { benchmarks := []struct { diff --git a/bson/unmarshaling_cases_test.go b/bson/unmarshaling_cases_test.go index 71d22f32d6..6e84f80e28 100644 --- a/bson/unmarshaling_cases_test.go +++ b/bson/unmarshaling_cases_test.go @@ -172,6 +172,70 @@ func unmarshalingTestCases() []unmarshalingTestCase { want: &valNonPtrStruct, data: docToBytes(valNonPtrStruct), }, + { + name: "nil pointer and non-pointer type with literal null BSON", + sType: reflect.TypeOf(unmarshalBehaviorTestCase{}), + want: &unmarshalBehaviorTestCase{ + BSONValueTracker: unmarshalBSONValueCallTracker{ + called: true, + }, + BSONValuePtrTracker: nil, + BSONTracker: unmarshalBSONCallTracker{ + called: true, + }, + BSONPtrTracker: nil, + }, + data: docToBytes(D{ + {Key: "bv_tracker", Value: nil}, + {Key: "bv_ptr_tracker", Value: nil}, + {Key: "b_tracker", Value: nil}, + {Key: "b_ptr_tracker", Value: nil}, + }), + }, + { + name: "nil pointer and non-pointer type with BSON minkey", + sType: reflect.TypeOf(unmarshalBehaviorTestCase{}), + want: &unmarshalBehaviorTestCase{ + BSONValueTracker: unmarshalBSONValueCallTracker{ + called: true, + }, + BSONValuePtrTracker: &unmarshalBSONValueCallTracker{ + called: true, + }, + BSONTracker: unmarshalBSONCallTracker{ + called: true, + }, + BSONPtrTracker: nil, + }, + data: docToBytes(D{ + {Key: "bv_tracker", Value: MinKey{}}, + {Key: "bv_ptr_tracker", Value: MinKey{}}, + {Key: "b_tracker", Value: MinKey{}}, + {Key: "b_ptr_tracker", Value: MinKey{}}, + }), + }, + { + name: "nil pointer and non-pointer type with BSON maxkey", + sType: reflect.TypeOf(unmarshalBehaviorTestCase{}), + want: &unmarshalBehaviorTestCase{ + BSONValueTracker: unmarshalBSONValueCallTracker{ + called: true, + }, + BSONValuePtrTracker: &unmarshalBSONValueCallTracker{ + called: true, + }, + BSONTracker: unmarshalBSONCallTracker{ + called: true, + }, + BSONPtrTracker: nil, + }, + data: docToBytes(D{ + {Key: "bv_tracker", Value: MaxKey{}}, + {Key: "bv_ptr_tracker", Value: MaxKey{}}, + {Key: "b_tracker", Value: MaxKey{}}, + {Key: "b_ptr_tracker", Value: MaxKey{}}, + }), + }, } } @@ -267,3 +331,39 @@ func (ms *myString) UnmarshalBSON(b []byte) error { *ms = myString(s) return nil } + +// unmarshalBSONValueCallTracker is a test struct that tracks whether the +// UnmarshalBSONValue method has been called. +type unmarshalBSONValueCallTracker struct { + called bool // called is set to true when UnmarshalBSONValue is invoked. +} + +var _ ValueUnmarshaler = &unmarshalBSONValueCallTracker{} + +// unmarshalBSONCallTracker is a test struct that tracks whether the +// UnmarshalBSON method has been called. +type unmarshalBSONCallTracker struct { + called bool // called is set to true when UnmarshalBSON is invoked. +} + +// Ensure unmarshalBSONCallTracker implements the Unmarshaler interface. +var _ Unmarshaler = &unmarshalBSONCallTracker{} + +// unmarshalBehaviorTestCase holds instances of call trackers for testing BSON +// unmarshaling behavior. +type unmarshalBehaviorTestCase struct { + BSONValueTracker unmarshalBSONValueCallTracker `bson:"bv_tracker"` // BSON value unmarshaling by value. + BSONValuePtrTracker *unmarshalBSONValueCallTracker `bson:"bv_ptr_tracker"` // BSON value unmarshaling by pointer. + BSONTracker unmarshalBSONCallTracker `bson:"b_tracker"` // BSON unmarshaling by value. + BSONPtrTracker *unmarshalBSONCallTracker `bson:"b_ptr_tracker"` // BSON unmarshaling by pointer. +} + +func (tracker *unmarshalBSONValueCallTracker) UnmarshalBSONValue(byte, []byte) error { + tracker.called = true + return nil +} + +func (tracker *unmarshalBSONCallTracker) UnmarshalBSON([]byte) error { + tracker.called = true + return nil +}