Skip to content

Commit 3b54b57

Browse files
committed
GODRIVER-3181 Read server responses in the background after op timeout.
1 parent c5205e2 commit 3b54b57

File tree

13 files changed

+779
-24
lines changed

13 files changed

+779
-24
lines changed

internal/integration/csot_test.go

Lines changed: 523 additions & 0 deletions
Large diffs are not rendered by default.

internal/integration/mtest/mongotest.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,15 +208,15 @@ func (t *T) cleanup() {
208208
// Run creates a new T instance for a sub-test and runs the given callback. It also creates a new collection using the
209209
// given name which is available to the callback through the T.Coll variable and is dropped after the callback
210210
// returns.
211-
func (t *T) Run(name string, callback func(*T)) {
211+
func (t *T) Run(name string, callback func(mt *T)) {
212212
t.RunOpts(name, NewOptions(), callback)
213213
}
214214

215215
// RunOpts creates a new T instance for a sub-test with the given options. If the current environment does not satisfy
216216
// constraints specified in the options, the new sub-test will be skipped automatically. If the test is not skipped,
217217
// the callback will be run with the new T instance. RunOpts creates a new collection with the given name which is
218218
// available to the callback through the T.Coll variable and is dropped after the callback returns.
219-
func (t *T) RunOpts(name string, opts *Options, callback func(*T)) {
219+
func (t *T) RunOpts(name string, opts *Options, callback func(mt *T)) {
220220
t.T.Run(name, func(wrapped *testing.T) {
221221
sub := newT(wrapped, t.baseOpts, opts)
222222

internal/integration/unified/unified_spec_runner.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,18 @@ var (
6161
"unpin when a new transaction is started": "Implement GODRIVER-3034",
6262
"unpin when a non-transaction write operation uses a session": "Implement GODRIVER-3034",
6363
"unpin when a non-transaction read operation uses a session": "Implement GODRIVER-3034",
64+
65+
// DRIVERS-2722: Setting "maxTimeMS" on a command that creates a cursor
66+
// also limits the lifetime of the cursor. That may be surprising to
67+
// users, so omit "maxTimeMS" from operations that return user-managed
68+
// cursors.
69+
"timeoutMS can be overridden for a find": "maxTimeMS is disabled on find and aggregate. See DRIVERS-2722.",
70+
"timeoutMS can be configured for an operation - find on collection": "maxTimeMS is disabled on find and aggregate. See DRIVERS-2722.",
71+
"timeoutMS can be configured for an operation - aggregate on collection": "maxTimeMS is disabled on find and aggregate. See DRIVERS-2722.",
72+
"timeoutMS can be configured for an operation - aggregate on database": "maxTimeMS is disabled on find and aggregate. See DRIVERS-2722.",
73+
"operation is retried multiple times for non-zero timeoutMS - find on collection": "maxTimeMS is disabled on find and aggregate. See DRIVERS-2722.",
74+
"operation is retried multiple times for non-zero timeoutMS - aggregate on collection": "maxTimeMS is disabled on find and aggregate. See DRIVERS-2722.",
75+
"operation is retried multiple times for non-zero timeoutMS - aggregate on database": "maxTimeMS is disabled on find and aggregate. See DRIVERS-2722.",
6476
}
6577

6678
logMessageValidatorTimeout = 10 * time.Millisecond

mongo/collection.go

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,12 @@ func aggregate(a aggregateParams, opts ...options.Lister[options.AggregateOption
948948
Crypt(a.client.cryptFLE).
949949
ServerAPI(a.client.serverAPI).
950950
HasOutputStage(hasOutputStage).
951-
Timeout(a.client.timeout)
951+
Timeout(a.client.timeout).
952+
// Omit "maxTimeMS" from operations that return a user-managed cursor to
953+
// prevent confusing "cursor not found" errors.
954+
//
955+
// See DRIVERS-2722 for more detail.
956+
OmitMaxTimeMS(true)
952957

953958
if args.AllowDiskUse != nil {
954959
op.AllowDiskUse(*args.AllowDiskUse)
@@ -1292,11 +1297,20 @@ func (coll *Collection) Find(ctx context.Context, filter interface{},
12921297
if err != nil {
12931298
return nil, err
12941299
}
1295-
return coll.find(ctx, filter, args)
1300+
1301+
// Omit "maxTimeMS" from operations that return a user-managed cursor to
1302+
// prevent confusing "cursor not found" errors.
1303+
//
1304+
// See DRIVERS-2722 for more detail.
1305+
return coll.find(ctx, filter, true, args)
12961306
}
12971307

1298-
func (coll *Collection) find(ctx context.Context, filter interface{},
1299-
args *options.FindOptions) (cur *Cursor, err error) {
1308+
func (coll *Collection) find(
1309+
ctx context.Context,
1310+
filter interface{},
1311+
omitMaxTimeMS bool,
1312+
args *options.FindOptions,
1313+
) (cur *Cursor, err error) {
13001314

13011315
if ctx == nil {
13021316
ctx = context.Background()
@@ -1334,7 +1348,7 @@ func (coll *Collection) find(ctx context.Context, filter interface{},
13341348
CommandMonitor(coll.client.monitor).ServerSelector(selector).
13351349
ClusterClock(coll.client.clock).Database(coll.db.name).Collection(coll.name).
13361350
Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI).
1337-
Timeout(coll.client.timeout).Logger(coll.client.logger)
1351+
Timeout(coll.client.timeout).Logger(coll.client.logger).OmitMaxTimeMS(omitMaxTimeMS)
13381352

13391353
cursorOpts := coll.client.createBaseCursorOptions()
13401354

@@ -1499,7 +1513,7 @@ func (coll *Collection) FindOne(ctx context.Context, filter interface{},
14991513
if err != nil {
15001514
return nil
15011515
}
1502-
cursor, err := coll.find(ctx, filter, newFindArgsFromFindOneArgs(args))
1516+
cursor, err := coll.find(ctx, filter, false, newFindArgsFromFindOneArgs(args))
15031517
return &SingleResult{
15041518
ctx: ctx,
15051519
cur: cursor,

x/mongo/driver/batch_cursor.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,12 @@ func (bc *BatchCursor) getMore(ctx context.Context) {
442442
Crypt: bc.crypt,
443443
ServerAPI: bc.serverAPI,
444444

445+
// Omit the automatically-calculated maxTimeMS because setting maxTimeMS
446+
// on a non-awaitData cursor causes a server error. For awaitData
447+
// cursors, maxTimeMS is set by the above CommandFn when maxAwaitTime is
448+
// specified.
449+
OmitMaxTimeMS: true,
450+
445451
// No read preference is passed to the getMore command,
446452
// resulting in the default read preference: "primaryPreferred".
447453
// Since this could be confusing, and there is no requirement

x/mongo/driver/errors.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,14 +509,28 @@ func ExtractErrorFromServerResponse(doc bsoncore.Document) error {
509509
errmsg = "command failed"
510510
}
511511

512-
return Error{
512+
err := Error{
513513
Code: code,
514514
Message: errmsg,
515515
Name: codeName,
516516
Labels: labels,
517517
TopologyVersion: tv,
518518
Raw: doc,
519519
}
520+
521+
// If we get a MaxTimeMSExpired error, assume that the error was caused
522+
// by setting "maxTimeMS" on the command based on the context deadline
523+
// or on "timeoutMS". In that case, make the error wrap
524+
// context.DeadlineExceeded so that users can always check
525+
//
526+
// errors.Is(err, context.DeadlineExceeded)
527+
//
528+
// for either client-side or server-side timeouts.
529+
if err.Code == 50 {
530+
err.Wrapped = context.DeadlineExceeded
531+
}
532+
533+
return err
520534
}
521535

522536
if len(wcError.WriteErrors) > 0 || wcError.WriteConcernError != nil {

x/mongo/driver/operation.go

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,7 +1186,7 @@ func (op Operation) addBatchArray(dst []byte) []byte {
11861186

11871187
func (op Operation) createLegacyHandshakeWireMessage(
11881188
ctx context.Context,
1189-
maxTimeMS uint64,
1189+
maxTimeMS int64,
11901190
dst []byte,
11911191
desc description.SelectedServer,
11921192
) ([]byte, startedInformation, error) {
@@ -1245,7 +1245,7 @@ func (op Operation) createLegacyHandshakeWireMessage(
12451245
// If maxTimeMS is greater than 0 append it to wire message. A maxTimeMS value of 0 only explicitly
12461246
// specifies the default behavior of no timeout server-side.
12471247
if maxTimeMS > 0 {
1248-
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", int64(maxTimeMS))
1248+
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", maxTimeMS)
12491249
}
12501250

12511251
dst, _ = bsoncore.AppendDocumentEnd(dst, idx)
@@ -1266,7 +1266,7 @@ func (op Operation) createLegacyHandshakeWireMessage(
12661266

12671267
func (op Operation) createMsgWireMessage(
12681268
ctx context.Context,
1269-
maxTimeMS uint64,
1269+
maxTimeMS int64,
12701270
dst []byte,
12711271
desc description.SelectedServer,
12721272
conn *mnet.Connection,
@@ -1316,7 +1316,7 @@ func (op Operation) createMsgWireMessage(
13161316
// If maxTimeMS is greater than 0 append it to wire message. A maxTimeMS value of 0 only explicitly
13171317
// specifies the default behavior of no timeout server-side.
13181318
if maxTimeMS > 0 {
1319-
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", int64(maxTimeMS))
1319+
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", maxTimeMS)
13201320
}
13211321

13221322
dst = bsoncore.AppendStringElement(dst, "$db", op.Database)
@@ -1362,7 +1362,7 @@ func isLegacyHandshake(op Operation, desc description.SelectedServer) bool {
13621362

13631363
func (op Operation) createWireMessage(
13641364
ctx context.Context,
1365-
maxTimeMS uint64,
1365+
maxTimeMS int64,
13661366
dst []byte,
13671367
desc description.SelectedServer,
13681368
conn *mnet.Connection,
@@ -1620,7 +1620,7 @@ func (op Operation) addClusterTime(dst []byte, desc description.SelectedServer)
16201620
// if the ctx is a Timeout context. If the context is not a Timeout context, it uses the
16211621
// operation's MaxTimeMS if set. If no MaxTimeMS is set on the operation, and context is
16221622
// not a Timeout context, calculateMaxTimeMS returns 0.
1623-
func (op Operation) calculateMaxTimeMS(ctx context.Context, rttMin time.Duration, rttStats string) (uint64, error) {
1623+
func (op Operation) calculateMaxTimeMS(ctx context.Context, rttMin time.Duration, rttStats string) (int64, error) {
16241624
if op.OmitMaxTimeMS {
16251625
return 0, nil
16261626
}
@@ -1637,13 +1637,23 @@ func (op Operation) calculateMaxTimeMS(ctx context.Context, rttMin time.Duration
16371637
maxTimeMS := int64((remainingTimeout - rttMin + time.Millisecond - 1) / time.Millisecond)
16381638
if maxTimeMS <= 0 {
16391639
return 0, fmt.Errorf(
1640-
"remaining time %v until context deadline is less than or equal to rtt minimum: %w\n%v",
1640+
"remaining time %v until context deadline is less than or equal to min network round-trip time %v (%v): %w",
16411641
remainingTimeout,
1642-
ErrDeadlineWouldBeExceeded,
1643-
rttStats)
1642+
rttMin,
1643+
rttStats,
1644+
ErrDeadlineWouldBeExceeded)
16441645
}
16451646

1646-
return uint64(maxTimeMS), nil
1647+
// The server will return a "BadValue" error if maxTimeMS is greater
1648+
// than the maximum positive int32 value (about 24.9 days). If the
1649+
// user specified a timeout value greater than that, omit maxTimeMS
1650+
// and let the client-side timeout handle cancelling the op if the
1651+
// timeout is ever reached.
1652+
if maxTimeMS > math.MaxInt32 {
1653+
return 0, nil
1654+
}
1655+
1656+
return maxTimeMS, nil
16471657
}
16481658

16491659
// updateClusterTimes updates the cluster times for the session and cluster clock attached to this

x/mongo/driver/operation/aggregate.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ type Aggregate struct {
4848
hasOutputStage bool
4949
customOptions map[string]bsoncore.Value
5050
timeout *time.Duration
51+
omitMaxTimeMS bool
5152

5253
result driver.CursorResponse
5354
}
@@ -110,6 +111,7 @@ func (a *Aggregate) Execute(ctx context.Context) error {
110111
IsOutputAggregate: a.hasOutputStage,
111112
Timeout: a.timeout,
112113
Name: driverutil.AggregateOp,
114+
OmitMaxTimeMS: a.omitMaxTimeMS,
113115
}.Execute(ctx)
114116

115117
}
@@ -404,3 +406,14 @@ func (a *Aggregate) Timeout(timeout *time.Duration) *Aggregate {
404406
a.timeout = timeout
405407
return a
406408
}
409+
410+
// OmitMaxTimeMS omits the automatically-calculated "maxTimeMS" from the
411+
// command.
412+
func (a *Aggregate) OmitMaxTimeMS(omit bool) *Aggregate {
413+
if a == nil {
414+
a = new(Aggregate)
415+
}
416+
417+
a.omitMaxTimeMS = omit
418+
return a
419+
}

x/mongo/driver/operation/find.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ type Find struct {
6161
serverAPI *driver.ServerAPIOptions
6262
timeout *time.Duration
6363
logger *logger.Logger
64+
omitMaxTimeMS bool
6465
}
6566

6667
// NewFind constructs and returns a new Find.
@@ -107,6 +108,7 @@ func (f *Find) Execute(ctx context.Context) error {
107108
Timeout: f.timeout,
108109
Logger: f.logger,
109110
Name: driverutil.FindOp,
111+
OmitMaxTimeMS: f.omitMaxTimeMS,
110112
}.Execute(ctx)
111113
}
112114

@@ -547,3 +549,14 @@ func (f *Find) Logger(logger *logger.Logger) *Find {
547549
f.logger = logger
548550
return f
549551
}
552+
553+
// OmitMaxTimeMS omits the automatically-calculated "maxTimeMS" from the
554+
// command.
555+
func (f *Find) OmitMaxTimeMS(omit bool) *Find {
556+
if f == nil {
557+
f = new(Find)
558+
}
559+
560+
f.omitMaxTimeMS = omit
561+
return f
562+
}

x/mongo/driver/operation_test.go

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ func TestOperation(t *testing.T) {
286286
rtt RTTMonitor
287287
rttMin time.Duration
288288
rttStats string
289-
want uint64
289+
want int64
290290
err error
291291
}{
292292
{
@@ -644,6 +644,35 @@ func TestOperation(t *testing.T) {
644644
// the TransientTransactionError label.
645645
assert.Equal(t, err, context.Canceled, "expected context.Canceled error, got %v", err)
646646
})
647+
t.Run("ErrDeadlineWouldBeExceeded wraps context.DeadlineExceeded", func(t *testing.T) {
648+
// Create a deployment that returns a server that reports a 90th
649+
// percentile RTT of 1 minute.
650+
d := new(mockDeployment)
651+
d.returns.server = mockServer{
652+
conn: mnet.NewConnection(&mockConnection{}),
653+
rttMonitor: mockRTTMonitor{min: 1 * time.Minute},
654+
}
655+
656+
// Create an operation with a Timeout specified to enable CSOT behavior.
657+
var dur time.Duration
658+
op := Operation{
659+
Database: "foobar",
660+
Deployment: d,
661+
CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
662+
return dst, nil
663+
},
664+
Timeout: &dur,
665+
}
666+
667+
// Call the operation with a context with a deadline less than the 90th
668+
// percentile RTT configured above.
669+
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
670+
defer cancel()
671+
err := op.Execute(ctx)
672+
673+
assert.ErrorIs(t, err, ErrDeadlineWouldBeExceeded)
674+
assert.ErrorIs(t, err, context.DeadlineExceeded)
675+
})
647676
}
648677

649678
func createExhaustServerResponse(response bsoncore.Document, moreToCome bool) []byte {
@@ -709,6 +738,25 @@ func (m *mockServerSelector) String() string {
709738
panic("not implemented")
710739
}
711740

741+
type mockServer struct {
742+
conn *mnet.Connection
743+
err error
744+
rttMonitor RTTMonitor
745+
}
746+
747+
func (ms mockServer) Connection(context.Context) (*mnet.Connection, error) { return ms.conn, ms.err }
748+
func (ms mockServer) RTTMonitor() RTTMonitor { return ms.rttMonitor }
749+
750+
type mockRTTMonitor struct {
751+
ewma time.Duration
752+
min time.Duration
753+
stats string
754+
}
755+
756+
func (mrm mockRTTMonitor) EWMA() time.Duration { return mrm.ewma }
757+
func (mrm mockRTTMonitor) Min() time.Duration { return mrm.min }
758+
func (mrm mockRTTMonitor) Stats() string { return mrm.stats }
759+
712760
type mockConnection struct {
713761
// parameters
714762
pWriteWM []byte

x/mongo/driver/topology/connection.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ type connection struct {
7676

7777
driverConnectionID int64
7878
generation uint64
79+
80+
// awaitingResponse indicates that the server response was not completely
81+
// read before returning the connection to the pool.
82+
awaitingResponse bool
7983
}
8084

8185
// newConnection handles the creation of a connection. It does not connect the connection.
@@ -370,8 +374,16 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
370374

371375
dst, errMsg, err := c.read(ctx)
372376
if err != nil {
373-
// We closeConnection the connection because we don't know if there are other bytes left to read.
374-
c.close()
377+
if nerr := net.Error(nil); errors.As(err, &nerr) && nerr.Timeout() {
378+
// If the error was a timeout error, instead of closing the
379+
// connection mark it as awaiting response so the pool can read the
380+
// response before making it available to other operations.
381+
c.awaitingResponse = true
382+
} else {
383+
// Otherwise, and close the connection because we don't know what
384+
// the connection state is.
385+
c.close()
386+
}
375387
message := errMsg
376388
if errors.Is(err, io.EOF) {
377389
message = "socket was unexpectedly closed"

0 commit comments

Comments
 (0)