Skip to content

Commit 0ba857a

Browse files
christophstroblmp911de
authored andcommitted
Add support for AggregationUpdate to BulkOperations.
We now accept `UpdateDefinition` in `BulkOperations` to support custom update definitions and aggregation updates. Closes #3872 Original pull request: #4344
1 parent a94ea17 commit 0ba857a

File tree

3 files changed

+123
-18
lines changed

3 files changed

+123
-18
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/BulkOperations.java

+41-3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import org.springframework.data.mongodb.core.query.Query;
2121
import org.springframework.data.mongodb.core.query.Update;
22+
import org.springframework.data.mongodb.core.query.UpdateDefinition;
2223
import org.springframework.data.util.Pair;
2324

2425
import com.mongodb.bulk.BulkWriteResult;
@@ -75,7 +76,19 @@ enum BulkMode {
7576
* @param update {@link Update} operation to perform, must not be {@literal null}.
7677
* @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}.
7778
*/
78-
BulkOperations updateOne(Query query, Update update);
79+
default BulkOperations updateOne(Query query, Update update) {
80+
return updateOne(query, (UpdateDefinition) update);
81+
}
82+
83+
/**
84+
* Add a single update to the bulk operation. For the update request, only the first matching document is updated.
85+
*
86+
* @param query update criteria, must not be {@literal null}.
87+
* @param update {@link Update} operation to perform, must not be {@literal null}.
88+
* @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}.
89+
* @since 4.1
90+
*/
91+
BulkOperations updateOne(Query query, UpdateDefinition update);
7992

8093
/**
8194
* Add a list of updates to the bulk operation. For each update request, only the first matching document is updated.
@@ -92,7 +105,19 @@ enum BulkMode {
92105
* @param update Update operation to perform.
93106
* @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}.
94107
*/
95-
BulkOperations updateMulti(Query query, Update update);
108+
default BulkOperations updateMulti(Query query, Update update) {
109+
return updateMulti(query, (UpdateDefinition) update);
110+
}
111+
112+
/**
113+
* Add a single update to the bulk operation. For the update request, all matching documents are updated.
114+
*
115+
* @param query Update criteria.
116+
* @param update Update operation to perform.
117+
* @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}.
118+
* @since 4.1
119+
*/
120+
BulkOperations updateMulti(Query query, UpdateDefinition update);
96121

97122
/**
98123
* Add a list of updates to the bulk operation. For each update request, all matching documents are updated.
@@ -110,7 +135,20 @@ enum BulkMode {
110135
* @param update Update operation to perform.
111136
* @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}.
112137
*/
113-
BulkOperations upsert(Query query, Update update);
138+
default BulkOperations upsert(Query query, Update update) {
139+
return upsert(query, (UpdateDefinition) update);
140+
}
141+
142+
/**
143+
* Add a single upsert to the bulk operation. An upsert is an update if the set of matching documents is not empty,
144+
* else an insert.
145+
*
146+
* @param query Update criteria.
147+
* @param update Update operation to perform.
148+
* @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}.
149+
* @since 4.1
150+
*/
151+
BulkOperations upsert(Query query, UpdateDefinition update);
114152

115153
/**
116154
* Add a list of upserts to the bulk operation. An upsert is an update if the set of matching documents is not empty,

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/DefaultBulkOperations.java

+37-9
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
package org.springframework.data.mongodb.core;
1717

1818
import java.util.ArrayList;
19-
import java.util.Collections;
2019
import java.util.List;
2120
import java.util.Optional;
2221
import java.util.stream.Collectors;
@@ -25,8 +24,12 @@
2524
import org.bson.conversions.Bson;
2625
import org.springframework.context.ApplicationEventPublisher;
2726
import org.springframework.dao.DataIntegrityViolationException;
27+
import org.springframework.data.mapping.PersistentEntity;
2828
import org.springframework.data.mapping.callback.EntityCallbacks;
2929
import org.springframework.data.mongodb.BulkOperationException;
30+
import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext;
31+
import org.springframework.data.mongodb.core.aggregation.AggregationUpdate;
32+
import org.springframework.data.mongodb.core.aggregation.RelaxedTypeBasedAggregationOperationContext;
3033
import org.springframework.data.mongodb.core.convert.QueryMapper;
3134
import org.springframework.data.mongodb.core.convert.UpdateMapper;
3235
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
@@ -133,12 +136,12 @@ public BulkOperations insert(List<? extends Object> documents) {
133136

134137
@Override
135138
@SuppressWarnings("unchecked")
136-
public BulkOperations updateOne(Query query, Update update) {
139+
public BulkOperations updateOne(Query query, UpdateDefinition update) {
137140

138141
Assert.notNull(query, "Query must not be null");
139142
Assert.notNull(update, "Update must not be null");
140143

141-
return updateOne(Collections.singletonList(Pair.of(query, update)));
144+
return update(query, update, false, false);
142145
}
143146

144147
@Override
@@ -155,12 +158,14 @@ public BulkOperations updateOne(List<Pair<Query, Update>> updates) {
155158

156159
@Override
157160
@SuppressWarnings("unchecked")
158-
public BulkOperations updateMulti(Query query, Update update) {
161+
public BulkOperations updateMulti(Query query, UpdateDefinition update) {
159162

160163
Assert.notNull(query, "Query must not be null");
161164
Assert.notNull(update, "Update must not be null");
162165

163-
return updateMulti(Collections.singletonList(Pair.of(query, update)));
166+
update(query, update, false, true);
167+
168+
return this;
164169
}
165170

166171
@Override
@@ -176,7 +181,7 @@ public BulkOperations updateMulti(List<Pair<Query, Update>> updates) {
176181
}
177182

178183
@Override
179-
public BulkOperations upsert(Query query, Update update) {
184+
public BulkOperations upsert(Query query, UpdateDefinition update) {
180185
return update(query, update, true, true);
181186
}
182187

@@ -294,7 +299,7 @@ private WriteModel<Document> extractAndMapWriteModel(SourceAwareWriteModelHolder
294299
maybeInvokeBeforeSaveCallback(it.getSource(), target);
295300
}
296301

297-
return mapWriteModel(it.getModel());
302+
return mapWriteModel(it.getSource(), it.getModel());
298303
}
299304

300305
/**
@@ -306,7 +311,7 @@ private WriteModel<Document> extractAndMapWriteModel(SourceAwareWriteModelHolder
306311
* @param multi whether to issue a multi-update.
307312
* @return the {@link BulkOperations} with the update registered.
308313
*/
309-
private BulkOperations update(Query query, Update update, boolean upsert, boolean multi) {
314+
private BulkOperations update(Query query, UpdateDefinition update, boolean upsert, boolean multi) {
310315

311316
Assert.notNull(query, "Query must not be null");
312317
Assert.notNull(update, "Update must not be null");
@@ -322,11 +327,16 @@ private BulkOperations update(Query query, Update update, boolean upsert, boolea
322327
return this;
323328
}
324329

325-
private WriteModel<Document> mapWriteModel(WriteModel<Document> writeModel) {
330+
private WriteModel<Document> mapWriteModel(Object source, WriteModel<Document> writeModel) {
326331

327332
if (writeModel instanceof UpdateOneModel) {
328333

329334
UpdateOneModel<Document> model = (UpdateOneModel<Document>) writeModel;
335+
if (source instanceof AggregationUpdate aggregationUpdate) {
336+
337+
List<Document> pipeline = mapUpdatePipeline(aggregationUpdate);
338+
return new UpdateOneModel<>(getMappedQuery(model.getFilter()), pipeline, model.getOptions());
339+
}
330340

331341
return new UpdateOneModel<>(getMappedQuery(model.getFilter()), getMappedUpdate(model.getUpdate()),
332342
model.getOptions());
@@ -335,6 +345,11 @@ private WriteModel<Document> mapWriteModel(WriteModel<Document> writeModel) {
335345
if (writeModel instanceof UpdateManyModel) {
336346

337347
UpdateManyModel<Document> model = (UpdateManyModel<Document>) writeModel;
348+
if (source instanceof AggregationUpdate aggregationUpdate) {
349+
350+
List<Document> pipeline = mapUpdatePipeline(aggregationUpdate);
351+
return new UpdateManyModel<>(getMappedQuery(model.getFilter()), pipeline, model.getOptions());
352+
}
338353

339354
return new UpdateManyModel<>(getMappedQuery(model.getFilter()), getMappedUpdate(model.getUpdate()),
340355
model.getOptions());
@@ -357,6 +372,19 @@ private WriteModel<Document> mapWriteModel(WriteModel<Document> writeModel) {
357372
return writeModel;
358373
}
359374

375+
private List<Document> mapUpdatePipeline(AggregationUpdate source) {
376+
Class<?> type = bulkOperationContext.getEntity().isPresent()
377+
? bulkOperationContext.getEntity().map(PersistentEntity::getType).get()
378+
: Object.class;
379+
AggregationOperationContext context = new RelaxedTypeBasedAggregationOperationContext(type,
380+
bulkOperationContext.getUpdateMapper().getMappingContext(), bulkOperationContext.getQueryMapper());
381+
382+
List<Document> pipeline = new AggregationUtil(bulkOperationContext.getQueryMapper(),
383+
bulkOperationContext.getQueryMapper().getMappingContext()).createPipeline(source,
384+
context);
385+
return pipeline;
386+
}
387+
360388
private Bson getMappedUpdate(Bson update) {
361389
return bulkOperationContext.getUpdateMapper().getMappedObject(update, bulkOperationContext.getEntity());
362390
}

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultBulkOperationsIntegrationTests.java

+45-6
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,27 @@
2121
import java.util.Arrays;
2222
import java.util.List;
2323
import java.util.Optional;
24+
import java.util.stream.Stream;
2425

26+
import com.mongodb.bulk.BulkWriteResult;
2527
import org.bson.Document;
2628
import org.junit.jupiter.api.BeforeEach;
2729
import org.junit.jupiter.api.Test;
2830
import org.junit.jupiter.api.extension.ExtendWith;
31+
import org.junit.jupiter.params.ParameterizedTest;
32+
import org.junit.jupiter.params.provider.Arguments;
33+
import org.junit.jupiter.params.provider.MethodSource;
2934
import org.springframework.data.mongodb.BulkOperationException;
3035
import org.springframework.data.mongodb.core.BulkOperations.BulkMode;
3136
import org.springframework.data.mongodb.core.DefaultBulkOperations.BulkOperationContext;
37+
import org.springframework.data.mongodb.core.aggregation.AggregationUpdate;
3238
import org.springframework.data.mongodb.core.convert.QueryMapper;
3339
import org.springframework.data.mongodb.core.convert.UpdateMapper;
3440
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
3541
import org.springframework.data.mongodb.core.query.Criteria;
3642
import org.springframework.data.mongodb.core.query.Query;
3743
import org.springframework.data.mongodb.core.query.Update;
44+
import org.springframework.data.mongodb.core.query.UpdateDefinition;
3845
import org.springframework.data.mongodb.test.util.MongoTemplateExtension;
3946
import org.springframework.data.mongodb.test.util.MongoTestTemplate;
4047
import org.springframework.data.mongodb.test.util.Template;
@@ -135,13 +142,14 @@ public void insertUnOrderedContinuesOnError() {
135142
});
136143
}
137144

138-
@Test // DATAMONGO-934
139-
public void upsertDoesUpdate() {
145+
@ParameterizedTest // DATAMONGO-934, GH-3872
146+
@MethodSource("upsertArguments")
147+
void upsertDoesUpdate(UpdateDefinition update) {
140148

141149
insertSomeDocuments();
142150

143151
com.mongodb.bulk.BulkWriteResult result = createBulkOps(BulkMode.ORDERED).//
144-
upsert(where("value", "value1"), set("value", "value2")).//
152+
upsert(where("value", "value1"), update).//
145153
execute();
146154

147155
assertThat(result).isNotNull();
@@ -152,11 +160,12 @@ public void upsertDoesUpdate() {
152160
assertThat(result.getUpserts().size()).isZero();
153161
}
154162

155-
@Test // DATAMONGO-934
156-
public void upsertDoesInsert() {
163+
@ParameterizedTest // DATAMONGO-934, GH-3872
164+
@MethodSource("upsertArguments")
165+
void upsertDoesInsert(UpdateDefinition update) {
157166

158167
com.mongodb.bulk.BulkWriteResult result = createBulkOps(BulkMode.ORDERED).//
159-
upsert(where("_id", "1"), set("value", "v1")).//
168+
upsert(where("_id", "1"), update).//
160169
execute();
161170

162171
assertThat(result).isNotNull();
@@ -171,11 +180,37 @@ public void updateOneOrdered() {
171180
testUpdate(BulkMode.ORDERED, false, 2);
172181
}
173182

183+
@Test // GH-3872
184+
public void updateOneWithAggregation() {
185+
186+
insertSomeDocuments();
187+
188+
BulkOperations bulkOps = createBulkOps(BulkMode.ORDERED);
189+
bulkOps.updateOne(where("value", "value1"), AggregationUpdate.update().set("value").toValue("value3"));
190+
BulkWriteResult result = bulkOps.execute();
191+
192+
assertThat(result.getModifiedCount()).isEqualTo(1);
193+
assertThat(operations.<Long>execute(COLLECTION_NAME, collection -> collection.countDocuments(new org.bson.Document("value", "value3")))).isOne();
194+
}
195+
174196
@Test // DATAMONGO-934
175197
public void updateMultiOrdered() {
176198
testUpdate(BulkMode.ORDERED, true, 4);
177199
}
178200

201+
@Test // GH-3872
202+
public void updateMultiWithAggregation() {
203+
204+
insertSomeDocuments();
205+
206+
BulkOperations bulkOps = createBulkOps(BulkMode.ORDERED);
207+
bulkOps.updateMulti(where("value", "value1"), AggregationUpdate.update().set("value").toValue("value3"));
208+
BulkWriteResult result = bulkOps.execute();
209+
210+
assertThat(result.getModifiedCount()).isEqualTo(2);
211+
assertThat(operations.<Long>execute(COLLECTION_NAME, collection -> collection.countDocuments(new org.bson.Document("value", "value3")))).isEqualTo(2);
212+
}
213+
179214
@Test // DATAMONGO-934
180215
public void updateOneUnOrdered() {
181216
testUpdate(BulkMode.UNORDERED, false, 2);
@@ -355,6 +390,10 @@ private void insertSomeDocuments() {
355390
coll.insertOne(rawDoc("4", "value2"));
356391
}
357392

393+
private static Stream<Arguments> upsertArguments() {
394+
return Stream.of(Arguments.of(set("value", "value2")), Arguments.of(AggregationUpdate.update().set("value").toValue("value2")));
395+
}
396+
358397
private static BaseDoc newDoc(String id) {
359398

360399
BaseDoc doc = new BaseDoc();

0 commit comments

Comments
 (0)