Skip to content

Commit b067477

Browse files
committed
Fix race condition between net deadline and context timeout in handshake error handling, improve SDAM error handling tests, add handshake cancellation error test.
1 parent ddacfd0 commit b067477

File tree

7 files changed

+183
-67
lines changed

7 files changed

+183
-67
lines changed

mongo/integration/primary_stepdown_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
package integration
88

99
import (
10+
"sync"
1011
"testing"
1112

1213
"go.mongodb.org/mongo-driver/bson"
@@ -23,7 +24,55 @@ const (
2324
errorInterruptedAtShutdown int32 = 11600
2425
)
2526

27+
// testPoolMonitor exposes an *event.PoolMonitor and collects all events logged to that
28+
// *event.PoolMonitor. It is safe to use from multiple concurrent goroutines.
29+
type testPoolMonitor struct {
30+
*event.PoolMonitor
31+
32+
events []*event.PoolEvent
33+
mu sync.RWMutex
34+
}
35+
36+
func newTestPoolMonitor() *testPoolMonitor {
37+
tpm := &testPoolMonitor{
38+
events: make([]*event.PoolEvent, 0),
39+
}
40+
tpm.PoolMonitor = &event.PoolMonitor{
41+
Event: func(evt *event.PoolEvent) {
42+
tpm.mu.Lock()
43+
defer tpm.mu.Unlock()
44+
tpm.events = append(tpm.events, evt)
45+
},
46+
}
47+
return tpm
48+
}
49+
50+
// Events returns a copy of the events collected by the testPoolMonitor. Filters can optionally be
51+
// applied to the returned events set and are applied using AND logic (i.e. all filters must return
52+
// true to include the event in the result).
53+
func (tpm *testPoolMonitor) Events(filters ...func(*event.PoolEvent) bool) []*event.PoolEvent {
54+
filtered := make([]*event.PoolEvent, 0, len(tpm.events))
55+
tpm.mu.RLock()
56+
defer tpm.mu.RUnlock()
57+
58+
for _, evt := range tpm.events {
59+
keep := true
60+
for _, filter := range filters {
61+
if !filter(evt) {
62+
keep = false
63+
}
64+
}
65+
if keep {
66+
filtered = append(filtered, evt)
67+
}
68+
}
69+
70+
return filtered
71+
}
72+
2673
var poolChan = make(chan *event.PoolEvent, 100)
74+
75+
// TODO(GODRIVER-2068): Replace all uses of poolMonitor with individual instances of testPoolMonitor.
2776
var poolMonitor = &event.PoolMonitor{
2877
Event: func(event *event.PoolEvent) {
2978
poolChan <- event

mongo/integration/sdam_error_handling_test.go

Lines changed: 89 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"time"
1717

1818
"go.mongodb.org/mongo-driver/bson"
19+
"go.mongodb.org/mongo-driver/event"
1920
"go.mongodb.org/mongo-driver/internal/testutil/assert"
2021
"go.mongodb.org/mongo-driver/mongo"
2122
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
@@ -28,7 +29,6 @@ func TestSDAMErrorHandling(t *testing.T) {
2829
return options.Client().
2930
ApplyURI(mtest.ClusterURI()).
3031
SetRetryWrites(false).
31-
SetPoolMonitor(poolMonitor).
3232
SetWriteConcern(mtest.MajorityWc)
3333
}
3434
baseMtOpts := func() *mtest.Options {
@@ -71,13 +71,9 @@ func TestSDAMErrorHandling(t *testing.T) {
7171
},
7272
})
7373

74-
// Reset the client with the appName specified in the failpoint.
75-
clientOpts := options.Client().
76-
SetAppName(appName).
77-
SetRetryWrites(false).
78-
SetPoolMonitor(poolMonitor)
79-
mt.ResetClient(clientOpts)
80-
clearPoolChan()
74+
// Reset the client with the appName specified in the failpoint and the pool monitor.
75+
tpm := newTestPoolMonitor()
76+
mt.ResetClient(baseClientOpts().SetAppName(appName).SetPoolMonitor(tpm.PoolMonitor))
8177

8278
// Use a context with a 100ms timeout so that the saslContinue delay of 150ms causes
8379
// an operation-scoped context timeout (i.e. a timeout not caused by a client timeout
@@ -88,7 +84,11 @@ func TestSDAMErrorHandling(t *testing.T) {
8884
assert.NotNil(mt, err, "expected InsertOne error, got nil")
8985
assert.True(mt, mongo.IsTimeout(err), "expected timeout error, got %v", err)
9086
assert.True(mt, mongo.IsNetworkError(err), "expected network error, got %v", err)
91-
assert.False(mt, isPoolCleared(), "expected pool not to be cleared but was cleared")
87+
88+
poolClearedEvents := tpm.Events(func(evt *event.PoolEvent) bool {
89+
return evt.Type == event.PoolCleared
90+
})
91+
assert.True(mt, len(poolClearedEvents) == 0, "expected pool not to be cleared but was cleared")
9292
})
9393

9494
mt.Run("pool cleared on non-operation-scoped network timeout", func(mt *mtest.T) {
@@ -112,24 +112,26 @@ func TestSDAMErrorHandling(t *testing.T) {
112112
},
113113
})
114114

115-
// Reset the client with the appName specified in the failpoint.
116-
clientOpts := options.Client().
115+
// Reset the client with the appName specified in the failpoint and the pool monitor.
116+
tpm := newTestPoolMonitor()
117+
mt.ResetClient(baseClientOpts().
117118
SetAppName(appName).
118-
SetRetryWrites(false).
119-
SetPoolMonitor(poolMonitor).
119+
SetPoolMonitor(tpm.PoolMonitor).
120120
// Set a 100ms socket timeout so that the saslContinue delay of 150ms causes a
121121
// timeout during socket read (i.e. a timeout not caused by the InsertOne context).
122-
SetSocketTimeout(100 * time.Millisecond)
123-
mt.ResetClient(clientOpts)
124-
clearPoolChan()
122+
SetSocketTimeout(100 * time.Millisecond))
125123

126124
// Use context.Background() so that the new connection will not time out due to an
127125
// operation-scoped timeout.
128126
_, err := mt.Coll.InsertOne(context.Background(), bson.D{{"test", 1}})
129127
assert.NotNil(mt, err, "expected InsertOne error, got nil")
130128
assert.True(mt, mongo.IsTimeout(err), "expected timeout error, got %v", err)
131129
assert.True(mt, mongo.IsNetworkError(err), "expected network error, got %v", err)
132-
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
130+
131+
poolClearedEvents := tpm.Events(func(evt *event.PoolEvent) bool {
132+
return evt.Type == event.PoolCleared
133+
})
134+
assert.True(mt, len(poolClearedEvents) > 0, "expected pool to be cleared but was not")
133135
})
134136

135137
mt.RunOpts("pool cleared on non-timeout network error", noClientOpts, func(mt *mtest.T) {
@@ -150,15 +152,20 @@ func TestSDAMErrorHandling(t *testing.T) {
150152
},
151153
})
152154

153-
clientOpts := options.Client().
155+
// Reset the client with the appName specified in the failpoint.
156+
tpm := newTestPoolMonitor()
157+
mt.ResetClient(baseClientOpts().
154158
SetAppName(appName).
155-
SetMinPoolSize(5).
156-
SetPoolMonitor(poolMonitor)
157-
mt.ResetClient(clientOpts)
158-
clearPoolChan()
159+
SetPoolMonitor(tpm.PoolMonitor).
160+
// Set minPoolSize to enable the background pool maintenance goroutine.
161+
SetMinPoolSize(5))
159162

160163
time.Sleep(200 * time.Millisecond)
161-
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
164+
165+
poolClearedEvents := tpm.Events(func(evt *event.PoolEvent) bool {
166+
return evt.Type == event.PoolCleared
167+
})
168+
assert.True(mt, len(poolClearedEvents) > 0, "expected pool to be cleared but was not")
162169
})
163170

164171
mt.Run("foreground", func(mt *mtest.T) {
@@ -178,24 +185,27 @@ func TestSDAMErrorHandling(t *testing.T) {
178185
},
179186
})
180187

181-
clientOpts := options.Client().
182-
SetAppName(appName).
183-
SetPoolMonitor(poolMonitor)
184-
mt.ResetClient(clientOpts)
185-
clearPoolChan()
188+
// Reset the client with the appName specified in the failpoint.
189+
tpm := newTestPoolMonitor()
190+
mt.ResetClient(baseClientOpts().SetAppName(appName).SetPoolMonitor(tpm.PoolMonitor))
186191

187192
_, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"x", 1}})
188193
assert.NotNil(mt, err, "expected InsertOne error, got nil")
189194
assert.False(mt, mongo.IsTimeout(err), "expected non-timeout error, got %v", err)
190-
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
195+
196+
poolClearedEvents := tpm.Events(func(evt *event.PoolEvent) bool {
197+
return evt.Type == event.PoolCleared
198+
})
199+
assert.True(mt, len(poolClearedEvents) > 0, "expected pool to be cleared but was not")
191200
})
192201
})
193202
})
194203
})
195204
mt.RunOpts("after handshake completes", baseMtOpts(), func(mt *mtest.T) {
196205
mt.RunOpts("network errors", noClientOpts, func(mt *mtest.T) {
197206
mt.Run("pool cleared on non-timeout network error", func(mt *mtest.T) {
198-
clearPoolChan()
207+
appName := "afterHandshakeNetworkError"
208+
199209
mt.SetFailPoint(mtest.FailPoint{
200210
ConfigureFailPoint: "failCommand",
201211
Mode: mtest.FailPointMode{
@@ -204,16 +214,26 @@ func TestSDAMErrorHandling(t *testing.T) {
204214
Data: mtest.FailPointData{
205215
FailCommands: []string{"insert"},
206216
CloseConnection: true,
217+
AppName: appName,
207218
},
208219
})
209220

221+
// Reset the client with the appName specified in the failpoint.
222+
tpm := newTestPoolMonitor()
223+
mt.ResetClient(baseClientOpts().SetAppName(appName).SetPoolMonitor(tpm.PoolMonitor))
224+
210225
_, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"test", 1}})
211226
assert.NotNil(mt, err, "expected InsertOne error, got nil")
212227
assert.False(mt, mongo.IsTimeout(err), "expected non-timeout error, got %v", err)
213-
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
228+
229+
poolClearedEvents := tpm.Events(func(evt *event.PoolEvent) bool {
230+
return evt.Type == event.PoolCleared
231+
})
232+
assert.True(mt, len(poolClearedEvents) > 0, "expected pool to be cleared but was not")
214233
})
215234
mt.Run("pool not cleared on timeout network error", func(mt *mtest.T) {
216-
clearPoolChan()
235+
tpm := newTestPoolMonitor()
236+
mt.ResetClient(baseClientOpts().SetPoolMonitor(tpm.PoolMonitor))
217237

218238
_, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"x", 1}})
219239
assert.Nil(mt, err, "InsertOne error: %v", err)
@@ -227,10 +247,14 @@ func TestSDAMErrorHandling(t *testing.T) {
227247
assert.NotNil(mt, err, "expected Find error, got %v", err)
228248
assert.True(mt, mongo.IsTimeout(err), "expected timeout error, got %v", err)
229249

230-
assert.False(mt, isPoolCleared(), "expected pool to not be cleared but was")
250+
poolClearedEvents := tpm.Events(func(evt *event.PoolEvent) bool {
251+
return evt.Type == event.PoolCleared
252+
})
253+
assert.True(mt, len(poolClearedEvents) == 0, "expected pool to not be cleared but was")
231254
})
232255
mt.Run("pool not cleared on context cancellation", func(mt *mtest.T) {
233-
clearPoolChan()
256+
tpm := newTestPoolMonitor()
257+
mt.ResetClient(baseClientOpts().SetPoolMonitor(tpm.PoolMonitor))
234258

235259
_, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"x", 1}})
236260
assert.Nil(mt, err, "InsertOne error: %v", err)
@@ -250,7 +274,10 @@ func TestSDAMErrorHandling(t *testing.T) {
250274
assert.True(mt, mongo.IsNetworkError(err), "expected network error, got %v", err)
251275
assert.True(mt, errors.Is(err, context.Canceled), "expected error %v to be context.Canceled", err)
252276

253-
assert.False(mt, isPoolCleared(), "expected pool to not be cleared but was")
277+
poolClearedEvents := tpm.Events(func(evt *event.PoolEvent) bool {
278+
return evt.Type == event.PoolCleared
279+
})
280+
assert.True(mt, len(poolClearedEvents) == 0, "expected pool to not be cleared but was")
254281
})
255282
})
256283
mt.RunOpts("server errors", noClientOpts, func(mt *mtest.T) {
@@ -287,28 +314,32 @@ func TestSDAMErrorHandling(t *testing.T) {
287314
}
288315
for _, tc := range testCases {
289316
mt.RunOpts(fmt.Sprintf("command error - %s", tc.name), serverErrorsMtOpts, func(mt *mtest.T) {
290-
clearPoolChan()
317+
appName := fmt.Sprintf("command_error_%s", tc.name)
291318

292319
// Cause the next insert to fail with an ok:0 response.
293-
fp := mtest.FailPoint{
320+
mt.SetFailPoint(mtest.FailPoint{
294321
ConfigureFailPoint: "failCommand",
295322
Mode: mtest.FailPointMode{
296323
Times: 1,
297324
},
298325
Data: mtest.FailPointData{
299326
FailCommands: []string{"insert"},
300327
ErrorCode: tc.errorCode,
328+
AppName: appName,
301329
},
302-
}
303-
mt.SetFailPoint(fp)
330+
})
331+
332+
// Reset the client with the appName specified in the failpoint.
333+
tpm := newTestPoolMonitor()
334+
mt.ResetClient(baseClientOpts().SetAppName(appName).SetPoolMonitor(tpm.PoolMonitor))
304335

305-
runServerErrorsTest(mt, tc.isShutdownError)
336+
runServerErrorsTest(mt, tc.isShutdownError, tpm)
306337
})
307338
mt.RunOpts(fmt.Sprintf("write concern error - %s", tc.name), serverErrorsMtOpts, func(mt *mtest.T) {
308-
clearPoolChan()
339+
appName := fmt.Sprintf("write_concern_error_%s", tc.name)
309340

310341
// Cause the next insert to fail with a write concern error.
311-
fp := mtest.FailPoint{
342+
mt.SetFailPoint(mtest.FailPoint{
312343
ConfigureFailPoint: "failCommand",
313344
Mode: mtest.FailPointMode{
314345
Times: 1,
@@ -318,32 +349,40 @@ func TestSDAMErrorHandling(t *testing.T) {
318349
WriteConcernError: &mtest.WriteConcernErrorData{
319350
Code: tc.errorCode,
320351
},
352+
AppName: appName,
321353
},
322-
}
323-
mt.SetFailPoint(fp)
354+
})
324355

325-
runServerErrorsTest(mt, tc.isShutdownError)
356+
// Reset the client with the appName specified in the failpoint.
357+
tpm := newTestPoolMonitor()
358+
mt.ResetClient(baseClientOpts().SetAppName(appName).SetPoolMonitor(tpm.PoolMonitor))
359+
360+
runServerErrorsTest(mt, tc.isShutdownError, tpm)
326361
})
327362
}
328363
})
329364
})
330365
}
331366

332-
func runServerErrorsTest(mt *mtest.T, isShutdownError bool) {
367+
func runServerErrorsTest(mt *mtest.T, isShutdownError bool, tpm *testPoolMonitor) {
333368
mt.Helper()
334369

335370
_, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"x", 1}})
336371
assert.NotNil(mt, err, "expected InsertOne error, got nil")
337372

373+
poolClearedEvents := tpm.Events(func(evt *event.PoolEvent) bool {
374+
return evt.Type == event.PoolCleared
375+
})
376+
isPoolCleared := len(poolClearedEvents) > 0
377+
338378
// The pool should always be cleared for shutdown errors, regardless of server version.
339379
if isShutdownError {
340-
assert.True(mt, isPoolCleared(), "expected pool to be cleared, but was not")
380+
assert.True(mt, isPoolCleared, "expected pool to be cleared, but was not")
341381
return
342382
}
343383

344384
// For non-shutdown errors, the pool is only cleared if the error is from a pre-4.2 server.
345385
wantCleared := mtest.CompareServerVersions(mtest.ServerVersion(), "4.2") < 0
346-
gotCleared := isPoolCleared()
347-
assert.Equal(mt, wantCleared, gotCleared, "expected pool to be cleared: %v; pool was cleared: %v",
348-
wantCleared, gotCleared)
386+
assert.Equal(mt, wantCleared, isPoolCleared, "expected pool to be cleared: %t; pool was cleared: %t",
387+
wantCleared, isPoolCleared)
349388
}

x/mongo/driver/topology/connection.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,15 @@ func newConnection(addr address.Address, opts ...ConnectionOption) (*connection,
101101
return c, nil
102102
}
103103

104-
func (c *connection) processInitializationError(err, ctxErr error) {
104+
func (c *connection) processInitializationError(err error, opCtx context.Context) {
105105
atomic.StoreInt32(&c.connected, disconnected)
106106
if c.nc != nil {
107107
_ = c.nc.Close()
108108
}
109109

110110
c.connectErr = ConnectionError{Wrapped: err, init: true}
111111
if c.config.errorHandlingCallback != nil {
112-
c.config.errorHandlingCallback(c.connectErr, ctxErr, c.generation, c.desc.ServiceID)
112+
c.config.errorHandlingCallback(c.connectErr, opCtx, c.generation, c.desc.ServiceID)
113113
}
114114
}
115115

@@ -184,7 +184,7 @@ func (c *connection) connect(ctx context.Context) {
184184
var tempNc net.Conn
185185
tempNc, err = c.config.dialer.DialContext(dialCtx, c.addr.Network(), c.addr.String())
186186
if err != nil {
187-
c.processInitializationError(err, ctx.Err())
187+
c.processInitializationError(err, ctx)
188188
return
189189
}
190190
c.nc = tempNc
@@ -200,7 +200,7 @@ func (c *connection) connect(ctx context.Context) {
200200
}
201201
tlsNc, err := configureTLS(dialCtx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts)
202202
if err != nil {
203-
c.processInitializationError(err, ctx.Err())
203+
c.processInitializationError(err, ctx)
204204
return
205205
}
206206
c.nc = tlsNc
@@ -252,7 +252,7 @@ func (c *connection) connect(ctx context.Context) {
252252

253253
// We have a failed handshake here
254254
if err != nil {
255-
c.processInitializationError(err, ctx.Err())
255+
c.processInitializationError(err, ctx)
256256
return
257257
}
258258

0 commit comments

Comments
 (0)