68
68
import org .springframework .data .mongodb .core .aggregation .AggregationOperationContext ;
69
69
import org .springframework .data .mongodb .core .aggregation .AggregationOptions ;
70
70
import org .springframework .data .mongodb .core .aggregation .AggregationResults ;
71
+ import org .springframework .data .mongodb .core .aggregation .AggregationUpdate ;
71
72
import org .springframework .data .mongodb .core .aggregation .Fields ;
73
+ import org .springframework .data .mongodb .core .aggregation .RelaxedTypeBasedAggregationOperationContext ;
72
74
import org .springframework .data .mongodb .core .aggregation .TypeBasedAggregationOperationContext ;
73
75
import org .springframework .data .mongodb .core .aggregation .TypedAggregation ;
74
76
import org .springframework .data .mongodb .core .convert .DbRefResolver ;
108
110
import org .springframework .data .mongodb .core .query .Meta .CursorOption ;
109
111
import org .springframework .data .mongodb .core .query .NearQuery ;
110
112
import org .springframework .data .mongodb .core .query .Query ;
111
- import org .springframework .data .mongodb .core .query .Update ;
112
113
import org .springframework .data .mongodb .core .query .UpdateDefinition ;
113
114
import org .springframework .data .mongodb .core .query .UpdateDefinition .ArrayFilter ;
114
115
import org .springframework .data .mongodb .core .validation .Validator ;
@@ -1043,25 +1044,25 @@ public <T> GeoResults<T> geoNear(NearQuery near, Class<?> domainType, String col
1043
1044
1044
1045
@ Nullable
1045
1046
@ Override
1046
- public <T > T findAndModify (Query query , Update update , Class <T > entityClass ) {
1047
+ public <T > T findAndModify (Query query , UpdateDefinition update , Class <T > entityClass ) {
1047
1048
return findAndModify (query , update , new FindAndModifyOptions (), entityClass , getCollectionName (entityClass ));
1048
1049
}
1049
1050
1050
1051
@ Nullable
1051
1052
@ Override
1052
- public <T > T findAndModify (Query query , Update update , Class <T > entityClass , String collectionName ) {
1053
+ public <T > T findAndModify (Query query , UpdateDefinition update , Class <T > entityClass , String collectionName ) {
1053
1054
return findAndModify (query , update , new FindAndModifyOptions (), entityClass , collectionName );
1054
1055
}
1055
1056
1056
1057
@ Nullable
1057
1058
@ Override
1058
- public <T > T findAndModify (Query query , Update update , FindAndModifyOptions options , Class <T > entityClass ) {
1059
+ public <T > T findAndModify (Query query , UpdateDefinition update , FindAndModifyOptions options , Class <T > entityClass ) {
1059
1060
return findAndModify (query , update , options , entityClass , getCollectionName (entityClass ));
1060
1061
}
1061
1062
1062
1063
@ Nullable
1063
1064
@ Override
1064
- public <T > T findAndModify (Query query , Update update , FindAndModifyOptions options , Class <T > entityClass ,
1065
+ public <T > T findAndModify (Query query , UpdateDefinition update , FindAndModifyOptions options , Class <T > entityClass ,
1065
1066
String collectionName ) {
1066
1067
1067
1068
Assert .notNull (query , "Query must not be null!" );
@@ -1561,53 +1562,54 @@ public Object doInCollection(MongoCollection<Document> collection) throws MongoE
1561
1562
}
1562
1563
1563
1564
@ Override
1564
- public UpdateResult upsert (Query query , Update update , Class <?> entityClass ) {
1565
+ public UpdateResult upsert (Query query , UpdateDefinition update , Class <?> entityClass ) {
1565
1566
return doUpdate (getCollectionName (entityClass ), query , update , entityClass , true , false );
1566
1567
}
1567
1568
1568
1569
@ Override
1569
- public UpdateResult upsert (Query query , Update update , String collectionName ) {
1570
+ public UpdateResult upsert (Query query , UpdateDefinition update , String collectionName ) {
1570
1571
return doUpdate (collectionName , query , update , null , true , false );
1571
1572
}
1572
1573
1573
1574
@ Override
1574
- public UpdateResult upsert (Query query , Update update , Class <?> entityClass , String collectionName ) {
1575
+ public UpdateResult upsert (Query query , UpdateDefinition update , Class <?> entityClass , String collectionName ) {
1575
1576
1576
1577
Assert .notNull (entityClass , "EntityClass must not be null!" );
1577
1578
1578
1579
return doUpdate (collectionName , query , update , entityClass , true , false );
1579
1580
}
1580
1581
1581
1582
@ Override
1582
- public UpdateResult updateFirst (Query query , Update update , Class <?> entityClass ) {
1583
+ public UpdateResult updateFirst (Query query , UpdateDefinition update , Class <?> entityClass ) {
1583
1584
return doUpdate (getCollectionName (entityClass ), query , update , entityClass , false , false );
1584
1585
}
1585
1586
1586
1587
@ Override
1587
- public UpdateResult updateFirst (final Query query , final Update update , final String collectionName ) {
1588
+ public UpdateResult updateFirst (final Query query , final UpdateDefinition update , final String collectionName ) {
1588
1589
return doUpdate (collectionName , query , update , null , false , false );
1589
1590
}
1590
1591
1591
1592
@ Override
1592
- public UpdateResult updateFirst (Query query , Update update , Class <?> entityClass , String collectionName ) {
1593
+ public UpdateResult updateFirst (Query query , UpdateDefinition update , Class <?> entityClass , String collectionName ) {
1593
1594
1594
1595
Assert .notNull (entityClass , "EntityClass must not be null!" );
1595
1596
1596
1597
return doUpdate (collectionName , query , update , entityClass , false , false );
1597
1598
}
1598
1599
1599
1600
@ Override
1600
- public UpdateResult updateMulti (Query query , Update update , Class <?> entityClass ) {
1601
+ public UpdateResult updateMulti (Query query , UpdateDefinition update , Class <?> entityClass ) {
1601
1602
return doUpdate (getCollectionName (entityClass ), query , update , entityClass , false , true );
1602
1603
}
1603
1604
1604
1605
@ Override
1605
- public UpdateResult updateMulti (final Query query , final Update update , String collectionName ) {
1606
+ public UpdateResult updateMulti (final Query query , final UpdateDefinition update , String collectionName ) {
1606
1607
return doUpdate (collectionName , query , update , null , false , true );
1607
1608
}
1608
1609
1609
1610
@ Override
1610
- public UpdateResult updateMulti (final Query query , final Update update , Class <?> entityClass , String collectionName ) {
1611
+ public UpdateResult updateMulti (final Query query , final UpdateDefinition update , Class <?> entityClass ,
1612
+ String collectionName ) {
1611
1613
1612
1614
Assert .notNull (entityClass , "EntityClass must not be null!" );
1613
1615
@@ -1622,24 +1624,52 @@ protected UpdateResult doUpdate(final String collectionName, final Query query,
1622
1624
Assert .notNull (query , "Query must not be null!" );
1623
1625
Assert .notNull (update , "Update must not be null!" );
1624
1626
1625
- return execute (collectionName , collection -> {
1627
+ MongoPersistentEntity <?> entity = entityClass == null ? null : getPersistentEntity (entityClass );
1628
+ increaseVersionForUpdateIfNecessary (entity , update );
1626
1629
1627
- MongoPersistentEntity <?> entity = entityClass == null ? null : getPersistentEntity (entityClass );
1630
+ UpdateOptions opts = new UpdateOptions ();
1631
+ opts .upsert (upsert );
1628
1632
1629
- increaseVersionForUpdateIfNecessary (entity , update );
1633
+ if (update .hasArrayFilters ()) {
1634
+ opts .arrayFilters (update .getArrayFilters ().stream ().map (ArrayFilter ::asDocument ).collect (Collectors .toList ()));
1635
+ }
1630
1636
1631
- UpdateOptions opts = new UpdateOptions ();
1632
- opts .upsert (upsert );
1637
+ Document queryObj = new Document ();
1633
1638
1634
- if (update . hasArrayFilters () ) {
1635
- opts . arrayFilters ( update . getArrayFilters (). stream (). map ( ArrayFilter :: asDocument ). collect ( Collectors . toList () ));
1636
- }
1639
+ if (query != null ) {
1640
+ queryObj . putAll ( queryMapper . getMappedObject ( query . getQueryObject (), entity ));
1641
+ }
1637
1642
1638
- Document queryObj = new Document ();
1643
+ if (multi && update .isIsolated () && !queryObj .containsKey ("$isolated" )) {
1644
+ queryObj .put ("$isolated" , 1 );
1645
+ }
1639
1646
1640
- if (query != null ) {
1641
- queryObj .putAll (queryMapper .getMappedObject (query .getQueryObject (), entity ));
1642
- }
1647
+ if (update instanceof AggregationUpdate ) {
1648
+
1649
+ AggregationOperationContext context = entityClass != null
1650
+ ? new RelaxedTypeBasedAggregationOperationContext (entityClass , mappingContext , queryMapper )
1651
+ : Aggregation .DEFAULT_CONTEXT ;
1652
+
1653
+ AggregationUpdate aUppdate = ((AggregationUpdate ) update );
1654
+ List <Document > pipeline = new AggregationUtil (queryMapper , mappingContext ).createPipeline (aUppdate , context );
1655
+
1656
+ return execute (collectionName , collection -> {
1657
+
1658
+ MongoAction mongoAction = new MongoAction (writeConcern , MongoActionOperation .UPDATE , collectionName ,
1659
+ entityClass , update .getUpdateObject (), queryObj );
1660
+ WriteConcern writeConcernToUse = prepareWriteConcern (mongoAction );
1661
+
1662
+ collection = writeConcernToUse != null ? collection .withWriteConcern (writeConcernToUse ) : collection ;
1663
+
1664
+ if (multi ) {
1665
+ return collection .updateMany (queryObj , pipeline , opts );
1666
+ }
1667
+
1668
+ return collection .updateOne (queryObj , pipeline , opts );
1669
+ });
1670
+ }
1671
+
1672
+ return execute (collectionName , collection -> {
1643
1673
1644
1674
operations .forType (entityClass ) //
1645
1675
.getCollation (query ) //
@@ -1649,10 +1679,6 @@ protected UpdateResult doUpdate(final String collectionName, final Query query,
1649
1679
Document updateObj = update instanceof MappedUpdate ? update .getUpdateObject ()
1650
1680
: updateMapper .getMappedObject (update .getUpdateObject (), entity );
1651
1681
1652
- if (multi && update .isIsolated () && !queryObj .containsKey ("$isolated" )) {
1653
- queryObj .put ("$isolated" , 1 );
1654
- }
1655
-
1656
1682
if (LOGGER .isDebugEnabled ()) {
1657
1683
LOGGER .debug ("Calling update using query: {} and update: {} in collection: {}" , serializeToJsonSafely (queryObj ),
1658
1684
serializeToJsonSafely (updateObj ), collectionName );
@@ -2640,7 +2666,7 @@ protected <T> T doFindAndRemove(String collectionName, Document query, Document
2640
2666
2641
2667
@ SuppressWarnings ("ConstantConditions" )
2642
2668
protected <T > T doFindAndModify (String collectionName , Document query , Document fields , Document sort ,
2643
- Class <T > entityClass , Update update , @ Nullable FindAndModifyOptions options ) {
2669
+ Class <T > entityClass , UpdateDefinition update , @ Nullable FindAndModifyOptions options ) {
2644
2670
2645
2671
EntityReader <? super T , Bson > readerToUse = this .mongoConverter ;
2646
2672
@@ -2653,7 +2679,18 @@ protected <T> T doFindAndModify(String collectionName, Document query, Document
2653
2679
increaseVersionForUpdateIfNecessary (entity , update );
2654
2680
2655
2681
Document mappedQuery = queryMapper .getMappedObject (query , entity );
2656
- Document mappedUpdate = updateMapper .getMappedObject (update .getUpdateObject (), entity );
2682
+
2683
+ Object mappedUpdate = new Document ();
2684
+ if (update instanceof AggregationUpdate ) {
2685
+
2686
+ AggregationOperationContext context = entityClass != null
2687
+ ? new RelaxedTypeBasedAggregationOperationContext (entityClass , mappingContext , queryMapper )
2688
+ : Aggregation .DEFAULT_CONTEXT ;
2689
+
2690
+ mappedUpdate = new AggregationUtil (queryMapper , mappingContext ).createPipeline ((Aggregation ) update , context );
2691
+ } else {
2692
+ mappedUpdate = updateMapper .getMappedObject (update .getUpdateObject (), entity );
2693
+ }
2657
2694
2658
2695
if (LOGGER .isDebugEnabled ()) {
2659
2696
LOGGER .debug (
@@ -3027,11 +3064,11 @@ private static class FindAndModifyCallback implements CollectionCallback<Documen
3027
3064
private final Document query ;
3028
3065
private final Document fields ;
3029
3066
private final Document sort ;
3030
- private final Document update ;
3067
+ private final Object update ;
3031
3068
private final List <Document > arrayFilters ;
3032
3069
private final FindAndModifyOptions options ;
3033
3070
3034
- public FindAndModifyCallback (Document query , Document fields , Document sort , Document update ,
3071
+ public FindAndModifyCallback (Document query , Document fields , Document sort , Object update ,
3035
3072
List <Document > arrayFilters , FindAndModifyOptions options ) {
3036
3073
this .query = query ;
3037
3074
this .fields = fields ;
@@ -3059,7 +3096,12 @@ public Document doInCollection(MongoCollection<Document> collection) throws Mong
3059
3096
opts .arrayFilters (arrayFilters );
3060
3097
}
3061
3098
3062
- return collection .findOneAndUpdate (query , update , opts );
3099
+ if (update instanceof Document ) {
3100
+ return collection .findOneAndUpdate (query , (Document ) update , opts );
3101
+ } else if (update instanceof List ) {
3102
+ return collection .findOneAndUpdate (query , (List <Document >) update , opts );
3103
+ }
3104
+ throw new IllegalArgumentException ("doh - that does not work" );
3063
3105
}
3064
3106
}
3065
3107
0 commit comments