Skip to content

Add support for AggregationUpdate to BulkOperations. #4344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.1.0-SNAPSHOT</version>
<version>4.1.x-3872-SNAPSHOT</version>
<packaging>pom</packaging>

<name>Spring Data MongoDB</name>
Expand Down
2 changes: 1 addition & 1 deletion spring-data-mongodb-benchmarks/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.1.0-SNAPSHOT</version>
<version>4.1.x-3872-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
2 changes: 1 addition & 1 deletion spring-data-mongodb-distribution/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.1.0-SNAPSHOT</version>
<version>4.1.x-3872-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
2 changes: 1 addition & 1 deletion spring-data-mongodb/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.1.0-SNAPSHOT</version>
<version>4.1.x-3872-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.Update;
import org.springframework.data.mongodb.core.query.UpdateDefinition;
import org.springframework.data.util.Pair;

import com.mongodb.bulk.BulkWriteResult;
Expand Down Expand Up @@ -75,7 +76,19 @@ enum BulkMode {
* @param update {@link Update} operation to perform, must not be {@literal null}.
* @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}.
*/
BulkOperations updateOne(Query query, Update update);
default BulkOperations updateOne(Query query, Update update) {
return updateOne(query, (UpdateDefinition) update);
}

/**
* Add a single update to the bulk operation. For the update request, only the first matching document is updated.
*
* @param query update criteria, must not be {@literal null}.
* @param update {@link Update} operation to perform, must not be {@literal null}.
* @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}.
* @since 4.1
*/
BulkOperations updateOne(Query query, UpdateDefinition update);

/**
* Add a list of updates to the bulk operation. For each update request, only the first matching document is updated.
Expand All @@ -92,7 +105,19 @@ enum BulkMode {
* @param update Update operation to perform.
* @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}.
*/
BulkOperations updateMulti(Query query, Update update);
default BulkOperations updateMulti(Query query, Update update) {
return updateMulti(query, (UpdateDefinition) update);
}

/**
* Add a single update to the bulk operation. For the update request, all matching documents are updated.
*
* @param query Update criteria.
* @param update Update operation to perform.
* @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}.
* @since 4.1
*/
BulkOperations updateMulti(Query query, UpdateDefinition update);

/**
* Add a list of updates to the bulk operation. For each update request, all matching documents are updated.
Expand All @@ -110,7 +135,20 @@ enum BulkMode {
* @param update Update operation to perform.
* @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}.
*/
BulkOperations upsert(Query query, Update update);
default BulkOperations upsert(Query query, Update update) {
return upsert(query, (UpdateDefinition) update);
}

/**
* Add a single upsert to the bulk operation. An upsert is an update if the set of matching documents is not empty,
* else an insert.
*
* @param query Update criteria.
* @param update Update operation to perform.
* @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}.
* @since 4.1
*/
BulkOperations upsert(Query query, UpdateDefinition update);

/**
* Add a list of upserts to the bulk operation. An upsert is an update if the set of matching documents is not empty,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
package org.springframework.data.mongodb.core;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
Expand All @@ -25,8 +24,12 @@
import org.bson.conversions.Bson;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.dao.DataIntegrityViolationException;
import org.springframework.data.mapping.PersistentEntity;
import org.springframework.data.mapping.callback.EntityCallbacks;
import org.springframework.data.mongodb.BulkOperationException;
import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext;
import org.springframework.data.mongodb.core.aggregation.AggregationUpdate;
import org.springframework.data.mongodb.core.aggregation.RelaxedTypeBasedAggregationOperationContext;
import org.springframework.data.mongodb.core.convert.QueryMapper;
import org.springframework.data.mongodb.core.convert.UpdateMapper;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
Expand Down Expand Up @@ -133,12 +136,12 @@ public BulkOperations insert(List<? extends Object> documents) {

@Override
@SuppressWarnings("unchecked")
public BulkOperations updateOne(Query query, Update update) {
public BulkOperations updateOne(Query query, UpdateDefinition update) {

Assert.notNull(query, "Query must not be null");
Assert.notNull(update, "Update must not be null");

return updateOne(Collections.singletonList(Pair.of(query, update)));
return update(query, update, false, false);
}

@Override
Expand All @@ -155,12 +158,14 @@ public BulkOperations updateOne(List<Pair<Query, Update>> updates) {

@Override
@SuppressWarnings("unchecked")
public BulkOperations updateMulti(Query query, Update update) {
public BulkOperations updateMulti(Query query, UpdateDefinition update) {

Assert.notNull(query, "Query must not be null");
Assert.notNull(update, "Update must not be null");

return updateMulti(Collections.singletonList(Pair.of(query, update)));
update(query, update, false, true);

return this;
}

@Override
Expand All @@ -176,7 +181,7 @@ public BulkOperations updateMulti(List<Pair<Query, Update>> updates) {
}

@Override
public BulkOperations upsert(Query query, Update update) {
public BulkOperations upsert(Query query, UpdateDefinition update) {
return update(query, update, true, true);
}

Expand Down Expand Up @@ -294,7 +299,7 @@ private WriteModel<Document> extractAndMapWriteModel(SourceAwareWriteModelHolder
maybeInvokeBeforeSaveCallback(it.getSource(), target);
}

return mapWriteModel(it.getModel());
return mapWriteModel(it.getSource(), it.getModel());
}

/**
Expand All @@ -306,7 +311,7 @@ private WriteModel<Document> extractAndMapWriteModel(SourceAwareWriteModelHolder
* @param multi whether to issue a multi-update.
* @return the {@link BulkOperations} with the update registered.
*/
private BulkOperations update(Query query, Update update, boolean upsert, boolean multi) {
private BulkOperations update(Query query, UpdateDefinition update, boolean upsert, boolean multi) {

Assert.notNull(query, "Query must not be null");
Assert.notNull(update, "Update must not be null");
Expand All @@ -322,11 +327,16 @@ private BulkOperations update(Query query, Update update, boolean upsert, boolea
return this;
}

private WriteModel<Document> mapWriteModel(WriteModel<Document> writeModel) {
private WriteModel<Document> mapWriteModel(Object source, WriteModel<Document> writeModel) {

if (writeModel instanceof UpdateOneModel) {

UpdateOneModel<Document> model = (UpdateOneModel<Document>) writeModel;
if (source instanceof AggregationUpdate aggregationUpdate) {

List<Document> pipeline = mapUpdatePipeline(aggregationUpdate);
return new UpdateOneModel<>(getMappedQuery(model.getFilter()), pipeline, model.getOptions());
}

return new UpdateOneModel<>(getMappedQuery(model.getFilter()), getMappedUpdate(model.getUpdate()),
model.getOptions());
Expand All @@ -335,6 +345,11 @@ private WriteModel<Document> mapWriteModel(WriteModel<Document> writeModel) {
if (writeModel instanceof UpdateManyModel) {

UpdateManyModel<Document> model = (UpdateManyModel<Document>) writeModel;
if (source instanceof AggregationUpdate aggregationUpdate) {

List<Document> pipeline = mapUpdatePipeline(aggregationUpdate);
return new UpdateManyModel<>(getMappedQuery(model.getFilter()), pipeline, model.getOptions());
}

return new UpdateManyModel<>(getMappedQuery(model.getFilter()), getMappedUpdate(model.getUpdate()),
model.getOptions());
Expand All @@ -357,6 +372,19 @@ private WriteModel<Document> mapWriteModel(WriteModel<Document> writeModel) {
return writeModel;
}

private List<Document> mapUpdatePipeline(AggregationUpdate source) {
Class<?> type = bulkOperationContext.getEntity().isPresent()
? bulkOperationContext.getEntity().map(PersistentEntity::getType).get()
: Object.class;
AggregationOperationContext context = new RelaxedTypeBasedAggregationOperationContext(type,
bulkOperationContext.getUpdateMapper().getMappingContext(), bulkOperationContext.getQueryMapper());

List<Document> pipeline = new AggregationUtil(bulkOperationContext.getQueryMapper(),
bulkOperationContext.getQueryMapper().getMappingContext()).createPipeline(source,
context);
return pipeline;
}

private Bson getMappedUpdate(Bson update) {
return bulkOperationContext.getUpdateMapper().getMappedObject(update, bulkOperationContext.getEntity());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,27 @@
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;

import com.mongodb.bulk.BulkWriteResult;
import org.bson.Document;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.data.mongodb.BulkOperationException;
import org.springframework.data.mongodb.core.BulkOperations.BulkMode;
import org.springframework.data.mongodb.core.DefaultBulkOperations.BulkOperationContext;
import org.springframework.data.mongodb.core.aggregation.AggregationUpdate;
import org.springframework.data.mongodb.core.convert.QueryMapper;
import org.springframework.data.mongodb.core.convert.UpdateMapper;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.Update;
import org.springframework.data.mongodb.core.query.UpdateDefinition;
import org.springframework.data.mongodb.test.util.MongoTemplateExtension;
import org.springframework.data.mongodb.test.util.MongoTestTemplate;
import org.springframework.data.mongodb.test.util.Template;
Expand Down Expand Up @@ -135,13 +142,14 @@ public void insertUnOrderedContinuesOnError() {
});
}

@Test // DATAMONGO-934
public void upsertDoesUpdate() {
@ParameterizedTest // DATAMONGO-934, GH-3872
@MethodSource("upsertArguments")
void upsertDoesUpdate(UpdateDefinition update) {

insertSomeDocuments();

com.mongodb.bulk.BulkWriteResult result = createBulkOps(BulkMode.ORDERED).//
upsert(where("value", "value1"), set("value", "value2")).//
upsert(where("value", "value1"), update).//
execute();

assertThat(result).isNotNull();
Expand All @@ -152,11 +160,12 @@ public void upsertDoesUpdate() {
assertThat(result.getUpserts().size()).isZero();
}

@Test // DATAMONGO-934
public void upsertDoesInsert() {
@ParameterizedTest // DATAMONGO-934, GH-3872
@MethodSource("upsertArguments")
void upsertDoesInsert(UpdateDefinition update) {

com.mongodb.bulk.BulkWriteResult result = createBulkOps(BulkMode.ORDERED).//
upsert(where("_id", "1"), set("value", "v1")).//
upsert(where("_id", "1"), update).//
execute();

assertThat(result).isNotNull();
Expand All @@ -171,11 +180,37 @@ public void updateOneOrdered() {
testUpdate(BulkMode.ORDERED, false, 2);
}

@Test // GH-3872
public void updateOneWithAggregation() {

insertSomeDocuments();

BulkOperations bulkOps = createBulkOps(BulkMode.ORDERED);
bulkOps.updateOne(where("value", "value1"), AggregationUpdate.update().set("value").toValue("value3"));
BulkWriteResult result = bulkOps.execute();

assertThat(result.getModifiedCount()).isEqualTo(1);
assertThat(operations.<Long>execute(COLLECTION_NAME, collection -> collection.countDocuments(new org.bson.Document("value", "value3")))).isOne();
}

@Test // DATAMONGO-934
public void updateMultiOrdered() {
testUpdate(BulkMode.ORDERED, true, 4);
}

@Test // GH-3872
public void updateMultiWithAggregation() {

insertSomeDocuments();

BulkOperations bulkOps = createBulkOps(BulkMode.ORDERED);
bulkOps.updateMulti(where("value", "value1"), AggregationUpdate.update().set("value").toValue("value3"));
BulkWriteResult result = bulkOps.execute();

assertThat(result.getModifiedCount()).isEqualTo(2);
assertThat(operations.<Long>execute(COLLECTION_NAME, collection -> collection.countDocuments(new org.bson.Document("value", "value3")))).isEqualTo(2);
}

@Test // DATAMONGO-934
public void updateOneUnOrdered() {
testUpdate(BulkMode.UNORDERED, false, 2);
Expand Down Expand Up @@ -355,6 +390,10 @@ private void insertSomeDocuments() {
coll.insertOne(rawDoc("4", "value2"));
}

private static Stream<Arguments> upsertArguments() {
return Stream.of(Arguments.of(set("value", "value2")), Arguments.of(AggregationUpdate.update().set("value").toValue("value2")));
}

private static BaseDoc newDoc(String id) {

BaseDoc doc = new BaseDoc();
Expand Down