Skip to content

GODRIVER-2037 Don't clear the connection pool on client-side connect timeout errors. #688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 30, 2021
Merged
55 changes: 50 additions & 5 deletions mongo/integration/sdam_error_handling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ func TestSDAMErrorHandling(t *testing.T) {
}
baseMtOpts := func() *mtest.Options {
mtOpts := mtest.NewOptions().
Topologies(mtest.ReplicaSet). // Don't run on sharded clusters to avoid complexity of sharded failpoints.
MinServerVersion("4.0"). // 4.0+ is required to use failpoints on replica sets.
Topologies(mtest.ReplicaSet, mtest.Single). // Don't run on sharded clusters to avoid complexity of sharded failpoints.
MinServerVersion("4.0"). // 4.0+ is required to use failpoints on replica sets.
ClientOptions(baseClientOpts())

if mtest.ClusterTopologyKind() == mtest.Sharded {
Expand All @@ -48,7 +48,7 @@ func TestSDAMErrorHandling(t *testing.T) {
// blockConnection and appName.
mt.RunOpts("before handshake completes", baseMtOpts().Auth(true).MinServerVersion("4.4"), func(mt *mtest.T) {
mt.RunOpts("network errors", noClientOpts, func(mt *mtest.T) {
mt.Run("pool cleared on network timeout", func(mt *mtest.T) {
mt.Run("pool not cleared on operation-scoped network timeout", func(mt *mtest.T) {
// Assert that the pool is cleared when a connection created by an application operation thread
// encounters a network timeout during handshaking. Unlike the non-timeout test below, we only test
// connections created in the foreground for timeouts because connections created by the pool
Expand Down Expand Up @@ -78,16 +78,60 @@ func TestSDAMErrorHandling(t *testing.T) {
mt.ResetClient(clientOpts)
clearPoolChan()

// The saslContinue blocks for 150ms so run the InsertOne with a 100ms context to cause a network
// timeout during auth and assert that the pool was cleared.
// Use a context with a 100ms timeout so that the saslContinue delay of 150ms causes
// an operation-scoped context timeout (i.e. a timeout not caused by a client timeout
// like connectTimeoutMS or socketTimeoutMS).
timeoutCtx, cancel := context.WithTimeout(mtest.Background, 100*time.Millisecond)
defer cancel()
_, err := mt.Coll.InsertOne(timeoutCtx, bson.D{{"test", 1}})
assert.NotNil(mt, err, "expected InsertOne error, got nil")
assert.True(mt, mongo.IsTimeout(err), "expected timeout error, got %v", err)
assert.True(mt, mongo.IsNetworkError(err), "expected network error, got %v", err)
assert.False(mt, isPoolCleared(), "expected pool not to be cleared but was cleared")
})

mt.Run("pool cleared on non-operation-scoped network timeout", func(mt *mtest.T) {
// Assert that the pool is cleared when a connection created by an application operation thread
// encounters a network timeout during handshaking. Unlike the non-timeout test below, we only test
// connections created in the foreground for timeouts because connections created by the pool
// maintenance routine can't be timed out using a context.

appName := "authNetworkTimeoutTest"
// Set failpoint on saslContinue instead of saslStart because saslStart isn't done when using
// speculative auth.
mt.SetFailPoint(mtest.FailPoint{
ConfigureFailPoint: "failCommand",
Mode: mtest.FailPointMode{
Times: 1,
},
Data: mtest.FailPointData{
FailCommands: []string{"saslContinue"},
BlockConnection: true,
BlockTimeMS: 150,
AppName: appName,
},
})

// Reset the client with the appName specified in the failpoint.
clientOpts := options.Client().
SetAppName(appName).
SetRetryWrites(false).
SetPoolMonitor(poolMonitor).
// Set a 100ms socket timeout so that the saslContinue delay of 150ms causes a
// timeout during socket read (i.e. a timeout not caused by the InsertOne context).
SetSocketTimeout(100 * time.Millisecond)
mt.ResetClient(clientOpts)
clearPoolChan()

// Use context.Background() so that the new connection will not time out due to an
// operation-scoped timeout.
_, err := mt.Coll.InsertOne(context.Background(), bson.D{{"test", 1}})
assert.NotNil(mt, err, "expected InsertOne error, got nil")
assert.True(mt, mongo.IsTimeout(err), "expected timeout error, got %v", err)
assert.True(mt, mongo.IsNetworkError(err), "expected network error, got %v", err)
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
})

mt.RunOpts("pool cleared on non-timeout network error", noClientOpts, func(mt *mtest.T) {
mt.Run("background", func(mt *mtest.T) {
// Assert that the pool is cleared when a connection created by the background pool maintenance
Expand Down Expand Up @@ -116,6 +160,7 @@ func TestSDAMErrorHandling(t *testing.T) {
time.Sleep(200 * time.Millisecond)
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
})

mt.Run("foreground", func(mt *mtest.T) {
// Assert that the pool is cleared when a connection created by an application thread connection
// checkout encounters a non-timeout network error during handshaking.
Expand Down
10 changes: 5 additions & 5 deletions x/mongo/driver/topology/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ func newConnection(addr address.Address, opts ...ConnectionOption) (*connection,
return c, nil
}

func (c *connection) processInitializationError(err error) {
func (c *connection) processInitializationError(err, ctxErr error) {
atomic.StoreInt32(&c.connected, disconnected)
if c.nc != nil {
_ = c.nc.Close()
}

c.connectErr = ConnectionError{Wrapped: err, init: true}
if c.config.errorHandlingCallback != nil {
c.config.errorHandlingCallback(c.connectErr, c.generation, c.desc.ServiceID)
c.config.errorHandlingCallback(c.connectErr, ctxErr, c.generation, c.desc.ServiceID)
}
}

Expand Down Expand Up @@ -184,7 +184,7 @@ func (c *connection) connect(ctx context.Context) {
var tempNc net.Conn
tempNc, err = c.config.dialer.DialContext(dialCtx, c.addr.Network(), c.addr.String())
if err != nil {
c.processInitializationError(err)
c.processInitializationError(err, ctx.Err())
return
}
c.nc = tempNc
Expand All @@ -200,7 +200,7 @@ func (c *connection) connect(ctx context.Context) {
}
tlsNc, err := configureTLS(dialCtx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts)
if err != nil {
c.processInitializationError(err)
c.processInitializationError(err, ctx.Err())
return
}
c.nc = tlsNc
Expand Down Expand Up @@ -252,7 +252,7 @@ func (c *connection) connect(ctx context.Context) {

// We have a failed handshake here
if err != nil {
c.processInitializationError(err)
c.processInitializationError(err, ctx.Err())
return
}

Expand Down
4 changes: 2 additions & 2 deletions x/mongo/driver/topology/connection_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ type connectionConfig struct {
zstdLevel *int
ocspCache ocsp.Cache
disableOCSPEndpointCheck bool
errorHandlingCallback func(error, uint64, *primitive.ObjectID)
errorHandlingCallback func(error, error, uint64, *primitive.ObjectID)
tlsConnectionSource tlsConnectionSource
loadBalanced bool
getGenerationFn generationNumberFn
Expand Down Expand Up @@ -92,7 +92,7 @@ func withTLSConnectionSource(fn func(tlsConnectionSource) tlsConnectionSource) C
}
}

func withErrorHandlingCallback(fn func(error, uint64, *primitive.ObjectID)) ConnectionOption {
func withErrorHandlingCallback(fn func(error, error, uint64, *primitive.ObjectID)) ConnectionOption {
return func(c *connectionConfig) error {
c.errorHandlingCallback = fn
return nil
Expand Down
2 changes: 1 addition & 1 deletion x/mongo/driver/topology/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func TestConnection(t *testing.T) {
return &net.TCPConn{}, nil
})
}),
withErrorHandlingCallback(func(err error, _ uint64, _ *primitive.ObjectID) {
withErrorHandlingCallback(func(err, ctxErr error, _ uint64, _ *primitive.ObjectID) {
got = err
}),
)
Expand Down
2 changes: 1 addition & 1 deletion x/mongo/driver/topology/sdam_spec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ func applyErrors(t *testing.T, topo *Topology, errors []applicationError) {

switch appErr.When {
case "beforeHandshakeCompletes":
server.ProcessHandshakeError(currError, generation, nil)
server.ProcessHandshakeError(currError, nil, generation, nil)
case "afterHandshakeCompletes":
_ = server.ProcessError(currError, &conn)
default:
Expand Down
30 changes: 29 additions & 1 deletion x/mongo/driver/topology/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,9 @@ func (s *Server) Connection(ctx context.Context) (driver.Connection, error) {
}

// ProcessHandshakeError implements SDAM error handling for errors that occur before a connection finishes handshaking.
func (s *Server) ProcessHandshakeError(err error, startingGenerationNumber uint64, serviceID *primitive.ObjectID) {
// ctxErr is any error caused by the context passed to Server#Connection() and is used to determine whether or not an
// operation-scoped context deadline or cancellation was the cause of the handshake error.
func (s *Server) ProcessHandshakeError(err, ctxErr error, startingGenerationNumber uint64, serviceID *primitive.ObjectID) {
// Ignore the error if the server is behind a load balancer but the service ID is unknown. This indicates that the
// error happened when dialing the connection or during the MongoDB handshake, so we don't know the service ID to
// use for clearing the pool.
Expand All @@ -290,6 +292,32 @@ func (s *Server) ProcessHandshakeError(err error, startingGenerationNumber uint6
return
}

isTimeout := func(err error) bool {
for err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return true
}
// Handle the case where an error has been replaced by "net.errCanceled", which isn't
// exported and can't be compared directly. In this case, just compare the error message.
if err.Error() == "operation was canceled" {
return true
}
if wrapper, ok := err.(interface{ Unwrap() error }); ok {
err = wrapper.Unwrap()
} else {
break
}
}

return false
}

// Ignore errors that indicate a client-side timeout occurred when using an operation-scoped
// deadline (i.e. not using connectTimeoutMS as the connection timeout).
if (ctxErr == context.DeadlineExceeded || ctxErr == context.Canceled) && isTimeout(wrappedConnErr) {
return
}

// Since the only kind of ConnectionError we receive from pool.Get will be an initialization error, we should set
// the description.Server appropriately. The description should not have a TopologyVersion because the staleness
// checking logic above has already determined that this description is not stale.
Expand Down
Loading