Skip to content

Commit 993adc4

Browse files
committed
GODRIVER-3470 Correct BSON unmarshaling logic for null values (mongodb#1924)
1 parent 7825d6d commit 993adc4

File tree

4 files changed

+112
-2
lines changed

4 files changed

+112
-2
lines changed

bson/default_value_decoders.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1221,7 +1221,13 @@ func valueUnmarshalerDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Va
12211221
return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val}
12221222
}
12231223

1224-
if vr.Type() == TypeNull {
1224+
// If BSON value is null and the go value is a pointer, then don't call
1225+
// UnmarshalBSONValue. Even if the Go pointer is already initialized (i.e.,
1226+
// non-nil), encountering null in BSON will result in the pointer being
1227+
// directly set to nil here. Since the pointer is being replaced with nil,
1228+
// there is no opportunity (or reason) for the custom UnmarshalBSONValue logic
1229+
// to be called.
1230+
if vr.Type() == TypeNull && val.Kind() == reflect.Ptr {
12251231
val.Set(reflect.Zero(val.Type()))
12261232

12271233
return vr.ReadNull()

bson/unmarshal.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ type ValueUnmarshaler interface {
3636
}
3737

3838
// Unmarshal parses the BSON-encoded data and stores the result in the value
39-
// pointed to by val. If val is nil or not a pointer, Unmarshal returns an error.
39+
// pointed to by val. If val is nil or not a pointer, Unmarshal returns an
40+
// error.
41+
//
42+
// When unmarshaling BSON, if the BSON value is null and the Go value is a
43+
// pointer, the pointer is set to nil without calling UnmarshalBSONValue.
4044
func Unmarshal(data []byte, val interface{}) error {
4145
vr := newDocumentReader(bytes.NewReader(data))
4246
if l, err := vr.peekLength(); err != nil {

bson/unmarshal_value_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"testing"
1414

1515
"go.mongodb.org/mongo-driver/v2/internal/assert"
16+
"go.mongodb.org/mongo-driver/v2/internal/require"
1617
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
1718
)
1819

@@ -33,6 +34,26 @@ func TestUnmarshalValue(t *testing.T) {
3334
}
3435
}
3536

37+
func TestInitializedPointerDataWithBSONNull(t *testing.T) {
38+
// Set up the test case with initialized pointers.
39+
tc := unmarshalBehaviorTestCase{
40+
BSONValuePtrTracker: &unmarshalBSONValueCallTracker{},
41+
BSONPtrTracker: &unmarshalBSONCallTracker{},
42+
}
43+
// Create BSON data where the '*_ptr_tracker' fields are explicitly set to
44+
// null.
45+
bytes := docToBytes(D{
46+
{Key: "bv_ptr_tracker", Value: nil},
47+
{Key: "b_ptr_tracker", Value: nil},
48+
})
49+
// Unmarshal the BSON data into the test case struct. This should set the
50+
// pointer fields to nil due to the BSON null value.
51+
err := Unmarshal(bytes, &tc)
52+
require.NoError(t, err)
53+
assert.Nil(t, tc.BSONValuePtrTracker)
54+
assert.Nil(t, tc.BSONPtrTracker)
55+
}
56+
3657
// tests covering GODRIVER-2779
3758
func BenchmarkSliceCodecUnmarshal(b *testing.B) {
3859
benchmarks := []struct {

bson/unmarshaling_cases_test.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,50 @@ func unmarshalingTestCases() []unmarshalingTestCase {
172172
want: &valNonPtrStruct,
173173
data: docToBytes(valNonPtrStruct),
174174
},
175+
{
176+
name: "nil pointer and non-pointer type with BSON minkey",
177+
sType: reflect.TypeOf(unmarshalBehaviorTestCase{}),
178+
want: &unmarshalBehaviorTestCase{
179+
BSONValueTracker: unmarshalBSONValueCallTracker{
180+
called: true,
181+
},
182+
BSONValuePtrTracker: &unmarshalBSONValueCallTracker{
183+
called: true,
184+
},
185+
BSONTracker: unmarshalBSONCallTracker{
186+
called: true,
187+
},
188+
BSONPtrTracker: nil,
189+
},
190+
data: docToBytes(D{
191+
{Key: "bv_tracker", Value: MinKey{}},
192+
{Key: "bv_ptr_tracker", Value: MinKey{}},
193+
{Key: "b_tracker", Value: MinKey{}},
194+
{Key: "b_ptr_tracker", Value: MinKey{}},
195+
}),
196+
},
197+
{
198+
name: "nil pointer and non-pointer type with BSON maxkey",
199+
sType: reflect.TypeOf(unmarshalBehaviorTestCase{}),
200+
want: &unmarshalBehaviorTestCase{
201+
BSONValueTracker: unmarshalBSONValueCallTracker{
202+
called: true,
203+
},
204+
BSONValuePtrTracker: &unmarshalBSONValueCallTracker{
205+
called: true,
206+
},
207+
BSONTracker: unmarshalBSONCallTracker{
208+
called: true,
209+
},
210+
BSONPtrTracker: nil,
211+
},
212+
data: docToBytes(D{
213+
{Key: "bv_tracker", Value: MaxKey{}},
214+
{Key: "bv_ptr_tracker", Value: MaxKey{}},
215+
{Key: "b_tracker", Value: MaxKey{}},
216+
{Key: "b_ptr_tracker", Value: MaxKey{}},
217+
}),
218+
},
175219
}
176220
}
177221

@@ -267,3 +311,38 @@ func (ms *myString) UnmarshalBSON(b []byte) error {
267311
*ms = myString(s)
268312
return nil
269313
}
314+
315+
// unmarshalBSONValueCallTracker is a test struct that tracks whether the
316+
// UnmarshalBSONValue method has been called.
317+
type unmarshalBSONValueCallTracker struct {
318+
called bool // called is set to true when UnmarshalBSONValue is invoked.
319+
}
320+
321+
var _ ValueUnmarshaler = &unmarshalBSONValueCallTracker{}
322+
323+
// unmarshalBSONCallTracker is a test struct that tracks whether the
324+
// UnmarshalBSON method has been called.
325+
type unmarshalBSONCallTracker struct {
326+
called bool // called is set to true when UnmarshalBSON is invoked.
327+
}
328+
329+
// Ensure unmarshalBSONCallTracker implements the Unmarshaler interface.
330+
var _ Unmarshaler = &unmarshalBSONCallTracker{}
331+
332+
// unmarshalBehaviorTestCase holds instances of call trackers for testing BSON
333+
// unmarshaling behavior.
334+
type unmarshalBehaviorTestCase struct {
335+
BSONValueTracker unmarshalBSONValueCallTracker `bson:"bv_tracker"` // BSON value unmarshaling by value.
336+
BSONValuePtrTracker *unmarshalBSONValueCallTracker `bson:"bv_ptr_tracker"` // BSON value unmarshaling by pointer.
337+
BSONTracker unmarshalBSONCallTracker `bson:"b_tracker"` // BSON unmarshaling by value.
338+
BSONPtrTracker *unmarshalBSONCallTracker `bson:"b_ptr_tracker"` // BSON unmarshaling by pointer.
339+
}
340+
341+
func (tracker *unmarshalBSONValueCallTracker) UnmarshalBSONValue(byte, []byte) error {
342+
tracker.called = true
343+
return nil
344+
}
345+
func (tracker *unmarshalBSONCallTracker) UnmarshalBSON([]byte) error {
346+
tracker.called = true
347+
return nil
348+
}

0 commit comments

Comments
 (0)