Skip to content

Commit ddacfd0

Browse files
committed
Simplify tests and fix comments.
1 parent 5db858a commit ddacfd0

File tree

2 files changed

+62
-79
lines changed

2 files changed

+62
-79
lines changed

mongo/integration/sdam_error_handling_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,13 @@ func TestSDAMErrorHandling(t *testing.T) {
4949
mt.RunOpts("before handshake completes", baseMtOpts().Auth(true).MinServerVersion("4.4"), func(mt *mtest.T) {
5050
mt.RunOpts("network errors", noClientOpts, func(mt *mtest.T) {
5151
mt.Run("pool not cleared on operation-scoped network timeout", func(mt *mtest.T) {
52-
// Assert that the pool is cleared when a connection created by an application operation thread
53-
// encounters a network timeout during handshaking. Unlike the non-timeout test below, we only test
54-
// connections created in the foreground for timeouts because connections created by the pool
55-
// maintenance routine can't be timed out using a context.
52+
// Assert that the pool is not cleared when a connection created by an application
53+
// operation thread encounters an operation timeout during handshaking. Unlike the
54+
// non-timeout test below, we only test connections created in the foreground for
55+
// timeouts because connections created by the pool maintenance routine can't be
56+
// timed out using a context.
5657

57-
appName := "authNetworkTimeoutTest"
58+
appName := "authOperationTimeoutTest"
5859
// Set failpoint on saslContinue instead of saslStart because saslStart isn't done when using
5960
// speculative auth.
6061
mt.SetFailPoint(mtest.FailPoint{
@@ -91,12 +92,11 @@ func TestSDAMErrorHandling(t *testing.T) {
9192
})
9293

9394
mt.Run("pool cleared on non-operation-scoped network timeout", func(mt *mtest.T) {
94-
// Assert that the pool is cleared when a connection created by an application operation thread
95-
// encounters a network timeout during handshaking. Unlike the non-timeout test below, we only test
96-
// connections created in the foreground for timeouts because connections created by the pool
97-
// maintenance routine can't be timed out using a context.
95+
// Assert that the pool is cleared when a connection created by an application
96+
// operation thread encounters a timeout caused by connectTimeoutMS during
97+
// handshaking.
9898

99-
appName := "authNetworkTimeoutTest"
99+
appName := "authConnectTimeoutTest"
100100
// Set failpoint on saslContinue instead of saslStart because saslStart isn't done when using
101101
// speculative auth.
102102
mt.SetFailPoint(mtest.FailPoint{

x/mongo/driver/topology/server_test.go

Lines changed: 52 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -57,109 +57,89 @@ func (cncd *channelNetConnDialer) DialContext(_ context.Context, _, _ string) (n
5757
func TestServerConnectionTimeout(t *testing.T) {
5858
testCases := []struct {
5959
desc string
60-
ctxFn func() (context.Context, context.CancelFunc)
6160
dialer func(Dialer) Dialer
6261
handshaker func(Handshaker) Handshaker
62+
operationTimeout time.Duration
63+
connectTimeout time.Duration
6364
expectErr bool
6465
expectPoolCleared bool
6566
}{
6667
{
67-
desc: "No errors should not clear the pool",
68-
ctxFn: func() (context.Context, context.CancelFunc) {
69-
return context.Background(), nil
70-
},
68+
desc: "successful connection should not clear the pool",
7169
expectErr: false,
7270
expectPoolCleared: false,
7371
},
7472
{
75-
desc: "Parent context deadline exceeded error during dialing should not clear the pool",
76-
ctxFn: func() (context.Context, context.CancelFunc) {
77-
return context.WithTimeout(context.Background(), 100*time.Millisecond)
78-
},
73+
desc: "operation timeout error during dialing should not clear the pool",
7974
dialer: func(Dialer) Dialer {
8075
var d net.Dialer
8176
return DialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
82-
// Sleep for at least 150ms and expect the context passed to server.Connection()
83-
// to time out during the sleep. Expect the error returned by DialContext() to
84-
// be treated as a timeout caused by an operation-scoped deadline.
77+
// Wait for the passed in context to time out. Expect the error returned by
78+
// DialContext() to be treated as a timeout caused by an operation-scoped deadline.
8579
// E.g. FindOne(context.WithTimeout(...))
86-
time.Sleep(150 * time.Millisecond)
80+
<-ctx.Done()
8781
return d.DialContext(ctx, network, addr)
8882
})
8983
},
84+
operationTimeout: 100 * time.Millisecond,
85+
connectTimeout: 1 * time.Minute,
9086
expectErr: true,
9187
expectPoolCleared: false,
9288
},
9389
{
94-
desc: "Child context deadline exceeded error during dialing should clear the pool",
95-
ctxFn: func() (context.Context, context.CancelFunc) {
96-
// Return a context with a timeout that will not be reached during the test.
97-
return context.WithTimeout(context.Background(), 1*time.Minute)
98-
},
90+
desc: "connectTimeMS timeout error during dialing should clear the pool",
9991
dialer: func(Dialer) Dialer {
10092
var d net.Dialer
10193
return DialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
102-
// Wrap the context in a context with an already-exceeded deadline. Expect the
103-
// error returned by DialContext() to be treated as a timeout caused by reaching
104-
// connectTimeoutMS.
105-
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(-1*time.Hour))
106-
defer cancel()
94+
// Wait for the passed in context to time out. Expect the error returned by
95+
// DialContext() to be treated as a timeout caused by reaching connectTimeoutMS.
96+
<-ctx.Done()
10797
return d.DialContext(ctx, network, addr)
10898
})
10999
},
100+
operationTimeout: 1 * time.Minute,
101+
connectTimeout: 100 * time.Millisecond,
110102
expectErr: true,
111103
expectPoolCleared: true,
112104
},
113105
{
114-
desc: "Parent context deadline exceeded error during handshake should not clear the pool",
115-
ctxFn: func() (context.Context, context.CancelFunc) {
116-
return context.WithTimeout(context.Background(), 100*time.Millisecond)
117-
},
118-
handshaker: func(Handshaker) Handshaker {
119-
h := auth.Handshaker(nil, &auth.HandshakeOptions{})
120-
return &testHandshaker{
121-
getHandshakeInformation: func(ctx context.Context, addr address.Address, c driver.Connection) (driver.HandshakeInformation, error) {
122-
// Sleep for at least 150ms and expect the context passed to
123-
// server.Connection() to time out during the sleep. Expect the error
124-
// returned by GetHandshakeInformation() to be treated as a timeout caused
125-
// by an operation-scoped deadline.
126-
// E.g. FindOne(context.WithTimeout(...))
127-
time.Sleep(150 * time.Millisecond)
128-
return h.GetHandshakeInformation(ctx, addr, c)
129-
},
130-
}
106+
desc: "connectTimeMS timeout error during dialing with no operation timeout should clear the pool",
107+
dialer: func(Dialer) Dialer {
108+
var d net.Dialer
109+
return DialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
110+
// Wait for the passed in context to time out. Expect the error returned by
111+
// DialContext() to be treated as a timeout caused by reaching connectTimeoutMS.
112+
<-ctx.Done()
113+
return d.DialContext(ctx, network, addr)
114+
})
131115
},
116+
operationTimeout: 0, // Uses a context.Background() with no timeout.
117+
connectTimeout: 100 * time.Millisecond,
132118
expectErr: true,
133-
expectPoolCleared: false,
119+
expectPoolCleared: true,
134120
},
135121
{
136-
desc: "Child context deadline exceeded error during handshake should clear the pool",
137-
ctxFn: func() (context.Context, context.CancelFunc) {
138-
// Return a context with a timeout that will not be reached during the test.
139-
return context.WithTimeout(context.Background(), 1*time.Minute)
140-
},
122+
desc: "operation timeout error during handshake should not clear the pool",
141123
handshaker: func(Handshaker) Handshaker {
142124
h := auth.Handshaker(nil, &auth.HandshakeOptions{})
143125
return &testHandshaker{
144126
getHandshakeInformation: func(ctx context.Context, addr address.Address, c driver.Connection) (driver.HandshakeInformation, error) {
145-
// Wrap the context in a context with an already-exceeded deadline. Expect
146-
// the error returned by GetHandshakeInformation() to be treated as a
147-
// timeout caused by connectTimeoutMS.
148-
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(-1*time.Hour))
149-
defer cancel()
127+
// Wait for the passed in context to time out. Expect the error returned by
128+
// GetHandshakeInformation() to be treated as a timeout caused by an
129+
// operation-scoped deadline.
130+
// E.g. FindOne(context.WithTimeout(...))
131+
<-ctx.Done()
150132
return h.GetHandshakeInformation(ctx, addr, c)
151133
},
152134
}
153135
},
136+
operationTimeout: 100 * time.Millisecond,
137+
connectTimeout: 1 * time.Minute,
154138
expectErr: true,
155-
expectPoolCleared: true,
139+
expectPoolCleared: false,
156140
},
157141
{
158-
desc: "Dial errors unrelated to context timeouts should clear the pool",
159-
ctxFn: func() (context.Context, context.CancelFunc) {
160-
// Return a context with a timeout that will not be reached during the test.
161-
return context.WithTimeout(context.Background(), 1*time.Minute)
162-
},
142+
desc: "dial errors unrelated to context timeouts should clear the pool",
163143
dialer: func(Dialer) Dialer {
164144
var d net.Dialer
165145
return DialerFunc(func(ctx context.Context, _, _ string) (net.Conn, error) {
@@ -171,22 +151,19 @@ func TestServerConnectionTimeout(t *testing.T) {
171151
expectPoolCleared: true,
172152
},
173153
{
174-
desc: "Context error with a dial error unrelated to context timeouts should clear the pool",
175-
ctxFn: func() (context.Context, context.CancelFunc) {
176-
return context.WithTimeout(context.Background(), 100*time.Millisecond)
177-
},
154+
desc: "context error with dial errors unrelated to context timeouts should clear the pool",
178155
dialer: func(Dialer) Dialer {
179156
var d net.Dialer
180157
return DialerFunc(func(ctx context.Context, _, _ string) (net.Conn, error) {
181158
// Try to dial an invalid TCP address and expect an error.
182159
c, err := d.DialContext(ctx, "tcp", "300.0.0.0:nope")
183-
// Sleep for at least 150ms and expect the context passed to server.Connection()
184-
// to time out during the sleep. Expect that the context error is ignored
185-
// because the dial error is not a timeout.
186-
time.Sleep(150 * time.Millisecond)
160+
// Wait for the passed in context to time out. Expect that the context error is
161+
// ignored because the dial error is not a timeout.
162+
<-ctx.Done()
187163
return c, err
188164
})
189165
},
166+
operationTimeout: 100 * time.Millisecond,
190167
expectErr: true,
191168
expectPoolCleared: true,
192169
},
@@ -205,10 +182,10 @@ func TestServerConnectionTimeout(t *testing.T) {
205182
var eventsWg sync.WaitGroup
206183
eventsWg.Add(1)
207184
go func() {
185+
defer eventsWg.Done()
208186
for evt := range eventsCh {
209187
events = append(events, evt)
210188
}
211-
eventsWg.Done()
212189
}()
213190

214191
// Create a TCP listener on a random port. The listener will accept connections but not
@@ -234,6 +211,9 @@ func TestServerConnectionTimeout(t *testing.T) {
234211
// Replace the default dialer and handshaker with the test dialer and handshaker, if
235212
// present.
236213
WithConnectionOptions(func(opts ...ConnectionOption) []ConnectionOption {
214+
if tc.connectTimeout > 0 {
215+
opts = append(opts, WithConnectTimeout(func(time.Duration) time.Duration { return tc.connectTimeout }))
216+
}
237217
if tc.dialer != nil {
238218
opts = append(opts, WithDialer(tc.dialer))
239219
}
@@ -249,9 +229,11 @@ func TestServerConnectionTimeout(t *testing.T) {
249229
require.NoError(t, err)
250230
require.NoError(t, server.Connect(nil))
251231

252-
// Use the context returned by the test case ctxFn to call Connection.
253-
ctx, cancel := tc.ctxFn()
254-
if cancel != nil {
232+
// Create a context with the operation timeout if one is specified in the test case.
233+
ctx := context.Background()
234+
if tc.operationTimeout > 0 {
235+
var cancel context.CancelFunc
236+
ctx, cancel = context.WithTimeout(ctx, tc.operationTimeout)
255237
defer cancel()
256238
}
257239
_, err = server.Connection(ctx)
@@ -263,6 +245,7 @@ func TestServerConnectionTimeout(t *testing.T) {
263245

264246
// Close the events channel and expect that no more events are sent on the channel. Then
265247
// wait for the events channel loop to return before inspecting the events slice.
248+
server.Disconnect(context.Background())
266249
close(eventsCh)
267250
eventsWg.Wait()
268251
require.NotEmpty(t, events, "expected more than 0 connection pool monitor events")

0 commit comments

Comments
 (0)