diff --git a/mongo/gridfs/bucket.go b/mongo/gridfs/bucket.go index acf157ddc2..5c2a5e3e29 100644 --- a/mongo/gridfs/bucket.go +++ b/mongo/gridfs/bucket.go @@ -12,7 +12,6 @@ import ( "errors" "fmt" "io" - "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" @@ -51,9 +50,6 @@ type Bucket struct { firstWriteDone bool readBuf []byte writeBuf []byte - - readDeadline time.Time - writeDeadline time.Time } // Upload contains options to upload a file to a bucket. @@ -120,30 +116,30 @@ func NewBucket(db *mongo.Database, opts ...*options.BucketOptions) (*Bucket, err return b, nil } -// SetWriteDeadline sets the write deadline for this bucket. -func (b *Bucket) SetWriteDeadline(t time.Time) error { - b.writeDeadline = t - return nil -} - -// SetReadDeadline sets the read deadline for this bucket -func (b *Bucket) SetReadDeadline(t time.Time) error { - b.readDeadline = t - return nil -} - -// OpenUploadStream creates a file ID new upload stream for a file given the filename. -func (b *Bucket) OpenUploadStream(filename string, opts ...*options.UploadOptions) (*UploadStream, error) { - return b.OpenUploadStreamWithID(primitive.NewObjectID(), filename, opts...) -} - -// OpenUploadStreamWithID creates a new upload stream for a file given the file ID and filename. -func (b *Bucket) OpenUploadStreamWithID(fileID interface{}, filename string, opts ...*options.UploadOptions) (*UploadStream, error) { - ctx, cancel := deadlineContext(b.writeDeadline) - if cancel != nil { - defer cancel() - } - +// OpenUploadStream creates a file ID new upload stream for a file given the +// filename. +// +// The context provided to this method controls the entire lifetime of an +// upload stream io.Writer. +func (b *Bucket) OpenUploadStream( + ctx context.Context, + filename string, + opts ...*options.UploadOptions, +) (*UploadStream, error) { + return b.OpenUploadStreamWithID(ctx, primitive.NewObjectID(), filename, opts...) +} + +// OpenUploadStreamWithID creates a new upload stream for a file given the file +// ID and filename. +// +// The context provided to this method controls the entire lifetime of an +// upload stream io.Writer. +func (b *Bucket) OpenUploadStreamWithID( + ctx context.Context, + fileID interface{}, + filename string, + opts ...*options.UploadOptions, +) (*UploadStream, error) { if err := b.checkFirstWrite(ctx); err != nil { return nil, err } @@ -153,32 +149,45 @@ func (b *Bucket) OpenUploadStreamWithID(fileID interface{}, filename string, opt return nil, err } - return newUploadStream(upload, fileID, filename, b.chunksColl, b.filesColl), nil + return newUploadStream(ctx, upload, fileID, filename, b.chunksColl, b.filesColl), nil } // UploadFromStream creates a fileID and uploads a file given a source stream. // -// If this upload requires a custom write deadline to be set on the bucket, it cannot be done concurrently with other -// write operations operations on this bucket that also require a custom deadline. -func (b *Bucket) UploadFromStream(filename string, source io.Reader, opts ...*options.UploadOptions) (primitive.ObjectID, error) { +// If this upload requires a custom write deadline to be set on the bucket, it +// cannot be done concurrently with other write operations operations on this +// bucket that also require a custom deadline. +// +// The context provided to this method controls the entire lifetime of an +// upload stream io.Writer. +func (b *Bucket) UploadFromStream( + ctx context.Context, + filename string, + source io.Reader, + opts ...*options.UploadOptions, +) (primitive.ObjectID, error) { fileID := primitive.NewObjectID() - err := b.UploadFromStreamWithID(fileID, filename, source, opts...) + err := b.UploadFromStreamWithID(ctx, fileID, filename, source, opts...) return fileID, err } // UploadFromStreamWithID uploads a file given a source stream. // -// If this upload requires a custom write deadline to be set on the bucket, it cannot be done concurrently with other -// write operations operations on this bucket that also require a custom deadline. -func (b *Bucket) UploadFromStreamWithID(fileID interface{}, filename string, source io.Reader, opts ...*options.UploadOptions) error { - us, err := b.OpenUploadStreamWithID(fileID, filename, opts...) - if err != nil { - return err - } - - err = us.SetWriteDeadline(b.writeDeadline) +// If this upload requires a custom write deadline to be set on the bucket, it +// cannot be done concurrently with other write operations operations on this +// bucket that also require a custom deadline. +// +// The context provided to this method controls the entire lifetime of an +// upload stream io.Writer. +func (b *Bucket) UploadFromStreamWithID( + ctx context.Context, + fileID interface{}, + filename string, + source io.Reader, + opts ...*options.UploadOptions, +) error { + us, err := b.OpenUploadStreamWithID(ctx, fileID, filename, opts...) if err != nil { - _ = us.Close() return err } @@ -204,20 +213,27 @@ func (b *Bucket) UploadFromStreamWithID(fileID interface{}, filename string, sou return us.Close() } -// OpenDownloadStream creates a stream from which the contents of the file can be read. -func (b *Bucket) OpenDownloadStream(fileID interface{}) (*DownloadStream, error) { - return b.openDownloadStream(bson.D{ - {"_id", fileID}, - }) +// OpenDownloadStream creates a stream from which the contents of the file can +// be read. +// +// The context provided to this method controls the entire lifetime of a +// download stream io.Reader. +func (b *Bucket) OpenDownloadStream(ctx context.Context, fileID interface{}) (*DownloadStream, error) { + return b.openDownloadStream(ctx, bson.D{{"_id", fileID}}) } -// DownloadToStream downloads the file with the specified fileID and writes it to the provided io.Writer. -// Returns the number of bytes written to the stream and an error, or nil if there was no error. +// DownloadToStream downloads the file with the specified fileID and writes it +// to the provided io.Writer. Returns the number of bytes written to the stream +// and an error, or nil if there was no error. +// +// If this download requires a custom read deadline to be set on the bucket, it +// cannot be done concurrently with other read operations operations on this +// bucket that also require a custom deadline. // -// If this download requires a custom read deadline to be set on the bucket, it cannot be done concurrently with other -// read operations operations on this bucket that also require a custom deadline. -func (b *Bucket) DownloadToStream(fileID interface{}, stream io.Writer) (int64, error) { - ds, err := b.OpenDownloadStream(fileID) +// The context provided to this method controls the entire lifetime of a +// download stream io.Reader. +func (b *Bucket) DownloadToStream(ctx context.Context, fileID interface{}, stream io.Writer) (int64, error) { + ds, err := b.OpenDownloadStream(ctx, fileID) if err != nil { return 0, err } @@ -225,8 +241,16 @@ func (b *Bucket) DownloadToStream(fileID interface{}, stream io.Writer) (int64, return b.downloadToStream(ds, stream) } -// OpenDownloadStreamByName opens a download stream for the file with the given filename. -func (b *Bucket) OpenDownloadStreamByName(filename string, opts ...*options.NameOptions) (*DownloadStream, error) { +// OpenDownloadStreamByName opens a download stream for the file with the given +// filename. +// +// The context provided to this method controls the entire lifetime of a +// download stream io.Reader. +func (b *Bucket) OpenDownloadStreamByName( + ctx context.Context, + filename string, + opts ...*options.NameOptions, +) (*DownloadStream, error) { var numSkip int32 = -1 var sortOrder int32 = 1 @@ -250,17 +274,27 @@ func (b *Bucket) OpenDownloadStreamByName(filename string, opts ...*options.Name numSkip = (-1 * numSkip) - 1 } - findOpts := options.Find().SetSkip(int64(numSkip)).SetSort(bson.D{{"uploadDate", sortOrder}}) + findOpts := options.FindOne().SetSkip(int64(numSkip)).SetSort(bson.D{{"uploadDate", sortOrder}}) - return b.openDownloadStream(bson.D{{"filename", filename}}, findOpts) + return b.openDownloadStream(ctx, bson.D{{"filename", filename}}, findOpts) } -// DownloadToStreamByName downloads the file with the given name to the given io.Writer. +// DownloadToStreamByName downloads the file with the given name to the given +// io.Writer. +// +// If this download requires a custom read deadline to be set on the bucket, it +// cannot be done concurrently with other read operations operations on this +// bucket that also require a custom deadline. // -// If this download requires a custom read deadline to be set on the bucket, it cannot be done concurrently with other -// read operations operations on this bucket that also require a custom deadline. -func (b *Bucket) DownloadToStreamByName(filename string, stream io.Writer, opts ...*options.NameOptions) (int64, error) { - ds, err := b.OpenDownloadStreamByName(filename, opts...) +// The context provided to this method controls the entire lifetime of a +// download stream io.Reader. +func (b *Bucket) DownloadToStreamByName( + ctx context.Context, + filename string, + stream io.Writer, + opts ...*options.NameOptions, +) (int64, error) { + ds, err := b.OpenDownloadStreamByName(ctx, filename, opts...) if err != nil { return 0, err } @@ -268,25 +302,11 @@ func (b *Bucket) DownloadToStreamByName(filename string, stream io.Writer, opts return b.downloadToStream(ds, stream) } -// Delete deletes all chunks and metadata associated with the file with the given file ID. -// -// If this operation requires a custom write deadline to be set on the bucket, it cannot be done concurrently with other -// write operations operations on this bucket that also require a custom deadline. -// -// Use SetWriteDeadline to set a deadline for the delete operation. -func (b *Bucket) Delete(fileID interface{}) error { - ctx, cancel := deadlineContext(b.writeDeadline) - if cancel != nil { - defer cancel() - } - return b.DeleteContext(ctx, fileID) -} - -// DeleteContext deletes all chunks and metadata associated with the file with the given file ID and runs the underlying +// Delete deletes all chunks and metadata associated with the file with the given file ID and runs the underlying // delete operations with the provided context. // // Use the context parameter to time-out or cancel the delete operation. The deadline set by SetWriteDeadline is ignored. -func (b *Bucket) DeleteContext(ctx context.Context, fileID interface{}) error { +func (b *Bucket) Delete(ctx context.Context, fileID interface{}) error { // If no deadline is set on the passed-in context, Timeout is set on the Client, and context is // not already a Timeout context, honor Timeout in new Timeout context for operation execution to // be shared by both delete operations. @@ -311,27 +331,16 @@ func (b *Bucket) DeleteContext(ctx context.Context, fileID interface{}) error { return b.deleteChunks(ctx, fileID) } -// Find returns the files collection documents that match the given filter. -// -// If this download requires a custom read deadline to be set on the bucket, it cannot be done concurrently with other -// read operations operations on this bucket that also require a custom deadline. -// -// Use SetReadDeadline to set a deadline for the find operation. -func (b *Bucket) Find(filter interface{}, opts ...*options.GridFSFindOptions) (*mongo.Cursor, error) { - ctx, cancel := deadlineContext(b.readDeadline) - if cancel != nil { - defer cancel() - } - - return b.FindContext(ctx, filter, opts...) -} - -// FindContext returns the files collection documents that match the given filter and runs the underlying +// Find returns the files collection documents that match the given filter and runs the underlying // find query with the provided context. // // Use the context parameter to time-out or cancel the find operation. The deadline set by SetReadDeadline // is ignored. -func (b *Bucket) FindContext(ctx context.Context, filter interface{}, opts ...*options.GridFSFindOptions) (*mongo.Cursor, error) { +func (b *Bucket) Find( + ctx context.Context, + filter interface{}, + opts ...*options.GridFSFindOptions, +) (*mongo.Cursor, error) { gfsOpts := options.GridFSFind() for _, opt := range opts { if opt == nil { @@ -391,20 +400,7 @@ func (b *Bucket) FindContext(ctx context.Context, filter interface{}, opts ...*o // write operations operations on this bucket that also require a custom deadline // // Use SetWriteDeadline to set a deadline for the rename operation. -func (b *Bucket) Rename(fileID interface{}, newFilename string) error { - ctx, cancel := deadlineContext(b.writeDeadline) - if cancel != nil { - defer cancel() - } - - return b.RenameContext(ctx, fileID, newFilename) -} - -// RenameContext renames the stored file with the specified file ID and runs the underlying update with the provided -// context. -// -// Use the context parameter to time-out or cancel the rename operation. The deadline set by SetWriteDeadline is ignored. -func (b *Bucket) RenameContext(ctx context.Context, fileID interface{}, newFilename string) error { +func (b *Bucket) Rename(ctx context.Context, fileID interface{}, newFilename string) error { res, err := b.filesColl.UpdateOne(ctx, bson.D{{"_id", fileID}}, bson.D{{"$set", bson.D{{"filename", newFilename}}}}, @@ -420,26 +416,11 @@ func (b *Bucket) RenameContext(ctx context.Context, fileID interface{}, newFilen return nil } -// Drop drops the files and chunks collections associated with this bucket. -// -// If this operation requires a custom write deadline to be set on the bucket, it cannot be done concurrently with other -// write operations operations on this bucket that also require a custom deadline -// -// Use SetWriteDeadline to set a deadline for the drop operation. -func (b *Bucket) Drop() error { - ctx, cancel := deadlineContext(b.writeDeadline) - if cancel != nil { - defer cancel() - } - - return b.DropContext(ctx) -} - -// DropContext drops the files and chunks collections associated with this bucket and runs the drop operations with +// Drop drops the files and chunks collections associated with this bucket and runs the drop operations with // the provided context. // // Use the context parameter to time-out or cancel the drop operation. The deadline set by SetWriteDeadline is ignored. -func (b *Bucket) DropContext(ctx context.Context) error { +func (b *Bucket) Drop(ctx context.Context) error { // If no deadline is set on the passed-in context, Timeout is set on the Client, and context is // not already a Timeout context, honor Timeout in new Timeout context for operation execution to // be shared by both drop operations. @@ -469,33 +450,33 @@ func (b *Bucket) GetChunksCollection() *mongo.Collection { return b.chunksColl } -func (b *Bucket) openDownloadStream(filter interface{}, opts ...*options.FindOptions) (*DownloadStream, error) { - ctx, cancel := deadlineContext(b.readDeadline) - if cancel != nil { - defer cancel() - } - - cursor, err := b.findFile(ctx, filter, opts...) - if err != nil { - return nil, err - } +func (b *Bucket) openDownloadStream( + ctx context.Context, + filter interface{}, + opts ...*options.FindOneOptions, +) (*DownloadStream, error) { + result := b.filesColl.FindOne(ctx, filter, opts...) // Unmarshal the data into a File instance, which can be passed to newDownloadStream. The _id value has to be // parsed out separately because "_id" will not match the File.ID field and we want to avoid exposing BSON tags // in the File type. After parsing it, use RawValue.Unmarshal to ensure File.ID is set to the appropriate value. var resp findFileResponse - if err = cursor.Decode(&resp); err != nil { - return nil, fmt.Errorf("error decoding files collection document: %v", err) + if err := result.Decode(&resp); err != nil { + if errors.Is(err, mongo.ErrNoDocuments) { + return nil, ErrFileNotFound + } + + return nil, fmt.Errorf("error decoding files collection document: %w", err) } foundFile := newFileFromResponse(resp) if foundFile.Length == 0 { - return newDownloadStream(nil, foundFile.ChunkSize, foundFile), nil + return newDownloadStream(ctx, nil, foundFile.ChunkSize, foundFile), nil } // For a file with non-zero length, chunkSize must exist so we know what size to expect when downloading chunks. - if _, err := cursor.Current.LookupErr("chunkSize"); err != nil { + if foundFile.ChunkSize == 0 { return nil, ErrMissingChunkSize } @@ -505,24 +486,10 @@ func (b *Bucket) openDownloadStream(filter interface{}, opts ...*options.FindOpt } // The chunk size can be overridden for individual files, so the expected chunk size should be the "chunkSize" // field from the files collection document, not the bucket's chunk size. - return newDownloadStream(chunksCursor, foundFile.ChunkSize, foundFile), nil -} - -func deadlineContext(deadline time.Time) (context.Context, context.CancelFunc) { - if deadline.Equal(time.Time{}) { - return context.Background(), nil - } - - return context.WithDeadline(context.Background(), deadline) + return newDownloadStream(ctx, chunksCursor, foundFile.ChunkSize, foundFile), nil } func (b *Bucket) downloadToStream(ds *DownloadStream, stream io.Writer) (int64, error) { - err := ds.SetReadDeadline(b.readDeadline) - if err != nil { - _ = ds.Close() - return 0, err - } - copied, err := io.Copy(stream, ds) if err != nil { _ = ds.Close() @@ -537,20 +504,6 @@ func (b *Bucket) deleteChunks(ctx context.Context, fileID interface{}) error { return err } -func (b *Bucket) findFile(ctx context.Context, filter interface{}, opts ...*options.FindOptions) (*mongo.Cursor, error) { - cursor, err := b.filesColl.Find(ctx, filter, opts...) - if err != nil { - return nil, err - } - - if !cursor.Next(ctx) { - _ = cursor.Close(ctx) - return nil, ErrFileNotFound - } - - return cursor, nil -} - func (b *Bucket) findChunks(ctx context.Context, fileID interface{}) (*mongo.Cursor, error) { chunksCursor, err := b.chunksColl.Find(ctx, bson.D{{"files_id", fileID}}, diff --git a/mongo/gridfs/bucket_test.go b/mongo/gridfs/bucket_test.go new file mode 100644 index 0000000000..0bff0ed871 --- /dev/null +++ b/mongo/gridfs/bucket_test.go @@ -0,0 +1,55 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package gridfs + +import ( + "context" + "testing" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/integtest" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +func TestBucket_openDownloadStream(t *testing.T) { + tests := []struct { + name string + filter interface{} + err error + }{ + { + name: "nil filter", + filter: nil, + err: mongo.ErrNilDocument, + }, + { + name: "nonmatching filter", + filter: bson.D{{"x", 1}}, + err: ErrFileNotFound, + }, + } + + cs := integtest.ConnString(t) + clientOpts := options.Client().ApplyURI(cs.Original) + + client, err := mongo.Connect(context.Background(), clientOpts) + assert.Nil(t, err, "Connect error: %v", err) + + db := client.Database("bucket") + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + bucket, err := NewBucket(db) + assert.NoError(t, err) + + _, err = bucket.openDownloadStream(context.Background(), test.filter) + assert.ErrorIs(t, err, test.err) + }) + } +} diff --git a/mongo/gridfs/download_stream.go b/mongo/gridfs/download_stream.go index 0a918542b3..41f5fb686c 100644 --- a/mongo/gridfs/download_stream.go +++ b/mongo/gridfs/download_stream.go @@ -37,8 +37,8 @@ type DownloadStream struct { bufferStart int bufferEnd int expectedChunk int32 // index of next expected chunk - readDeadline time.Time fileLen int64 + ctx context.Context // The pointer returned by GetFile. This should not be used in the actual DownloadStream code outside of the // newDownloadStream constructor because the values can be mutated by the user after calling GetFile. Instead, @@ -94,7 +94,7 @@ func newFileFromResponse(resp findFileResponse) *File { } } -func newDownloadStream(cursor *mongo.Cursor, chunkSize int32, file *File) *DownloadStream { +func newDownloadStream(ctx context.Context, cursor *mongo.Cursor, chunkSize int32, file *File) *DownloadStream { numChunks := int32(math.Ceil(float64(file.Length) / float64(chunkSize))) return &DownloadStream{ @@ -105,6 +105,7 @@ func newDownloadStream(cursor *mongo.Cursor, chunkSize int32, file *File) *Downl done: cursor == nil, fileLen: file.Length, file: file, + ctx: ctx, } } @@ -121,16 +122,6 @@ func (ds *DownloadStream) Close() error { return nil } -// SetReadDeadline sets the read deadline for this download stream. -func (ds *DownloadStream) SetReadDeadline(t time.Time) error { - if ds.closed { - return ErrStreamClosed - } - - ds.readDeadline = t - return nil -} - // Read reads the file from the server and writes it to a destination byte slice. func (ds *DownloadStream) Read(p []byte) (int, error) { if ds.closed { @@ -141,17 +132,12 @@ func (ds *DownloadStream) Read(p []byte) (int, error) { return 0, io.EOF } - ctx, cancel := deadlineContext(ds.readDeadline) - if cancel != nil { - defer cancel() - } - bytesCopied := 0 var err error for bytesCopied < len(p) { if ds.bufferStart >= ds.bufferEnd { // Buffer is empty and can load in data from new chunk. - err = ds.fillBuffer(ctx) + err = ds.fillBuffer(ds.ctx) if err != nil { if err == errNoMoreChunks { if bytesCopied == 0 { @@ -183,18 +169,13 @@ func (ds *DownloadStream) Skip(skip int64) (int64, error) { return 0, nil } - ctx, cancel := deadlineContext(ds.readDeadline) - if cancel != nil { - defer cancel() - } - var skipped int64 var err error for skipped < skip { if ds.bufferStart >= ds.bufferEnd { // Buffer is empty and can load in data from new chunk. - err = ds.fillBuffer(ctx) + err = ds.fillBuffer(ds.ctx) if err != nil { if err == errNoMoreChunks { return skipped, nil diff --git a/mongo/gridfs/gridfs_examples_test.go b/mongo/gridfs/gridfs_examples_test.go index 7203444dd4..7345f2f43e 100644 --- a/mongo/gridfs/gridfs_examples_test.go +++ b/mongo/gridfs/gridfs_examples_test.go @@ -28,7 +28,13 @@ func ExampleBucket_OpenUploadStream() { // collection document. uploadOpts := options.GridFSUpload(). SetMetadata(bson.D{{"metadata tag", "tag"}}) - uploadStream, err := bucket.OpenUploadStream("filename", uploadOpts) + + // Use WithContext to force a timeout if the upload does not succeed in + // 2 seconds. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + uploadStream, err := bucket.OpenUploadStream(ctx, "filename", uploadOpts) if err != nil { log.Fatal(err) } @@ -38,13 +44,6 @@ func ExampleBucket_OpenUploadStream() { } }() - // Use SetWriteDeadline to force a timeout if the upload does not succeed in - // 2 seconds. - err = uploadStream.SetWriteDeadline(time.Now().Add(2 * time.Second)) - if err != nil { - log.Fatal(err) - } - if _, err = uploadStream.Write(fileContent); err != nil { log.Fatal(err) } @@ -59,6 +58,7 @@ func ExampleBucket_UploadFromStream() { uploadOpts := options.GridFSUpload(). SetMetadata(bson.D{{"metadata tag", "tag"}}) fileID, err := bucket.UploadFromStream( + context.Background(), "filename", bytes.NewBuffer(fileContent), uploadOpts) @@ -73,7 +73,12 @@ func ExampleBucket_OpenDownloadStream() { var bucket *gridfs.Bucket var fileID primitive.ObjectID - downloadStream, err := bucket.OpenDownloadStream(fileID) + // Use WithContext to force a timeout if the download does not succeed in + // 2 seconds. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + downloadStream, err := bucket.OpenDownloadStream(ctx, fileID) if err != nil { log.Fatal(err) } @@ -83,13 +88,6 @@ func ExampleBucket_OpenDownloadStream() { } }() - // Use SetReadDeadline to force a timeout if the download does not succeed - // in 2 seconds. - err = downloadStream.SetReadDeadline(time.Now().Add(2 * time.Second)) - if err != nil { - log.Fatal(err) - } - fileBuffer := bytes.NewBuffer(nil) if _, err := io.Copy(fileBuffer, downloadStream); err != nil { log.Fatal(err) @@ -100,8 +98,10 @@ func ExampleBucket_DownloadToStream() { var bucket *gridfs.Bucket var fileID primitive.ObjectID + ctx := context.Background() + fileBuffer := bytes.NewBuffer(nil) - if _, err := bucket.DownloadToStream(fileID, fileBuffer); err != nil { + if _, err := bucket.DownloadToStream(ctx, fileID, fileBuffer); err != nil { log.Fatal(err) } } @@ -110,7 +110,7 @@ func ExampleBucket_Delete() { var bucket *gridfs.Bucket var fileID primitive.ObjectID - if err := bucket.Delete(fileID); err != nil { + if err := bucket.Delete(context.Background(), fileID); err != nil { log.Fatal(err) } } @@ -122,7 +122,7 @@ func ExampleBucket_Find() { filter := bson.D{ {"length", bson.D{{"$gt", 1000}}}, } - cursor, err := bucket.Find(filter) + cursor, err := bucket.Find(context.Background(), filter) if err != nil { log.Fatal(err) } @@ -150,7 +150,9 @@ func ExampleBucket_Rename() { var bucket *gridfs.Bucket var fileID primitive.ObjectID - if err := bucket.Rename(fileID, "new file name"); err != nil { + ctx := context.Background() + + if err := bucket.Rename(ctx, fileID, "new file name"); err != nil { log.Fatal(err) } } @@ -158,7 +160,7 @@ func ExampleBucket_Rename() { func ExampleBucket_Drop() { var bucket *gridfs.Bucket - if err := bucket.Drop(); err != nil { + if err := bucket.Drop(context.Background()); err != nil { log.Fatal(err) } } diff --git a/mongo/gridfs/gridfs_test.go b/mongo/gridfs/gridfs_test.go index a0add659f5..ea9d39efe4 100644 --- a/mongo/gridfs/gridfs_test.go +++ b/mongo/gridfs/gridfs_test.go @@ -81,7 +81,7 @@ func TestGridFS(t *testing.T) { bucket, err := NewBucket(db, tt.bucketOpts) assert.Nil(t, err, "NewBucket error: %v", err) - us, err := bucket.OpenUploadStream("filename", tt.uploadOpts) + us, err := bucket.OpenUploadStream(context.Background(), "filename", tt.uploadOpts) assert.Nil(t, err, "OpenUploadStream error: %v", err) expectedBucketChunkSize := DefaultChunkSize diff --git a/mongo/gridfs/upload_stream.go b/mongo/gridfs/upload_stream.go index cf1997db80..fbcd0be7ef 100644 --- a/mongo/gridfs/upload_stream.go +++ b/mongo/gridfs/upload_stream.go @@ -33,19 +33,25 @@ type UploadStream struct { *Upload // chunk size and metadata FileID interface{} - chunkIndex int - chunksColl *mongo.Collection // collection to store file chunks - filename string - filesColl *mongo.Collection // collection to store file metadata - closed bool - buffer []byte - bufferIndex int - fileLen int64 - writeDeadline time.Time + chunkIndex int + chunksColl *mongo.Collection // collection to store file chunks + filename string + filesColl *mongo.Collection // collection to store file metadata + closed bool + buffer []byte + bufferIndex int + fileLen int64 + ctx context.Context } // NewUploadStream creates a new upload stream. -func newUploadStream(upload *Upload, fileID interface{}, filename string, chunks, files *mongo.Collection) *UploadStream { +func newUploadStream( + ctx context.Context, + upload *Upload, + fileID interface{}, + filename string, + chunks, files *mongo.Collection, +) *UploadStream { return &UploadStream{ Upload: upload, FileID: fileID, @@ -54,6 +60,7 @@ func newUploadStream(upload *Upload, fileID interface{}, filename string, chunks filename: filename, filesColl: files, buffer: make([]byte, UploadBufferSize), + ctx: ctx, } } @@ -63,18 +70,13 @@ func (us *UploadStream) Close() error { return ErrStreamClosed } - ctx, cancel := deadlineContext(us.writeDeadline) - if cancel != nil { - defer cancel() - } - if us.bufferIndex != 0 { - if err := us.uploadChunks(ctx, true); err != nil { + if err := us.uploadChunks(us.ctx, true); err != nil { return err } } - if err := us.createFilesCollDoc(ctx); err != nil { + if err := us.createFilesCollDoc(us.ctx); err != nil { return err } @@ -82,16 +84,6 @@ func (us *UploadStream) Close() error { return nil } -// SetWriteDeadline sets the write deadline for this stream. -func (us *UploadStream) SetWriteDeadline(t time.Time) error { - if us.closed { - return ErrStreamClosed - } - - us.writeDeadline = t - return nil -} - // Write transfers the contents of a byte slice into this upload stream. If the stream's underlying buffer fills up, // the buffer will be uploaded as chunks to the server. Implements the io.Writer interface. func (us *UploadStream) Write(p []byte) (int, error) { @@ -99,13 +91,6 @@ func (us *UploadStream) Write(p []byte) (int, error) { return 0, ErrStreamClosed } - var ctx context.Context - - ctx, cancel := deadlineContext(us.writeDeadline) - if cancel != nil { - defer cancel() - } - origLen := len(p) for { if len(p) == 0 { @@ -117,7 +102,7 @@ func (us *UploadStream) Write(p []byte) (int, error) { us.bufferIndex += n if us.bufferIndex == UploadBufferSize { - err := us.uploadChunks(ctx, false) + err := us.uploadChunks(us.ctx, false) if err != nil { return 0, err } @@ -132,12 +117,7 @@ func (us *UploadStream) Abort() error { return ErrStreamClosed } - ctx, cancel := deadlineContext(us.writeDeadline) - if cancel != nil { - defer cancel() - } - - _, err := us.chunksColl.DeleteMany(ctx, bson.D{{"files_id", us.FileID}}) + _, err := us.chunksColl.DeleteMany(us.ctx, bson.D{{"files_id", us.FileID}}) if err != nil { return err } diff --git a/mongo/integration/crud_helpers_test.go b/mongo/integration/crud_helpers_test.go index 6f62230425..c6a06e338e 100644 --- a/mongo/integration/crud_helpers_test.go +++ b/mongo/integration/crud_helpers_test.go @@ -1233,7 +1233,7 @@ func executeGridFSDownload(mt *mtest.T, bucket *gridfs.Bucket, args bson.Raw) (i } } - return bucket.DownloadToStream(fileID, new(bytes.Buffer)) + return bucket.DownloadToStream(context.Background(), fileID, new(bytes.Buffer)) } func executeGridFSDownloadByName(mt *mtest.T, bucket *gridfs.Bucket, args bson.Raw) (int64, error) { @@ -1253,7 +1253,7 @@ func executeGridFSDownloadByName(mt *mtest.T, bucket *gridfs.Bucket, args bson.R } } - return bucket.DownloadToStreamByName(file, new(bytes.Buffer)) + return bucket.DownloadToStreamByName(context.Background(), file, new(bytes.Buffer)) } func executeCreateIndex(mt *mtest.T, sess mongo.Session, args bson.Raw) (string, error) { diff --git a/mongo/integration/gridfs_test.go b/mongo/integration/gridfs_test.go index e5a8d735ca..3401796e11 100644 --- a/mongo/integration/gridfs_test.go +++ b/mongo/integration/gridfs_test.go @@ -76,7 +76,7 @@ func TestGridFS(x *testing.T) { bucket, err := gridfs.NewBucket(mt.DB, options.GridFSBucket().SetChunkSizeBytes(chunkSize)) assert.Nil(mt, err, "NewBucket error: %v", err) - ustream, err := bucket.OpenUploadStream("foo") + ustream, err := bucket.OpenUploadStream(context.Background(), "foo") assert.Nil(mt, err, "OpenUploadStream error: %v", err) id := ustream.FileID @@ -85,7 +85,7 @@ func TestGridFS(x *testing.T) { err = ustream.Close() assert.Nil(mt, err, "Close error: %v", err) - dstream, err := bucket.OpenDownloadStream(id) + dstream, err := bucket.OpenDownloadStream(context.Background(), id) assert.Nil(mt, err, "OpenDownloadStream error") dst := make([]byte, tc.read) _, err = dstream.Read(dst) @@ -110,17 +110,19 @@ func TestGridFS(x *testing.T) { // Unit tests showing that UploadFromStream creates indexes on the chunks and files collections. bucket, err := gridfs.NewBucket(mt.DB) assert.Nil(mt, err, "NewBucket error: %v", err) - err = bucket.SetWriteDeadline(time.Now().Add(5 * time.Second)) - assert.Nil(mt, err, "SetWriteDeadline error: %v", err) byteData := []byte("Hello, world!") r := bytes.NewReader(byteData) - _, err = bucket.UploadFromStream("filename", r) + uploadCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + mt.Cleanup(cancel) + + _, err = bucket.UploadFromStream(uploadCtx, "filename", r) assert.Nil(mt, err, "UploadFromStream error: %v", err) findCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() + mt.Cleanup(cancel) + findIndex(findCtx, mt, mt.DB.Collection("fs.files"), false, "key", "filename") findIndex(findCtx, mt, mt.DB.Collection("fs.chunks"), true, "key", "files_id") }) @@ -188,10 +190,10 @@ func TestGridFS(x *testing.T) { bucket, err := gridfs.NewBucket(mt.DB) assert.Nil(mt, err, "NewBucket error: %v", err) defer func() { - _ = bucket.Drop() + _ = bucket.Drop(context.Background()) }() - _, err = bucket.OpenUploadStream("filename") + _, err = bucket.OpenUploadStream(context.Background(), "filename") assert.Nil(mt, err, "OpenUploadStream error: %v", err) mt.FilterStartedEvents(func(evt *event.CommandStartedEvent) bool { @@ -235,10 +237,10 @@ func TestGridFS(x *testing.T) { bucket, err := gridfs.NewBucket(mt.DB) assert.Nil(mt, err, "NewBucket error: %v", err) defer func() { - _ = bucket.Drop() + _ = bucket.Drop(context.Background()) }() - _, err = bucket.UploadFromStream("filename", bytes.NewBuffer(fileContent)) + _, err = bucket.UploadFromStream(context.Background(), "filename", bytes.NewBuffer(fileContent)) assert.Nil(mt, err, "UploadFromStream error: %v", err) mt.FilterStartedEvents(func(evt *event.CommandStartedEvent) bool { @@ -282,15 +284,15 @@ func TestGridFS(x *testing.T) { // Create a new GridFS bucket. bucket, err := gridfs.NewBucket(mt.DB) assert.Nil(mt, err, "NewBucket error: %v", err) - defer func() { _ = bucket.Drop() }() + defer func() { _ = bucket.Drop(context.Background()) }() // Upload the file and store the uploaded file ID. uploadedFileID := tc.fileID dataReader := bytes.NewReader(fileData) if uploadedFileID == nil { - uploadedFileID, err = bucket.UploadFromStream(fileName, dataReader, uploadOpts) + uploadedFileID, err = bucket.UploadFromStream(context.Background(), fileName, dataReader, uploadOpts) } else { - err = bucket.UploadFromStreamWithID(tc.fileID, fileName, dataReader, uploadOpts) + err = bucket.UploadFromStreamWithID(context.Background(), tc.fileID, fileName, dataReader, uploadOpts) } assert.Nil(mt, err, "error uploading file: %v", err) @@ -312,13 +314,13 @@ func TestGridFS(x *testing.T) { // For both methods that create a DownloadStream, open a stream and compare the file given by the // stream to the expected File object. mt.RunOpts("OpenDownloadStream", noClientOpts, func(mt *mtest.T) { - downloadStream, err := bucket.OpenDownloadStream(uploadedFileID) + downloadStream, err := bucket.OpenDownloadStream(context.Background(), uploadedFileID) assert.Nil(mt, err, "OpenDownloadStream error: %v", err) actualFile := downloadStream.GetFile() assert.Equal(mt, expectedFile, actualFile, "expected file %v, got %v", expectedFile, actualFile) }) mt.RunOpts("OpenDownloadStreamByName", noClientOpts, func(mt *mtest.T) { - downloadStream, err := bucket.OpenDownloadStreamByName(fileName) + downloadStream, err := bucket.OpenDownloadStreamByName(context.Background(), fileName) assert.Nil(mt, err, "OpenDownloadStream error: %v", err) actualFile := downloadStream.GetFile() assert.Equal(mt, expectedFile, actualFile, "expected file %v, got %v", expectedFile, actualFile) @@ -332,17 +334,17 @@ func TestGridFS(x *testing.T) { bucket, err := gridfs.NewBucket(mt.DB) assert.Nil(mt, err, "NewBucket error: %v", err) - defer func() { _ = bucket.Drop() }() + defer func() { _ = bucket.Drop(context.Background()) }() fileData := []byte("hello world") uploadOpts := options.GridFSUpload().SetChunkSizeBytes(4) - fileID, err := bucket.UploadFromStream("file", bytes.NewReader(fileData), uploadOpts) + fileID, err := bucket.UploadFromStream(context.Background(), "file", bytes.NewReader(fileData), uploadOpts) assert.Nil(mt, err, "UploadFromStream error: %v", err) // If the bucket's chunk size was used, this would error because the actual chunk size is 4 and the bucket // chunk size is 255 KB. var downloadBuffer bytes.Buffer - _, err = bucket.DownloadToStream(fileID, &downloadBuffer) + _, err = bucket.DownloadToStream(context.Background(), fileID, &downloadBuffer) assert.Nil(mt, err, "DownloadToStream error: %v", err) downloadedBytes := downloadBuffer.Bytes() @@ -363,9 +365,9 @@ func TestGridFS(x *testing.T) { bucket, err := gridfs.NewBucket(mt.DB) assert.Nil(mt, err, "NewBucket error: %v", err) - defer func() { _ = bucket.Drop() }() + defer func() { _ = bucket.Drop(context.Background()) }() - _, err = bucket.OpenDownloadStream(oid) + _, err = bucket.OpenDownloadStream(context.Background(), oid) assert.Equal(mt, gridfs.ErrMissingChunkSize, err, "expected error %v, got %v", gridfs.ErrMissingChunkSize, err) }) mt.Run("cursor error during read after downloading", func(mt *mtest.T) { @@ -378,22 +380,23 @@ func TestGridFS(x *testing.T) { bucket, err := gridfs.NewBucket(mt.DB) assert.Nil(mt, err, "NewBucket error: %v", err) - defer func() { _ = bucket.Drop() }() + defer func() { _ = bucket.Drop(context.Background()) }() dataReader := bytes.NewReader(fileData) - _, err = bucket.UploadFromStream(fileName, dataReader) + _, err = bucket.UploadFromStream(context.Background(), fileName, dataReader) assert.Nil(mt, err, "UploadFromStream error: %v", err) - ds, err := bucket.OpenDownloadStreamByName(fileName) - assert.Nil(mt, err, "OpenDownloadStreamByName error: %v", err) + ctx, cancel := context.WithCancel(context.Background()) + + ds, err := bucket.OpenDownloadStreamByName(ctx, fileName) + assert.NoError(mt, err, "OpenDownloadStreamByName error: %v", err) - err = ds.SetReadDeadline(time.Now().Add(-1 * time.Second)) - assert.Nil(mt, err, "SetReadDeadline error: %v", err) + cancel() p := make([]byte, len(fileData)) _, err = ds.Read(p) assert.NotNil(mt, err, "expected error from Read, got nil") - assert.True(mt, mongo.IsTimeout(err), "expected error to be a timeout, got %v", err.Error()) + assert.ErrorIs(mt, context.Canceled, err) }) mt.Run("cursor error during skip after downloading", func(mt *mtest.T) { // To simulate a cursor error we upload a file larger than the 16MB default batch size, @@ -405,21 +408,22 @@ func TestGridFS(x *testing.T) { bucket, err := gridfs.NewBucket(mt.DB) assert.Nil(mt, err, "NewBucket error: %v", err) - defer func() { _ = bucket.Drop() }() + defer func() { _ = bucket.Drop(context.Background()) }() dataReader := bytes.NewReader(fileData) - _, err = bucket.UploadFromStream(fileName, dataReader) + _, err = bucket.UploadFromStream(context.Background(), fileName, dataReader) assert.Nil(mt, err, "UploadFromStream error: %v", err) - ds, err := bucket.OpenDownloadStreamByName(fileName) + ctx, cancel := context.WithCancel(context.Background()) + + ds, err := bucket.OpenDownloadStreamByName(ctx, fileName) assert.Nil(mt, err, "OpenDownloadStreamByName error: %v", err) - err = ds.SetReadDeadline(time.Now().Add(-1 * time.Second)) - assert.Nil(mt, err, "SetReadDeadline error: %v", err) + cancel() _, err = ds.Skip(int64(len(fileData))) assert.NotNil(mt, err, "expected error from Skip, got nil") - assert.True(mt, mongo.IsTimeout(err), "expected error to be a timeout, got %v", err.Error()) + assert.ErrorIs(mt, context.Canceled, err) }) }) @@ -444,9 +448,9 @@ func TestGridFS(x *testing.T) { } bucket, err := gridfs.NewBucket(mt.DB, bucketOpts) assert.Nil(mt, err, "NewBucket error: %v", err) - defer func() { _ = bucket.Drop() }() + defer func() { _ = bucket.Drop(context.Background()) }() - _, err = bucket.UploadFromStream("accessors-test-file", bytes.NewReader(fileData)) + _, err = bucket.UploadFromStream(context.Background(), "accessors-test-file", bytes.NewReader(fileData)) assert.Nil(mt, err, "UploadFromStream error: %v", err) bucketName := tc.bucketName @@ -497,9 +501,6 @@ func TestGridFS(x *testing.T) { timeout = 20 * time.Second // race detector causes 2-20x slowdown } - err = bucket.SetWriteDeadline(time.Now().Add(timeout)) - assert.Nil(mt, err, "SetWriteDeadline error: %v", err) - // Test that Upload works when the buffer to write is longer than the upload stream's internal buffer. // This requires multiple calls to uploadChunks. size := test.fileSize @@ -508,7 +509,10 @@ func TestGridFS(x *testing.T) { p[i] = byte(rand.Intn(100)) } - _, err = bucket.UploadFromStream("filename", bytes.NewReader(p)) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + mt.Cleanup(cancel) + + _, err = bucket.UploadFromStream(ctx, "filename", bytes.NewReader(p)) assert.Nil(mt, err, "UploadFromStream error: %v", err) var w *bytes.Buffer @@ -518,7 +522,7 @@ func TestGridFS(x *testing.T) { w = bytes.NewBuffer(make([]byte, 0, test.bufSize)) } - _, err = bucket.DownloadToStreamByName("filename", w) + _, err = bucket.DownloadToStreamByName(ctx, "filename", w) assert.Nil(mt, err, "DownloadToStreamByName error: %v", err) assert.Equal(mt, p, w.Bytes(), "downloaded file did not match p") }) @@ -530,7 +534,7 @@ func TestGridFS(x *testing.T) { bucket, err := gridfs.NewBucket(mt.DB) assert.Nil(mt, err, "NewBucket error: %v", err) // Find the file back. - cursor, err := bucket.Find(bson.D{{"foo", "bar"}}) + cursor, err := bucket.Find(context.Background(), bson.D{{"foo", "bar"}}) defer func() { _ = cursor.Close(context.Background()) }() diff --git a/mongo/integration/unified/gridfs_bucket_operation_execution.go b/mongo/integration/unified/gridfs_bucket_operation_execution.go index 3be6fded0c..d9f21cbc7b 100644 --- a/mongo/integration/unified/gridfs_bucket_operation_execution.go +++ b/mongo/integration/unified/gridfs_bucket_operation_execution.go @@ -50,7 +50,7 @@ func createBucketFindCursor(ctx context.Context, operation *operation) (*cursorR return nil, newMissingArgumentError("filter") } - cursor, err := bucket.FindContext(ctx, filter, opts) + cursor, err := bucket.Find(ctx, filter, opts) res := &cursorResult{ cursor: cursor, err: err, @@ -85,7 +85,7 @@ func executeBucketDelete(ctx context.Context, operation *operation) (*operationR return nil, newMissingArgumentError("id") } - return newErrorResult(bucket.DeleteContext(ctx, *id)), nil + return newErrorResult(bucket.Delete(ctx, *id)), nil } func executeBucketDownload(ctx context.Context, operation *operation) (*operationResult, error) { @@ -114,7 +114,7 @@ func executeBucketDownload(ctx context.Context, operation *operation) (*operatio return nil, newMissingArgumentError("id") } - stream, err := bucket.OpenDownloadStream(*id) + stream, err := bucket.OpenDownloadStream(ctx, *id) if err != nil { return newErrorResult(err), nil } @@ -158,7 +158,7 @@ func executeBucketDownloadByName(ctx context.Context, operation *operation) (*op } var buf bytes.Buffer - _, err = bucket.DownloadToStreamByName(filename, &buf, opts) + _, err = bucket.DownloadToStreamByName(ctx, filename, &buf, opts) if err != nil { return newErrorResult(err), nil } @@ -172,7 +172,7 @@ func executeBucketDrop(ctx context.Context, operation *operation) (*operationRes return nil, err } - return newErrorResult(bucket.DropContext(ctx)), nil + return newErrorResult(bucket.Drop(ctx)), nil } func executeBucketRename(ctx context.Context, operation *operation) (*operationResult, error) { @@ -204,7 +204,7 @@ func executeBucketRename(ctx context.Context, operation *operation) (*operationR return nil, newMissingArgumentError("id") } - return newErrorResult(bucket.RenameContext(ctx, id, newFilename)), nil + return newErrorResult(bucket.Rename(ctx, id, newFilename)), nil } func executeBucketUpload(ctx context.Context, operation *operation) (*operationResult, error) { @@ -252,7 +252,7 @@ func executeBucketUpload(ctx context.Context, operation *operation) (*operationR return nil, newMissingArgumentError("source") } - fileID, err := bucket.UploadFromStream(filename, bytes.NewReader(fileBytes), opts) + fileID, err := bucket.UploadFromStream(ctx, filename, bytes.NewReader(fileBytes), opts) if err != nil { return newErrorResult(err), nil }