diff --git a/pom.xml b/pom.xml index b78554647f..d0e776fb89 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-mongodb-parent - 4.1.0-SNAPSHOT + 4.1.x-GH-4070-SNAPSHOT pom Spring Data MongoDB diff --git a/spring-data-mongodb-benchmarks/pom.xml b/spring-data-mongodb-benchmarks/pom.xml index 1b2a1390e6..025821e1d9 100644 --- a/spring-data-mongodb-benchmarks/pom.xml +++ b/spring-data-mongodb-benchmarks/pom.xml @@ -7,7 +7,7 @@ org.springframework.data spring-data-mongodb-parent - 4.1.0-SNAPSHOT + 4.1.x-GH-4070-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb-distribution/pom.xml b/spring-data-mongodb-distribution/pom.xml index 8db8d798fb..1dd9076b9a 100644 --- a/spring-data-mongodb-distribution/pom.xml +++ b/spring-data-mongodb-distribution/pom.xml @@ -15,7 +15,7 @@ org.springframework.data spring-data-mongodb-parent - 4.1.0-SNAPSHOT + 4.1.x-GH-4070-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/pom.xml b/spring-data-mongodb/pom.xml index 9a57f7eb52..24b0d18439 100644 --- a/spring-data-mongodb/pom.xml +++ b/spring-data-mongodb/pom.xml @@ -13,7 +13,7 @@ org.springframework.data spring-data-mongodb-parent - 4.1.0-SNAPSHOT + 4.1.x-GH-4070-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationVariable.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationVariable.java new file mode 100644 index 0000000000..5c10b32a86 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationVariable.java @@ -0,0 +1,130 @@ +/* + * Copyright 2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.aggregation; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; + +/** + * A special field that points to a variable {@code $$} expression. + * + * @author Christoph Strobl + * @since 4.1 + */ +public interface AggregationVariable extends Field { + + String PREFIX = "$$"; + + /** + * @return {@literal true} if the fields {@link #getName() name} does not match the defined {@link #getTarget() + * target}. + */ + default boolean isAliased() { + return !ObjectUtils.nullSafeEquals(getName(), getTarget()); + } + + @Override + default String getName() { + return getTarget(); + } + + default boolean isInternal() { + return false; + } + + /** + * Create a new {@link AggregationVariable} for the given name. + *

+ * Variables start with {@code $$}. If not, the given value gets prefixed with {@code $$}. + * + * @param value must not be {@literal null}. + * @return new instance of {@link AggregationVariable}. + * @throws IllegalArgumentException if given value is {@literal null}. + */ + static AggregationVariable variable(String value) { + + Assert.notNull(value, "Value must not be null"); + return new AggregationVariable() { + + private final String val = AggregationVariable.prefixVariable(value); + + @Override + public String getTarget() { + return val; + } + }; + } + + /** + * Create a new {@link #isInternal() local} {@link AggregationVariable} for the given name. + *

+ * Variables start with {@code $$}. If not, the given value gets prefixed with {@code $$}. + * + * @param value must not be {@literal null}. + * @return new instance of {@link AggregationVariable}. + * @throws IllegalArgumentException if given value is {@literal null}. + */ + static AggregationVariable localVariable(String value) { + + Assert.notNull(value, "Value must not be null"); + return new AggregationVariable() { + + private final String val = AggregationVariable.prefixVariable(value); + + @Override + public String getTarget() { + return val; + } + + @Override + public boolean isInternal() { + return true; + } + }; + } + + /** + * Check if the given field name reference may be variable. + * + * @param fieldRef can be {@literal null}. + * @return true if given value matches the variable identification pattern. + */ + static boolean isVariable(@Nullable String fieldRef) { + return fieldRef != null && fieldRef.stripLeading().matches("^\\$\\$\\w.*"); + } + + /** + * Check if the given field may be variable. + * + * @param field can be {@literal null}. + * @return true if given {@link Field field} is an {@link AggregationVariable} or if its value is a + * {@link #isVariable(String) variable}. + */ + static boolean isVariable(Field field) { + + if (field instanceof AggregationVariable) { + return true; + } + return isVariable(field.getTarget()); + } + + private static String prefixVariable(String variable) { + + var trimmed = variable.stripLeading(); + return trimmed.startsWith(PREFIX) ? trimmed : (PREFIX + trimmed); + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java index 1ea699852f..94f7fc6736 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java @@ -1471,24 +1471,15 @@ public interface AsBuilder { } } - public enum Variable implements Field { + public enum Variable implements AggregationVariable { THIS { - @Override - public String getName() { - return "$$this"; - } @Override public String getTarget() { return "$$this"; } - @Override - public boolean isAliased() { - return false; - } - @Override public String toString() { return getName(); @@ -1496,27 +1487,23 @@ public String toString() { }, VALUE { - @Override - public String getName() { - return "$$value"; - } @Override public String getTarget() { return "$$value"; } - @Override - public boolean isAliased() { - return false; - } - @Override public String toString() { return getName(); } }; + @Override + public boolean isInternal() { + return true; + } + /** * Create a {@link Field} reference to a given {@literal property} prefixed with the {@link Variable} identifier. * eg. {@code $$value.product} @@ -1548,6 +1535,16 @@ public String toString() { } }; } + + public static boolean isVariable(Field field) { + + for (Variable var : values()) { + if (field.getTarget().startsWith(var.getTarget())) { + return true; + } + } + return false; + } } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Fields.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Fields.java index 4dac936871..277b447a9b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Fields.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Fields.java @@ -245,7 +245,7 @@ public AggregationField(String name, @Nullable String target) { private static String cleanUp(String source) { - if (SystemVariable.isReferingToSystemVariable(source)) { + if (AggregationVariable.isVariable(source)) { return source; } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/SystemVariable.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/SystemVariable.java index 15c5bf6e90..bb3cc49808 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/SystemVariable.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/SystemVariable.java @@ -24,7 +24,7 @@ * @author Christoph Strobl * @see Aggregation Variables. */ -public enum SystemVariable { +public enum SystemVariable implements AggregationVariable { /** * Variable for the current datetime. @@ -82,8 +82,6 @@ public enum SystemVariable { */ SEARCH_META; - private static final String PREFIX = "$$"; - /** * Return {@literal true} if the given {@code fieldRef} denotes a well-known system variable, {@literal false} * otherwise. @@ -93,13 +91,12 @@ public enum SystemVariable { */ public static boolean isReferingToSystemVariable(@Nullable String fieldRef) { - if (fieldRef == null || !fieldRef.startsWith(PREFIX) || fieldRef.length() <= 2) { + String candidate = variableNameFrom(fieldRef); + if (candidate == null) { return false; } - int indexOfFirstDot = fieldRef.indexOf('.'); - String candidate = fieldRef.substring(2, indexOfFirstDot == -1 ? fieldRef.length() : indexOfFirstDot); - + candidate = candidate.startsWith(PREFIX) ? candidate.substring(2) : candidate; for (SystemVariable value : values()) { if (value.name().equals(candidate)) { return true; @@ -113,4 +110,20 @@ public static boolean isReferingToSystemVariable(@Nullable String fieldRef) { public String toString() { return PREFIX.concat(name()); } + + @Override + public String getTarget() { + return toString(); + } + + @Nullable + static String variableNameFrom(@Nullable String fieldRef) { + + if (fieldRef == null || !fieldRef.startsWith(PREFIX) || fieldRef.length() <= 2) { + return null; + } + + int indexOfFirstDot = fieldRef.indexOf('.'); + return indexOfFirstDot == -1 ? fieldRef : fieldRef.substring(2, indexOfFirstDot); + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContext.java index 7dd07e5940..f179165984 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContext.java @@ -133,7 +133,7 @@ public AggregationOperationContext continueOnMissingFieldReference(Class type protected FieldReference getReferenceFor(Field field) { - if(entity.getNullable() == null) { + if(entity.getNullable() == null || AggregationVariable.isVariable(field)) { return new DirectFieldReference(new ExposedField(field, true)); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationVariableUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationVariableUnitTests.java new file mode 100644 index 0000000000..a4af334013 --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationVariableUnitTests.java @@ -0,0 +1,95 @@ +/* + * Copyright 2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.aggregation; + +import static org.assertj.core.api.Assertions.*; + +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +/** + * @author Christoph Strobl + */ +class AggregationVariableUnitTests { + + @Test // GH-4070 + void variableErrorsOnNullValue() { + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> AggregationVariable.variable(null)); + } + + @Test // GH-4070 + void createsVariable() { + + var variable = AggregationVariable.variable("$$now"); + + assertThat(variable.getTarget()).isEqualTo("$$now"); + assertThat(variable.isInternal()).isFalse(); + } + + @Test // GH-4070 + void prefixesVariableIfNeeded() { + + var variable = AggregationVariable.variable("this"); + + assertThat(variable.getTarget()).isEqualTo("$$this"); + } + + @Test // GH-4070 + void localVariableErrorsOnNullValue() { + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> AggregationVariable.localVariable(null)); + } + + @Test // GH-4070 + void localVariable() { + + var variable = AggregationVariable.localVariable("$$this"); + + assertThat(variable.getTarget()).isEqualTo("$$this"); + assertThat(variable.isInternal()).isTrue(); + } + + @Test // GH-4070 + void prefixesLocalVariableIfNeeded() { + + var variable = AggregationVariable.localVariable("this"); + + assertThat(variable.getTarget()).isEqualTo("$$this"); + } + + @Test // GH-4070 + void isVariableReturnsTrueForAggregationVariableTypes() { + + var variable = Mockito.mock(AggregationVariable.class); + + assertThat(AggregationVariable.isVariable(variable)).isTrue(); + } + + @Test // GH-4070 + void isVariableReturnsTrueForFieldThatTargetsVariable() { + + var variable = Fields.field("value", "$$this"); + + assertThat(AggregationVariable.isVariable(variable)).isTrue(); + } + + @Test // GH-4070 + void isVariableReturnsFalseForFieldThatDontTargetsVariable() { + + var variable = Fields.field("value", "$this"); + + assertThat(AggregationVariable.isVariable(variable)).isFalse(); + } +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContextUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContextUnitTests.java index c770e22fba..9ac1606559 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContextUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContextUnitTests.java @@ -39,8 +39,11 @@ import org.springframework.data.domain.Sort; import org.springframework.data.domain.Sort.Direction; import org.springframework.data.mapping.MappingException; +import org.springframework.data.mongodb.core.aggregation.ArrayOperators.Reduce; +import org.springframework.data.mongodb.core.aggregation.ArrayOperators.Reduce.Variable; import org.springframework.data.mongodb.core.aggregation.ExposedFields.DirectFieldReference; import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField; +import org.springframework.data.mongodb.core.aggregation.SetOperators.SetUnion; import org.springframework.data.mongodb.core.convert.DbRefResolver; import org.springframework.data.mongodb.core.convert.MappingMongoConverter; import org.springframework.data.mongodb.core.convert.MongoCustomConversions; @@ -455,6 +458,30 @@ void rendersProjectOnNestedPrefixedUnwrappedFieldWithAtFieldAnnotationCorrectly( .isEqualTo(new Document("val", "$withUnwrapped.prefix-with-at-field-annotation")); } + @Test // GH-4070 + void rendersLocalVariables() { + + AggregationOperationContext context = getContext(WithLists.class); + + Document agg = newAggregation(WithLists.class, + project() + .and(Reduce.arrayOf("listOfListOfString").withInitialValue(field("listOfString")) + .reduce(SetUnion.arrayAsSet(Variable.VALUE.getTarget()).union(Variable.THIS.getTarget()))) + .as("listOfString")).toDocument("collection", context); + + assertThat(getPipelineElementFromAggregationAt(agg, 0).get("$project")).isEqualTo(Document.parse(""" + { + "listOfString" : { + "$reduce" : { + "in" : { "$setUnion" : ["$$value", "$$this"] }, + "initialValue" : "$listOfString", + "input" : "$listOfListOfString" + } + } + } + """)); + } + @org.springframework.data.mongodb.core.mapping.Document(collection = "person") @AllArgsConstructor public static class FooPerson { @@ -553,4 +580,9 @@ static class UnwrappableType { @org.springframework.data.mongodb.core.mapping.Field("with-at-field-annotation") // String atFieldAnnotatedValue; } + + static class WithLists { + public List listOfString; + public List> listOfListOfString; + } }