Skip to content

SkipOutput for void methods using declarative Aggregations having $out stage #4341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.1.0-SNAPSHOT</version>
<version>4.1.x-GH-4088-SNAPSHOT</version>
<packaging>pom</packaging>

<name>Spring Data MongoDB</name>
Expand Down
2 changes: 1 addition & 1 deletion spring-data-mongodb-benchmarks/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.1.0-SNAPSHOT</version>
<version>4.1.x-GH-4088-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
2 changes: 1 addition & 1 deletion spring-data-mongodb-distribution/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.1.0-SNAPSHOT</version>
<version>4.1.x-GH-4088-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
2 changes: 1 addition & 1 deletion spring-data-mongodb/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.1.0-SNAPSHOT</version>
<version>4.1.x-GH-4088-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ public Object execute(Object[] parameters) {
* @param accessor for providing invocation arguments. Never {@literal null}.
* @param typeToRead the desired component target type. Can be {@literal null}.
*/
@Nullable
protected Object doExecute(MongoQueryMethod method, ResultProcessor processor, ConvertingParameterAccessor accessor,
@Nullable Class<?> typeToRead) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
import org.springframework.data.mongodb.core.aggregation.AggregationOptions;
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.convert.MongoConverter;
import org.springframework.data.mongodb.core.query.Collation;
import org.springframework.data.mongodb.core.query.Meta;
Expand Down Expand Up @@ -109,7 +110,7 @@ static AggregationOptions.Builder applyMeta(AggregationOptions.Builder builder,
* @param accessor
* @param targetType
*/
static void appendSortIfPresent(List<AggregationOperation> aggregationPipeline, ConvertingParameterAccessor accessor,
static void appendSortIfPresent(AggregationPipeline aggregationPipeline, ConvertingParameterAccessor accessor,
Class<?> targetType) {

if (accessor.getSort().isUnsorted()) {
Expand All @@ -134,7 +135,7 @@ static void appendSortIfPresent(List<AggregationOperation> aggregationPipeline,
* @param aggregationPipeline
* @param accessor
*/
static void appendLimitAndOffsetIfPresent(List<AggregationOperation> aggregationPipeline,
static void appendLimitAndOffsetIfPresent(AggregationPipeline aggregationPipeline,
ConvertingParameterAccessor accessor) {
appendLimitAndOffsetIfPresent(aggregationPipeline, accessor, LongUnaryOperator.identity(),
IntUnaryOperator.identity());
Expand All @@ -150,7 +151,7 @@ static void appendLimitAndOffsetIfPresent(List<AggregationOperation> aggregation
* @param limitOperator
* @since 3.3
*/
static void appendLimitAndOffsetIfPresent(List<AggregationOperation> aggregationPipeline,
static void appendLimitAndOffsetIfPresent(AggregationPipeline aggregationPipeline,
ConvertingParameterAccessor accessor, LongUnaryOperator offsetOperator, IntUnaryOperator limitOperator) {

Pageable pageable = accessor.getPageable();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.springframework.data.mongodb.core.query.UpdateDefinition;
import org.springframework.data.support.PageableExecutionUtils;
import org.springframework.data.util.TypeInformation;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;

Expand All @@ -55,6 +56,7 @@
@FunctionalInterface
interface MongoQueryExecution {

@Nullable
Object execute(Query query);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package org.springframework.data.mongodb.repository.query;

import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.util.ReflectionUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

Expand Down Expand Up @@ -81,7 +83,7 @@ protected Publisher<Object> doExecute(ReactiveMongoQueryMethod method, ResultPro
Class<?> sourceType = method.getDomainClass();
Class<?> targetType = typeToRead;

List<AggregationOperation> pipeline = it;
AggregationPipeline pipeline = new AggregationPipeline(it);

AggregationUtils.appendSortIfPresent(pipeline, accessor, typeToRead);
AggregationUtils.appendLimitAndOffsetIfPresent(pipeline, accessor);
Expand All @@ -93,10 +95,13 @@ protected Publisher<Object> doExecute(ReactiveMongoQueryMethod method, ResultPro
targetType = Document.class;
}

AggregationOptions options = computeOptions(method, accessor);
TypedAggregation<?> aggregation = new TypedAggregation<>(sourceType, pipeline, options);
AggregationOptions options = computeOptions(method, accessor, pipeline);
TypedAggregation<?> aggregation = new TypedAggregation<>(sourceType, pipeline.getOperations(), options);

Flux<?> flux = reactiveMongoOperations.aggregate(aggregation, targetType);
if(ReflectionUtils.isVoid(typeToRead)) {
return flux.then();
}

if (isSimpleReturnType && !isRawReturnType) {
flux = flux.handle((item, sink) -> {
Expand All @@ -121,13 +126,16 @@ private Mono<List<AggregationOperation>> computePipeline(ConvertingParameterAcce
return parseAggregationPipeline(getQueryMethod().getAnnotatedAggregation(), accessor);
}

private AggregationOptions computeOptions(MongoQueryMethod method, ConvertingParameterAccessor accessor) {
private AggregationOptions computeOptions(MongoQueryMethod method, ConvertingParameterAccessor accessor, AggregationPipeline pipeline) {

AggregationOptions.Builder builder = Aggregation.newAggregationOptions();

AggregationUtils.applyCollation(builder, method.getAnnotatedCollation(), accessor, method.getParameters(),
expressionParser, evaluationContextProvider);
AggregationUtils.applyMeta(builder, method);
if(ReflectionUtils.isVoid(method.getReturnType().getComponentType().getType()) && pipeline.isOutOrMerge()) {
builder.skipOutput();
}

return builder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,20 @@
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
import org.springframework.data.mongodb.core.aggregation.AggregationOptions;
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
import org.springframework.data.mongodb.core.convert.MongoConverter;
import org.springframework.data.mongodb.core.mapping.MongoSimpleTypes;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.repository.query.QueryMethodEvaluationContextProvider;
import org.springframework.data.repository.query.ResultProcessor;
import org.springframework.data.util.ReflectionUtils;
import org.springframework.expression.ExpressionParser;
import org.springframework.lang.Nullable;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ObjectUtils;

/**
* {@link AbstractMongoQuery} implementation to run string-based aggregations using
Expand Down Expand Up @@ -84,13 +89,14 @@ public StringBasedAggregation(MongoQueryMethod method, MongoOperations mongoOper
* @see org.springframework.data.mongodb.repository.query.AbstractReactiveMongoQuery#doExecute(org.springframework.data.mongodb.repository.query.MongoQueryMethod, org.springframework.data.repository.query.ResultProcessor, org.springframework.data.mongodb.repository.query.ConvertingParameterAccessor, java.lang.Class)
*/
@Override
@Nullable
protected Object doExecute(MongoQueryMethod method, ResultProcessor resultProcessor,
ConvertingParameterAccessor accessor, Class<?> typeToRead) {

Class<?> sourceType = method.getDomainClass();
Class<?> targetType = typeToRead;

List<AggregationOperation> pipeline = computePipeline(method, accessor);
AggregationPipeline pipeline = computePipeline(method, accessor);
AggregationUtils.appendSortIfPresent(pipeline, accessor, typeToRead);

if (method.isSliceQuery()) {
Expand All @@ -111,8 +117,8 @@ protected Object doExecute(MongoQueryMethod method, ResultProcessor resultProces
targetType = method.getReturnType().getRequiredActualType().getRequiredComponentType().getType();
}

AggregationOptions options = computeOptions(method, accessor);
TypedAggregation<?> aggregation = new TypedAggregation<>(sourceType, pipeline, options);
AggregationOptions options = computeOptions(method, accessor, pipeline);
TypedAggregation<?> aggregation = new TypedAggregation<>(sourceType, pipeline.getOperations(), options);

if (method.isStreamQuery()) {

Expand All @@ -126,6 +132,9 @@ protected Object doExecute(MongoQueryMethod method, ResultProcessor resultProces
}

AggregationResults<Object> result = (AggregationResults<Object>) mongoOperations.aggregate(aggregation, targetType);
if(ReflectionUtils.isVoid(typeToRead)) {
return null;
}

if (isRawAggregationResult) {
return result;
Expand Down Expand Up @@ -167,18 +176,22 @@ private boolean isSimpleReturnType(Class<?> targetType) {
return MongoSimpleTypes.HOLDER.isSimpleType(targetType);
}

List<AggregationOperation> computePipeline(MongoQueryMethod method, ConvertingParameterAccessor accessor) {
return parseAggregationPipeline(method.getAnnotatedAggregation(), accessor);
AggregationPipeline computePipeline(MongoQueryMethod method, ConvertingParameterAccessor accessor) {
return new AggregationPipeline(parseAggregationPipeline(method.getAnnotatedAggregation(), accessor));
}

private AggregationOptions computeOptions(MongoQueryMethod method, ConvertingParameterAccessor accessor) {
private AggregationOptions computeOptions(MongoQueryMethod method, ConvertingParameterAccessor accessor, AggregationPipeline pipeline) {

AggregationOptions.Builder builder = Aggregation.newAggregationOptions();

AggregationUtils.applyCollation(builder, method.getAnnotatedCollation(), accessor, method.getParameters(),
expressionParser, evaluationContextProvider);
AggregationUtils.applyMeta(builder, method);

if(ReflectionUtils.isVoid(method.getReturnType().getType()) && pipeline.isOutOrMerge()) {
builder.skipOutput();
}

return builder.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ public class ReactiveStringBasedAggregationUnitTests {

private static final String RAW_SORT_STRING = "{ '$sort' : { 'lastname' : -1 } }";
private static final String RAW_GROUP_BY_LASTNAME_STRING = "{ '$group': { '_id' : '$lastname', 'names' : { '$addToSet' : '$firstname' } } }";
private static final String RAW_OUT = "{ '$out' : 'authors' }";
private static final String GROUP_BY_LASTNAME_STRING_WITH_PARAMETER_PLACEHOLDER = "{ '$group': { '_id' : '$lastname', names : { '$addToSet' : '$?0' } } }";
private static final String GROUP_BY_LASTNAME_STRING_WITH_SPEL_PARAMETER_PLACEHOLDER = "{ '$group': { '_id' : '$lastname', 'names' : { '$addToSet' : '$?#{[0]}' } } }";

Expand Down Expand Up @@ -188,6 +189,22 @@ private AggregationInvocation executeAggregation(String name, Object... args) {
return new AggregationInvocation(aggregationCaptor.getValue(), targetTypeCaptor.getValue(), result);
}

@Test // GH-4088
void aggregateWithVoidReturnTypeSkipsResultOnOutStage() {

AggregationInvocation invocation = executeAggregation("outSkipResult");

assertThat(skipResultsOf(invocation)).isTrue();
}

@Test // GH-4088
void aggregateWithOutStageDoesNotSkipResults() {

AggregationInvocation invocation = executeAggregation("outDoNotSkipResult");

assertThat(skipResultsOf(invocation)).isFalse();
}

private ReactiveStringBasedAggregation createAggregationForMethod(String name, Class<?>... parameters) {

Method method = ClassUtils.getMethod(SampleRepository.class, name, parameters);
Expand Down Expand Up @@ -216,6 +233,11 @@ private Collation collationOf(AggregationInvocation invocation) {
: null;
}

private Boolean skipResultsOf(AggregationInvocation invocation) {
return invocation.aggregation.getOptions() != null ? invocation.aggregation.getOptions().isSkipResults()
: false;
}

private Class<?> targetTypeOf(AggregationInvocation invocation) {
return invocation.getTargetType();
}
Expand Down Expand Up @@ -243,6 +265,12 @@ private interface SampleRepository extends ReactiveCrudRepository<Person, Long>

@Aggregation(pipeline = RAW_GROUP_BY_LASTNAME_STRING, collation = "de_AT")
Mono<PersonAggregate> aggregateWithCollation(Collation collation);

@Aggregation(pipeline = { RAW_GROUP_BY_LASTNAME_STRING, RAW_OUT })
Flux<Person> outDoNotSkipResult();

@Aggregation(pipeline = { RAW_GROUP_BY_LASTNAME_STRING, RAW_OUT })
Mono<Void> outSkipResult();
}

static class PersonAggregate {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ public class StringBasedAggregationUnitTests {

private static final String RAW_SORT_STRING = "{ '$sort' : { 'lastname' : -1 } }";
private static final String RAW_GROUP_BY_LASTNAME_STRING = "{ '$group': { '_id' : '$lastname', 'names' : { '$addToSet' : '$firstname' } } }";
private static final String RAW_OUT = "{ '$out' : 'authors' }";
private static final String GROUP_BY_LASTNAME_STRING_WITH_PARAMETER_PLACEHOLDER = "{ '$group': { '_id' : '$lastname', names : { '$addToSet' : '$?0' } } }";
private static final String GROUP_BY_LASTNAME_STRING_WITH_SPEL_PARAMETER_PLACEHOLDER = "{ '$group': { '_id' : '$lastname', 'names' : { '$addToSet' : '$?#{[0]}' } } }";

Expand Down Expand Up @@ -260,6 +261,22 @@ void aggregateRaisesErrorOnInvalidReturnType() {
.withMessageContaining("Page");
}

@Test // GH-4088
void aggregateWithVoidReturnTypeSkipsResultOnOutStage() {

AggregationInvocation invocation = executeAggregation("outSkipResult");

assertThat(skipResultsOf(invocation)).isTrue();
}

@Test // GH-4088
void aggregateWithOutStageDoesNotSkipResults() {

AggregationInvocation invocation = executeAggregation("outDoNotSkipResult");

assertThat(skipResultsOf(invocation)).isFalse();
}

private AggregationInvocation executeAggregation(String name, Object... args) {

Class<?>[] argTypes = Arrays.stream(args).map(Object::getClass).toArray(Class[]::new);
Expand Down Expand Up @@ -302,6 +319,11 @@ private Collation collationOf(AggregationInvocation invocation) {
: null;
}

private Boolean skipResultsOf(AggregationInvocation invocation) {
return invocation.aggregation.getOptions() != null ? invocation.aggregation.getOptions().isSkipResults()
: false;
}

private Class<?> targetTypeOf(AggregationInvocation invocation) {
return invocation.getTargetType();
}
Expand Down Expand Up @@ -350,6 +372,12 @@ private interface SampleRepository extends Repository<Person, Long> {

@Aggregation(RAW_GROUP_BY_LASTNAME_STRING)
String simpleReturnType();

@Aggregation(pipeline = { RAW_GROUP_BY_LASTNAME_STRING, RAW_OUT })
List<Person> outDoNotSkipResult();

@Aggregation(pipeline = { RAW_GROUP_BY_LASTNAME_STRING, RAW_OUT })
void outSkipResult();
}

private interface UnsupportedRepository extends Repository<Person, Long> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ public interface PersonRepository extends CrudRepository<Person, String> {

@Aggregation("{ '$project': { '_id' : '$lastname' } }")
List<String> findAllLastnames(); <9>

@Aggregation(pipeline = {
"{ $group : { _id : '$author', books: { $push: '$title' } } }",
"{ $out : 'authors' }"
})
void groupAndOutSkippingOutput(); <10>
}
----
[source,java]
Expand Down Expand Up @@ -75,6 +81,7 @@ Therefore, the `Sort` properties are mapped against the methods return type `Per
To gain more control, you might consider `AggregationResult` as method return type as shown in <7>.
<8> Obtain the raw `AggregationResults` mapped to the generic target wrapper type `SumValue` or `org.bson.Document`.
<9> Like in <6>, a single value can be directly obtained from multiple result ``Document``s.
<10> Skips the output of the `$out` stage when return type is `void`.
====

In some scenarios, aggregations might require additional options, such as a maximum run time, additional log comments, or the permission to temporarily write data to disk.
Expand Down