16
16
package org .springframework .data .mongodb .core ;
17
17
18
18
import java .util .ArrayList ;
19
- import java .util .Collections ;
20
19
import java .util .List ;
21
20
import java .util .Optional ;
22
21
import java .util .stream .Collectors ;
25
24
import org .bson .conversions .Bson ;
26
25
import org .springframework .context .ApplicationEventPublisher ;
27
26
import org .springframework .dao .DataIntegrityViolationException ;
27
+ import org .springframework .data .mapping .PersistentEntity ;
28
28
import org .springframework .data .mapping .callback .EntityCallbacks ;
29
29
import org .springframework .data .mongodb .BulkOperationException ;
30
+ import org .springframework .data .mongodb .core .aggregation .AggregationOperationContext ;
31
+ import org .springframework .data .mongodb .core .aggregation .AggregationUpdate ;
32
+ import org .springframework .data .mongodb .core .aggregation .RelaxedTypeBasedAggregationOperationContext ;
30
33
import org .springframework .data .mongodb .core .convert .QueryMapper ;
31
34
import org .springframework .data .mongodb .core .convert .UpdateMapper ;
32
35
import org .springframework .data .mongodb .core .mapping .MongoPersistentEntity ;
@@ -133,12 +136,12 @@ public BulkOperations insert(List<? extends Object> documents) {
133
136
134
137
@ Override
135
138
@ SuppressWarnings ("unchecked" )
136
- public BulkOperations updateOne (Query query , Update update ) {
139
+ public BulkOperations updateOne (Query query , UpdateDefinition update ) {
137
140
138
141
Assert .notNull (query , "Query must not be null" );
139
142
Assert .notNull (update , "Update must not be null" );
140
143
141
- return updateOne ( Collections . singletonList ( Pair . of ( query , update )) );
144
+ return update ( query , update , false , false );
142
145
}
143
146
144
147
@ Override
@@ -155,12 +158,14 @@ public BulkOperations updateOne(List<Pair<Query, Update>> updates) {
155
158
156
159
@ Override
157
160
@ SuppressWarnings ("unchecked" )
158
- public BulkOperations updateMulti (Query query , Update update ) {
161
+ public BulkOperations updateMulti (Query query , UpdateDefinition update ) {
159
162
160
163
Assert .notNull (query , "Query must not be null" );
161
164
Assert .notNull (update , "Update must not be null" );
162
165
163
- return updateMulti (Collections .singletonList (Pair .of (query , update )));
166
+ update (query , update , false , true );
167
+
168
+ return this ;
164
169
}
165
170
166
171
@ Override
@@ -176,7 +181,7 @@ public BulkOperations updateMulti(List<Pair<Query, Update>> updates) {
176
181
}
177
182
178
183
@ Override
179
- public BulkOperations upsert (Query query , Update update ) {
184
+ public BulkOperations upsert (Query query , UpdateDefinition update ) {
180
185
return update (query , update , true , true );
181
186
}
182
187
@@ -294,7 +299,7 @@ private WriteModel<Document> extractAndMapWriteModel(SourceAwareWriteModelHolder
294
299
maybeInvokeBeforeSaveCallback (it .getSource (), target );
295
300
}
296
301
297
- return mapWriteModel (it .getModel ());
302
+ return mapWriteModel (it .getSource (), it . getModel ());
298
303
}
299
304
300
305
/**
@@ -306,7 +311,7 @@ private WriteModel<Document> extractAndMapWriteModel(SourceAwareWriteModelHolder
306
311
* @param multi whether to issue a multi-update.
307
312
* @return the {@link BulkOperations} with the update registered.
308
313
*/
309
- private BulkOperations update (Query query , Update update , boolean upsert , boolean multi ) {
314
+ private BulkOperations update (Query query , UpdateDefinition update , boolean upsert , boolean multi ) {
310
315
311
316
Assert .notNull (query , "Query must not be null" );
312
317
Assert .notNull (update , "Update must not be null" );
@@ -322,11 +327,16 @@ private BulkOperations update(Query query, Update update, boolean upsert, boolea
322
327
return this ;
323
328
}
324
329
325
- private WriteModel <Document > mapWriteModel (WriteModel <Document > writeModel ) {
330
+ private WriteModel <Document > mapWriteModel (Object source , WriteModel <Document > writeModel ) {
326
331
327
332
if (writeModel instanceof UpdateOneModel ) {
328
333
329
334
UpdateOneModel <Document > model = (UpdateOneModel <Document >) writeModel ;
335
+ if (source instanceof AggregationUpdate aggregationUpdate ) {
336
+
337
+ List <Document > pipeline = mapUpdatePipeline (aggregationUpdate );
338
+ return new UpdateOneModel <>(getMappedQuery (model .getFilter ()), pipeline , model .getOptions ());
339
+ }
330
340
331
341
return new UpdateOneModel <>(getMappedQuery (model .getFilter ()), getMappedUpdate (model .getUpdate ()),
332
342
model .getOptions ());
@@ -335,6 +345,11 @@ private WriteModel<Document> mapWriteModel(WriteModel<Document> writeModel) {
335
345
if (writeModel instanceof UpdateManyModel ) {
336
346
337
347
UpdateManyModel <Document > model = (UpdateManyModel <Document >) writeModel ;
348
+ if (source instanceof AggregationUpdate aggregationUpdate ) {
349
+
350
+ List <Document > pipeline = mapUpdatePipeline (aggregationUpdate );
351
+ return new UpdateManyModel <>(getMappedQuery (model .getFilter ()), pipeline , model .getOptions ());
352
+ }
338
353
339
354
return new UpdateManyModel <>(getMappedQuery (model .getFilter ()), getMappedUpdate (model .getUpdate ()),
340
355
model .getOptions ());
@@ -357,6 +372,19 @@ private WriteModel<Document> mapWriteModel(WriteModel<Document> writeModel) {
357
372
return writeModel ;
358
373
}
359
374
375
+ private List <Document > mapUpdatePipeline (AggregationUpdate source ) {
376
+ Class <?> type = bulkOperationContext .getEntity ().isPresent ()
377
+ ? bulkOperationContext .getEntity ().map (PersistentEntity ::getType ).get ()
378
+ : Object .class ;
379
+ AggregationOperationContext context = new RelaxedTypeBasedAggregationOperationContext (type ,
380
+ bulkOperationContext .getUpdateMapper ().getMappingContext (), bulkOperationContext .getQueryMapper ());
381
+
382
+ List <Document > pipeline = new AggregationUtil (bulkOperationContext .getQueryMapper (),
383
+ bulkOperationContext .getQueryMapper ().getMappingContext ()).createPipeline (source ,
384
+ context );
385
+ return pipeline ;
386
+ }
387
+
360
388
private Bson getMappedUpdate (Bson update ) {
361
389
return bulkOperationContext .getUpdateMapper ().getMappedObject (update , bulkOperationContext .getEntity ());
362
390
}
0 commit comments