Skip to content

Commit c237136

Browse files
SkipOutput for void methods using declarative Aggregations having $out stage.
We now set the skipOutput flag if an annotated Aggregation defines an $out stage and when the method is declared to return no result (void / Mono<Void>, kotlin.Unit) Closes: #4088
1 parent bc5db18 commit c237136

File tree

7 files changed

+94
-13
lines changed

7 files changed

+94
-13
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractMongoQuery.java

+1
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ public Object execute(Object[] parameters) {
127127
* @param accessor for providing invocation arguments. Never {@literal null}.
128128
* @param typeToRead the desired component target type. Can be {@literal null}.
129129
*/
130+
@Nullable
130131
protected Object doExecute(MongoQueryMethod method, ResultProcessor processor, ConvertingParameterAccessor accessor,
131132
@Nullable Class<?> typeToRead) {
132133

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AggregationUtils.java

+4-3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.data.mongodb.core.aggregation.Aggregation;
2828
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
2929
import org.springframework.data.mongodb.core.aggregation.AggregationOptions;
30+
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
3031
import org.springframework.data.mongodb.core.convert.MongoConverter;
3132
import org.springframework.data.mongodb.core.query.Collation;
3233
import org.springframework.data.mongodb.core.query.Meta;
@@ -109,7 +110,7 @@ static AggregationOptions.Builder applyMeta(AggregationOptions.Builder builder,
109110
* @param accessor
110111
* @param targetType
111112
*/
112-
static void appendSortIfPresent(List<AggregationOperation> aggregationPipeline, ConvertingParameterAccessor accessor,
113+
static void appendSortIfPresent(AggregationPipeline aggregationPipeline, ConvertingParameterAccessor accessor,
113114
Class<?> targetType) {
114115

115116
if (accessor.getSort().isUnsorted()) {
@@ -134,7 +135,7 @@ static void appendSortIfPresent(List<AggregationOperation> aggregationPipeline,
134135
* @param aggregationPipeline
135136
* @param accessor
136137
*/
137-
static void appendLimitAndOffsetIfPresent(List<AggregationOperation> aggregationPipeline,
138+
static void appendLimitAndOffsetIfPresent(AggregationPipeline aggregationPipeline,
138139
ConvertingParameterAccessor accessor) {
139140
appendLimitAndOffsetIfPresent(aggregationPipeline, accessor, LongUnaryOperator.identity(),
140141
IntUnaryOperator.identity());
@@ -150,7 +151,7 @@ static void appendLimitAndOffsetIfPresent(List<AggregationOperation> aggregation
150151
* @param limitOperator
151152
* @since 3.3
152153
*/
153-
static void appendLimitAndOffsetIfPresent(List<AggregationOperation> aggregationPipeline,
154+
static void appendLimitAndOffsetIfPresent(AggregationPipeline aggregationPipeline,
154155
ConvertingParameterAccessor accessor, LongUnaryOperator offsetOperator, IntUnaryOperator limitOperator) {
155156

156157
Pageable pageable = accessor.getPageable();

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import org.springframework.data.mongodb.core.query.UpdateDefinition;
3939
import org.springframework.data.support.PageableExecutionUtils;
4040
import org.springframework.data.util.TypeInformation;
41+
import org.springframework.lang.Nullable;
4142
import org.springframework.util.Assert;
4243
import org.springframework.util.ClassUtils;
4344

@@ -55,6 +56,7 @@
5556
@FunctionalInterface
5657
interface MongoQueryExecution {
5758

59+
@Nullable
5860
Object execute(Query query);
5961

6062
/**

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedAggregation.java

+12-4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
*/
1616
package org.springframework.data.mongodb.repository.query;
1717

18+
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
19+
import org.springframework.data.util.ReflectionUtils;
1820
import reactor.core.publisher.Flux;
1921
import reactor.core.publisher.Mono;
2022

@@ -81,7 +83,7 @@ protected Publisher<Object> doExecute(ReactiveMongoQueryMethod method, ResultPro
8183
Class<?> sourceType = method.getDomainClass();
8284
Class<?> targetType = typeToRead;
8385

84-
List<AggregationOperation> pipeline = it;
86+
AggregationPipeline pipeline = new AggregationPipeline(it);
8587

8688
AggregationUtils.appendSortIfPresent(pipeline, accessor, typeToRead);
8789
AggregationUtils.appendLimitAndOffsetIfPresent(pipeline, accessor);
@@ -93,10 +95,13 @@ protected Publisher<Object> doExecute(ReactiveMongoQueryMethod method, ResultPro
9395
targetType = Document.class;
9496
}
9597

96-
AggregationOptions options = computeOptions(method, accessor);
97-
TypedAggregation<?> aggregation = new TypedAggregation<>(sourceType, pipeline, options);
98+
AggregationOptions options = computeOptions(method, accessor, pipeline);
99+
TypedAggregation<?> aggregation = new TypedAggregation<>(sourceType, pipeline.getOperations(), options);
98100

99101
Flux<?> flux = reactiveMongoOperations.aggregate(aggregation, targetType);
102+
if(ReflectionUtils.isVoid(typeToRead)) {
103+
return flux.then();
104+
}
100105

101106
if (isSimpleReturnType && !isRawReturnType) {
102107
flux = flux.handle((item, sink) -> {
@@ -121,13 +126,16 @@ private Mono<List<AggregationOperation>> computePipeline(ConvertingParameterAcce
121126
return parseAggregationPipeline(getQueryMethod().getAnnotatedAggregation(), accessor);
122127
}
123128

124-
private AggregationOptions computeOptions(MongoQueryMethod method, ConvertingParameterAccessor accessor) {
129+
private AggregationOptions computeOptions(MongoQueryMethod method, ConvertingParameterAccessor accessor, AggregationPipeline pipeline) {
125130

126131
AggregationOptions.Builder builder = Aggregation.newAggregationOptions();
127132

128133
AggregationUtils.applyCollation(builder, method.getAnnotatedCollation(), accessor, method.getParameters(),
129134
expressionParser, evaluationContextProvider);
130135
AggregationUtils.applyMeta(builder, method);
136+
if(ReflectionUtils.isVoid(method.getReturnType().getComponentType().getType()) && pipeline.isOutOrMerge()) {
137+
builder.skipOutput();
138+
}
131139

132140
return builder.build();
133141
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/StringBasedAggregation.java

+19-6
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,20 @@
2929
import org.springframework.data.mongodb.core.aggregation.Aggregation;
3030
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
3131
import org.springframework.data.mongodb.core.aggregation.AggregationOptions;
32+
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
3233
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
3334
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
3435
import org.springframework.data.mongodb.core.convert.MongoConverter;
3536
import org.springframework.data.mongodb.core.mapping.MongoSimpleTypes;
3637
import org.springframework.data.mongodb.core.query.Query;
3738
import org.springframework.data.repository.query.QueryMethodEvaluationContextProvider;
3839
import org.springframework.data.repository.query.ResultProcessor;
40+
import org.springframework.data.util.ReflectionUtils;
3941
import org.springframework.expression.ExpressionParser;
42+
import org.springframework.lang.Nullable;
4043
import org.springframework.util.ClassUtils;
44+
import org.springframework.util.CollectionUtils;
45+
import org.springframework.util.ObjectUtils;
4146

4247
/**
4348
* {@link AbstractMongoQuery} implementation to run string-based aggregations using
@@ -84,13 +89,14 @@ public StringBasedAggregation(MongoQueryMethod method, MongoOperations mongoOper
8489
* @see org.springframework.data.mongodb.repository.query.AbstractReactiveMongoQuery#doExecute(org.springframework.data.mongodb.repository.query.MongoQueryMethod, org.springframework.data.repository.query.ResultProcessor, org.springframework.data.mongodb.repository.query.ConvertingParameterAccessor, java.lang.Class)
8590
*/
8691
@Override
92+
@Nullable
8793
protected Object doExecute(MongoQueryMethod method, ResultProcessor resultProcessor,
8894
ConvertingParameterAccessor accessor, Class<?> typeToRead) {
8995

9096
Class<?> sourceType = method.getDomainClass();
9197
Class<?> targetType = typeToRead;
9298

93-
List<AggregationOperation> pipeline = computePipeline(method, accessor);
99+
AggregationPipeline pipeline = computePipeline(method, accessor);
94100
AggregationUtils.appendSortIfPresent(pipeline, accessor, typeToRead);
95101

96102
if (method.isSliceQuery()) {
@@ -111,8 +117,8 @@ protected Object doExecute(MongoQueryMethod method, ResultProcessor resultProces
111117
targetType = method.getReturnType().getRequiredActualType().getRequiredComponentType().getType();
112118
}
113119

114-
AggregationOptions options = computeOptions(method, accessor);
115-
TypedAggregation<?> aggregation = new TypedAggregation<>(sourceType, pipeline, options);
120+
AggregationOptions options = computeOptions(method, accessor, pipeline);
121+
TypedAggregation<?> aggregation = new TypedAggregation<>(sourceType, pipeline.getOperations(), options);
116122

117123
if (method.isStreamQuery()) {
118124

@@ -126,6 +132,9 @@ protected Object doExecute(MongoQueryMethod method, ResultProcessor resultProces
126132
}
127133

128134
AggregationResults<Object> result = (AggregationResults<Object>) mongoOperations.aggregate(aggregation, targetType);
135+
if(ReflectionUtils.isVoid(typeToRead)) {
136+
return null;
137+
}
129138

130139
if (isRawAggregationResult) {
131140
return result;
@@ -167,18 +176,22 @@ private boolean isSimpleReturnType(Class<?> targetType) {
167176
return MongoSimpleTypes.HOLDER.isSimpleType(targetType);
168177
}
169178

170-
List<AggregationOperation> computePipeline(MongoQueryMethod method, ConvertingParameterAccessor accessor) {
171-
return parseAggregationPipeline(method.getAnnotatedAggregation(), accessor);
179+
AggregationPipeline computePipeline(MongoQueryMethod method, ConvertingParameterAccessor accessor) {
180+
return new AggregationPipeline(parseAggregationPipeline(method.getAnnotatedAggregation(), accessor));
172181
}
173182

174-
private AggregationOptions computeOptions(MongoQueryMethod method, ConvertingParameterAccessor accessor) {
183+
private AggregationOptions computeOptions(MongoQueryMethod method, ConvertingParameterAccessor accessor, AggregationPipeline pipeline) {
175184

176185
AggregationOptions.Builder builder = Aggregation.newAggregationOptions();
177186

178187
AggregationUtils.applyCollation(builder, method.getAnnotatedCollation(), accessor, method.getParameters(),
179188
expressionParser, evaluationContextProvider);
180189
AggregationUtils.applyMeta(builder, method);
181190

191+
if(ReflectionUtils.isVoid(method.getReturnType().getType()) && pipeline.isOutOrMerge()) {
192+
builder.skipOutput();
193+
}
194+
182195
return builder.build();
183196
}
184197

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedAggregationUnitTests.java

+28
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ public class ReactiveStringBasedAggregationUnitTests {
7878

7979
private static final String RAW_SORT_STRING = "{ '$sort' : { 'lastname' : -1 } }";
8080
private static final String RAW_GROUP_BY_LASTNAME_STRING = "{ '$group': { '_id' : '$lastname', 'names' : { '$addToSet' : '$firstname' } } }";
81+
private static final String RAW_OUT = "{ '$out' : 'authors' }";
8182
private static final String GROUP_BY_LASTNAME_STRING_WITH_PARAMETER_PLACEHOLDER = "{ '$group': { '_id' : '$lastname', names : { '$addToSet' : '$?0' } } }";
8283
private static final String GROUP_BY_LASTNAME_STRING_WITH_SPEL_PARAMETER_PLACEHOLDER = "{ '$group': { '_id' : '$lastname', 'names' : { '$addToSet' : '$?#{[0]}' } } }";
8384

@@ -188,6 +189,22 @@ private AggregationInvocation executeAggregation(String name, Object... args) {
188189
return new AggregationInvocation(aggregationCaptor.getValue(), targetTypeCaptor.getValue(), result);
189190
}
190191

192+
@Test // GH-4088
193+
void aggregateWithVoidReturnTypeSkipsResultOnOutStage() {
194+
195+
AggregationInvocation invocation = executeAggregation("outSkipResult");
196+
197+
assertThat(skipResultsOf(invocation)).isTrue();
198+
}
199+
200+
@Test // GH-4088
201+
void aggregateWithOutStageDoesNotSkipResults() {
202+
203+
AggregationInvocation invocation = executeAggregation("outDoNotSkipResult");
204+
205+
assertThat(skipResultsOf(invocation)).isFalse();
206+
}
207+
191208
private ReactiveStringBasedAggregation createAggregationForMethod(String name, Class<?>... parameters) {
192209

193210
Method method = ClassUtils.getMethod(SampleRepository.class, name, parameters);
@@ -216,6 +233,11 @@ private Collation collationOf(AggregationInvocation invocation) {
216233
: null;
217234
}
218235

236+
private Boolean skipResultsOf(AggregationInvocation invocation) {
237+
return invocation.aggregation.getOptions() != null ? invocation.aggregation.getOptions().isSkipResults()
238+
: false;
239+
}
240+
219241
private Class<?> targetTypeOf(AggregationInvocation invocation) {
220242
return invocation.getTargetType();
221243
}
@@ -243,6 +265,12 @@ private interface SampleRepository extends ReactiveCrudRepository<Person, Long>
243265

244266
@Aggregation(pipeline = RAW_GROUP_BY_LASTNAME_STRING, collation = "de_AT")
245267
Mono<PersonAggregate> aggregateWithCollation(Collation collation);
268+
269+
@Aggregation(pipeline = { RAW_GROUP_BY_LASTNAME_STRING, RAW_OUT })
270+
Flux<Person> outDoNotSkipResult();
271+
272+
@Aggregation(pipeline = { RAW_GROUP_BY_LASTNAME_STRING, RAW_OUT })
273+
Mono<Void> outSkipResult();
246274
}
247275

248276
static class PersonAggregate {

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StringBasedAggregationUnitTests.java

+28
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ public class StringBasedAggregationUnitTests {
9191

9292
private static final String RAW_SORT_STRING = "{ '$sort' : { 'lastname' : -1 } }";
9393
private static final String RAW_GROUP_BY_LASTNAME_STRING = "{ '$group': { '_id' : '$lastname', 'names' : { '$addToSet' : '$firstname' } } }";
94+
private static final String RAW_OUT = "{ '$out' : 'authors' }";
9495
private static final String GROUP_BY_LASTNAME_STRING_WITH_PARAMETER_PLACEHOLDER = "{ '$group': { '_id' : '$lastname', names : { '$addToSet' : '$?0' } } }";
9596
private static final String GROUP_BY_LASTNAME_STRING_WITH_SPEL_PARAMETER_PLACEHOLDER = "{ '$group': { '_id' : '$lastname', 'names' : { '$addToSet' : '$?#{[0]}' } } }";
9697

@@ -260,6 +261,22 @@ void aggregateRaisesErrorOnInvalidReturnType() {
260261
.withMessageContaining("Page");
261262
}
262263

264+
@Test // GH-4088
265+
void aggregateWithVoidReturnTypeSkipsResultOnOutStage() {
266+
267+
AggregationInvocation invocation = executeAggregation("outSkipResult");
268+
269+
assertThat(skipResultsOf(invocation)).isTrue();
270+
}
271+
272+
@Test // GH-4088
273+
void aggregateWithOutStageDoesNotSkipResults() {
274+
275+
AggregationInvocation invocation = executeAggregation("outDoNotSkipResult");
276+
277+
assertThat(skipResultsOf(invocation)).isFalse();
278+
}
279+
263280
private AggregationInvocation executeAggregation(String name, Object... args) {
264281

265282
Class<?>[] argTypes = Arrays.stream(args).map(Object::getClass).toArray(Class[]::new);
@@ -302,6 +319,11 @@ private Collation collationOf(AggregationInvocation invocation) {
302319
: null;
303320
}
304321

322+
private Boolean skipResultsOf(AggregationInvocation invocation) {
323+
return invocation.aggregation.getOptions() != null ? invocation.aggregation.getOptions().isSkipResults()
324+
: false;
325+
}
326+
305327
private Class<?> targetTypeOf(AggregationInvocation invocation) {
306328
return invocation.getTargetType();
307329
}
@@ -350,6 +372,12 @@ private interface SampleRepository extends Repository<Person, Long> {
350372

351373
@Aggregation(RAW_GROUP_BY_LASTNAME_STRING)
352374
String simpleReturnType();
375+
376+
@Aggregation(pipeline = { RAW_GROUP_BY_LASTNAME_STRING, RAW_OUT })
377+
List<Person> outDoNotSkipResult();
378+
379+
@Aggregation(pipeline = { RAW_GROUP_BY_LASTNAME_STRING, RAW_OUT })
380+
void outSkipResult();
353381
}
354382

355383
private interface UnsupportedRepository extends Repository<Person, Long> {

0 commit comments

Comments
 (0)