diff --git a/pom.xml b/pom.xml index a6dc167a03..fc20e0dd0a 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-commons - 4.0.0-SNAPSHOT + 4.0.0-SEARCH-RESULT-SNAPSHOT Spring Data Core Core Spring concepts underpinning every Spring Data module. diff --git a/src/main/antora/modules/ROOT/nav.adoc b/src/main/antora/modules/ROOT/nav.adoc index 3b1dbe8927..9a7ac3241f 100644 --- a/src/main/antora/modules/ROOT/nav.adoc +++ b/src/main/antora/modules/ROOT/nav.adoc @@ -7,6 +7,7 @@ ** xref:repositories/query-methods.adoc[] ** xref:repositories/definition.adoc[] ** xref:repositories/query-methods-details.adoc[] +** xref:repositories/vector-search.adoc[] ** xref:repositories/create-instances.adoc[] ** xref:repositories/custom-implementations.adoc[] ** xref:repositories/core-domain-events.adoc[] diff --git a/src/main/antora/modules/ROOT/pages/repositories/vector-search.adoc b/src/main/antora/modules/ROOT/pages/repositories/vector-search.adoc new file mode 100644 index 0000000000..15e32dccee --- /dev/null +++ b/src/main/antora/modules/ROOT/pages/repositories/vector-search.adoc @@ -0,0 +1,167 @@ +[[vector-search]] += Vector Search + +With the rise of Generative AI, Vector databases have gained strong traction in the world of databases. +These databases enable efficient storage and querying of high-dimensional vectors, making them well-suited for tasks such as semantic search, recommendation systems, and natural language understanding. + +Vector search is a technique that retrieves semantically similar data by comparing vector representations (also known as embeddings) rather than relying on traditional exact-match queries. +This approach enables intelligent, context-aware applications that go beyond keyword-based retrieval. + +In the context of Spring Data, vector search opens new possibilities for building intelligent, context-aware applications, particularly in domains like natural language processing, recommendation systems, and generative AI. +By modelling vector-based querying using familiar repository abstractions, Spring Data allows developers to seamlessly integrate similarity-based vector-capable databases with the simplicity and consistency of the Spring Data programming model. + +ifdef::vector-search-intro-include[] +include::{vector-search-intro-include}[] +endif::[] + +[[vector-search.model]] +== Vector Model + +To support vector search in a type-safe and idiomatic way, Spring Data introduces the following core abstractions: + +* <> +* <` and `SearchResult`>> +* <> + +[[vector-search.model.vector]] +=== `Vector` + +The `Vector` type represents an n-dimensional numerical embedding, typically produced by embedding models. +In Spring Data, it is defined as a lightweight wrapper around an array of floating-point numbers, ensuring immutability and consistency. +This type can be used as an input for search queries or as a property on a domain entity to store the associated vector representation. + +==== +[source,java] +---- +Vector vector = Vector.of(0.23f, 0.11f, 0.77f); +---- +==== + +Using `Vector` in your domain model removes the need to work with raw arrays or lists of numbers, providing a more type-safe and expressive way to handle vector data. +This abstraction also allows for easy integration with various vector databases and libraries. +It also allows for implementing vendor-specific optimizations such as binary or quantized vectors that do not map to a standard floating point (`float` and `double` as of https://en.wikipedia.org/wiki/IEEE_754[IEEE 754]) representation. +A domain object can have a vector property, which can be used for similarity searches. +Consider the following example: + +ifdef::vector-search-model-include[] +include::{vector-search-model-include}[] +endif::[] + +NOTE: Associating a vector with a domain object results in the vector being loaded and stored as part of the entity lifecycle, which may introduce additional overhead on retrieval and persistence operations. + +[[vector-search.model.search-result]] +=== Search Results + +The `SearchResult` type encapsulates the results of a vector similarity query. +It includes both the matched domain object and a relevance score that indicates how closely it matches the query vector. +This abstraction provides a structured way to handle result ranking and enables developers to easily work with both the data and its contextual relevance. + +ifdef::vector-search-repository-include[] +include::{vector-search-repository-include}[] +endif::[] + +In this example, the `searchByCountryAndEmbeddingNear` method returns a `SearchResults` object, which contains a list of `SearchResult` instances. +Each result includes the matched `Comment` entity and its relevance score. + +Relevance score is a numerical value that indicates how closely the matched vector aligns with the query vector. +Depending on whether a score represents distance or similarity a higher score can mean a closer match or a more distant one. + +The scoring function used to calculate this score can vary based on the underlying database, index or input parameters. + +[[vector-search.model.scoring]] +=== Score, Similarity, and Scoring Functions + +The `Score` type holds a numerical value indicating the relevance of a search result. +It can be used to rank results based on their similarity to the query vector. +The `Score` type is typically a floating-point number, and its interpretation (higher is better or lower is better) depends on the specific similarity function used. +Scores are a by-product of vector search and are not required for a successful search operation. +Score values are not part of a domain model and therefore represented best as out-of-band data. + +Generally, a Score is computed by a `ScoringFunction`. +The actual scoring function used to calculate this score can depends on the underlying database and can be obtained from a search index or input parameters. + +Spring Data support declares constants for commonly used functions such as: + +Euclidean Distance:: Calculates the straight-line distance in n-dimensional space involving the square root of the sum of squared differences. +Cosine Similarity:: Measures the angle between two vectors by calculating the Dot product first and then normalizing its result by dividing by the product of their lengths. +Dot Product:: Computes the sum of element-wise multiplications. + +The choice of similarity function can impact both the performance and semantics of the search and is often determined by the underlying database or index being used. +Spring Data adopts to the database's native scoring function capabilities and whether the score can be used to limit results. + +ifdef::vector-search-scoring-include[] +include::{vector-search-scoring-include}[] +endif::[] + +[[vector-search.methods]] +== Vector Search Methods + +Vector search methods are defined in repositories using the same conventions as standard Spring Data query methods. +These methods return `SearchResults` and require a `Vector` parameter to define the query vector. +The actual implementation depends on the actual internals of the underlying data store and its capabilities around vector search. + +NOTE: If you are new to Spring Data repositories, make sure to familiarize yourself with the xref:repositories/core-concepts.adoc[basics of repository definitions and query methods]. + +Generally, you have the choice of declaring a search method using two approaches: + +* Query Derivation +* Declaring a String-based Query + +Vector Search methods must declare a `Vector` parameter to define the query vector. + +[[vector-search.method.derivation]] +=== Derived Search Methods + +A derived search method uses the name of the method to derive the query. +Vector Search supports the following keywords to run a Vector search when declaring a search method: + +.Query predicate keywords +[options="header",cols="1,3"] +|=============== +|Logical keyword|Keyword expressions +|`NEAR`|`Near`, `IsNear` +|`WITHIN`|`Within`, `IsWithin` +|=============== + +ifdef::vector-search-method-derived-include[] +include::{vector-search-method-derived-include}[] +endif::[] + +Derived search methods are typically easier to read and maintain, as they rely on the method name to express the query intent. +However, a derived search method requires either to declare a `Score`, `Range` or `ScoreFunction` as second argument to the `Near`/`Within` keyword to limit search results by their score. + +[[vector-search.method.string]] +=== Annotated Search Methods + +Annotated methods provide full control over the query semantics and parameters. +Unlike derived methods, they do not rely on method name conventions. + +ifdef::vector-search-method-annotated-include[] +include::{vector-search-method-annotated-include}[] +endif::[] + +With more control over the actual query, Spring Data can make fewer assumptions about the query and its parameters. +For example, `Similarity` normalization uses the native score function within the query to normalize the given similarity into a score predicate value and vice versa. +If an annotated query does not define e.g. the score, then the score value in the returned `SearchResult` will be zero. + +[[vector-search.method.sorting]] +=== Sorting + +By default, search results are ordered according to their score. +You can override sorting by using the `Sort` parameter: + +.Using `Sort` in Repository Search Methods +==== +[source,java] +---- +interface CommentRepository extends Repository { + + SearchResults searchByEmbeddingNearOrderByCountry(Vector vector, Score score); + + SearchResults searchByEmbeddingWithin(Vector vector, Score score, Sort sort); +} +---- +==== + +Please note that custom sorting does not allow expressing the score as a sorting criteria. +You can only refer to domain properties. diff --git a/src/main/java/org/springframework/data/domain/Page.java b/src/main/java/org/springframework/data/domain/Page.java index 54563e3969..e0b74fee9a 100644 --- a/src/main/java/org/springframework/data/domain/Page.java +++ b/src/main/java/org/springframework/data/domain/Page.java @@ -69,4 +69,5 @@ static Page empty(Pageable pageable) { */ @Override Page map(Function converter); + } diff --git a/src/main/java/org/springframework/data/domain/Range.java b/src/main/java/org/springframework/data/domain/Range.java index fb3aa165dc..be53c621f8 100644 --- a/src/main/java/org/springframework/data/domain/Range.java +++ b/src/main/java/org/springframework/data/domain/Range.java @@ -223,7 +223,7 @@ public boolean contains(T value, Comparator comparator) { /** * Apply a mapping {@link Function} to the lower and upper boundary values. * - * @param mapper must not be {@literal null}. If the mapper returns {@code null}, then the corresponding boundary + * @param mapper must not be {@literal null}. If the mapper returns {@literal null}, then the corresponding boundary * value represents an {@link Bound#unbounded()} boundary. * @return a new {@link Range} after applying the value to the mapper. * @param target type of the mapping function. @@ -430,7 +430,7 @@ public boolean isInclusive() { /** * Apply a mapping {@link Function} to the boundary value. * - * @param mapper must not be {@literal null}. If the mapper returns {@code null}, then the boundary value + * @param mapper must not be {@literal null}. If the mapper returns {@literal null}, then the boundary value * corresponds with {@link Bound#unbounded()}. * @return a new {@link Bound} after applying the value to the mapper. * @param diff --git a/src/main/java/org/springframework/data/domain/Score.java b/src/main/java/org/springframework/data/domain/Score.java new file mode 100644 index 0000000000..9f80ff1477 --- /dev/null +++ b/src/main/java/org/springframework/data/domain/Score.java @@ -0,0 +1,118 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.domain; + +import java.io.Serializable; + +import org.springframework.util.ObjectUtils; + +/** + * Value object representing a search result score computed via a {@link ScoringFunction}. + *

+ * Encapsulates the numeric score and the scoring function used to derive it. Scores are primarily used to rank search + * results. Depending on the used {@link ScoringFunction} higher scores can indicate either a higher distance or a + * higher similarity. Use the {@link Similarity} class to indicate usage of a normalized score across representing + * effectively the similarity. + *

+ * Instances of this class are immutable and suitable for use in comparison, sorting, and range operations. + * + * @author Mark Paluch + * @since 4.0 + * @see Similarity + */ +public sealed class Score implements Serializable permits Similarity { + + private final double value; + private final ScoringFunction function; + + Score(double value, ScoringFunction function) { + this.value = value; + this.function = function; + } + + /** + * Creates a new {@link Score} from a plain {@code score} value using {@link ScoringFunction#unspecified()}. + * + * @param score the score value without a specific {@link ScoringFunction}. + * @return the new {@link Score}. + */ + public static Score of(double score) { + return of(score, ScoringFunction.unspecified()); + } + + /** + * Creates a new {@link Score} from a {@code score} value using the given {@link ScoringFunction}. + * + * @param score the score value. + * @param function the scoring function that has computed the {@code score}. + * @return the new {@link Score}. + */ + public static Score of(double score, ScoringFunction function) { + return new Score(score, function); + } + + /** + * Creates a {@link Range} from the given minimum and maximum {@code Score} values. + * + * @param min the lower score value, must not be {@literal null}. + * @param max the upper score value, must not be {@literal null}. + * @return a {@link Range} over {@link Score} bounds. + */ + public static Range between(Score min, Score max) { + return Range.from(Range.Bound.inclusive(min)).to(Range.Bound.inclusive(max)); + } + + /** + * Returns the raw numeric value of the score. + * + * @return the score value. + */ + public double getValue() { + return value; + } + + /** + * Returns the {@link ScoringFunction} that was used to compute this score. + * + * @return the associated scoring function. + */ + public ScoringFunction getFunction() { + return function; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Score other)) { + return false; + } + if (value != other.value) { + return false; + } + return ObjectUtils.nullSafeEquals(function, other.function); + } + + @Override + public int hashCode() { + return ObjectUtils.nullSafeHash(value, function); + } + + @Override + public String toString() { + return function instanceof UnspecifiedScoringFunction ? Double.toString(value) + : "%s (%s)".formatted(Double.toString(value), function.getName()); + } + +} diff --git a/src/main/java/org/springframework/data/domain/ScoringFunction.java b/src/main/java/org/springframework/data/domain/ScoringFunction.java new file mode 100644 index 0000000000..249565d719 --- /dev/null +++ b/src/main/java/org/springframework/data/domain/ScoringFunction.java @@ -0,0 +1,87 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.domain; + +/** + * Strategy interface for scoring functions. + *

+ * Implementations define how score (distance or similarity) between two vectors is computed, allowing control over + * ranking behavior in search queries. + *

+ * Provides commonly used scoring variants via static factory methods. See {@link VectorScoringFunctions} for the + * concrete implementations. + * + * @author Mark Paluch + * @since 4.0 + * @see Score + * @see Similarity + */ +public interface ScoringFunction { + + /** + * Returns the default {@code ScoringFunction} to be used when none is explicitly specified. + *

+ * This is typically used to indicate the absence of a scoring definition. + * + * @return the default {@code ScoringFunction} instance. + */ + static ScoringFunction unspecified() { + return UnspecifiedScoringFunction.INSTANCE; + } + + /** + * Return the Euclidean distance scoring function. + *

+ * Calculates the L2 norm (straight-line distance) between two vectors. + * + * @return the {@code ScoringFunction} based on Euclidean distance. + */ + static ScoringFunction euclidean() { + return VectorScoringFunctions.EUCLIDEAN; + } + + /** + * Return the cosine similarity scoring function. + *

+ * Measures the cosine of the angle between two vectors, independent of magnitude. + * + * @return the {@code ScoringFunction} based on cosine similarity. + */ + static ScoringFunction cosine() { + return VectorScoringFunctions.COSINE; + } + + /** + * Return the dot product (also known as inner product) scoring function. + *

+ * Computes the algebraic product of two vectors, considering both direction and magnitude. + * + * @return the {@code ScoringFunction} based on dot product. + */ + static ScoringFunction dotProduct() { + return VectorScoringFunctions.DOT_PRODUCT; + } + + /** + * Return the name of the scoring function. + *

+ * Typically used for display or configuration purposes. + * + * @return the identifying name of this scoring function. + */ + String getName(); + +} diff --git a/src/main/java/org/springframework/data/domain/SearchResult.java b/src/main/java/org/springframework/data/domain/SearchResult.java new file mode 100644 index 0000000000..4dd8216616 --- /dev/null +++ b/src/main/java/org/springframework/data/domain/SearchResult.java @@ -0,0 +1,128 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.domain; + +import java.io.Serial; +import java.io.Serializable; +import java.util.function.Function; + +import org.jspecify.annotations.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.ObjectUtils; + +/** + * Immutable value object representing a search result consisting of a content item and an associated {@link Score}. + *

+ * Typically used in the context of similarity-based or vector search operations where each result carries a relevance + * {@link Score}. Provides accessor methods for the content and its score, along with transformation support via + * {@link #map(Function)}. + * + * @param the type of the content object + * @author Mark Paluch + * @since 4.0 + * @see Score + * @see Similarity + */ +public final class SearchResult implements Serializable { + + private static final @Serial long serialVersionUID = 1637452570977581370L; + + private final T content; + private final Score score; + + /** + * Creates a new {@link SearchResult} with the given content and {@link Score}. + * + * @param content the result content, must not be {@literal null}. + * @param score the result score, must not be {@literal null}. + */ + public SearchResult(T content, Score score) { + + Assert.notNull(content, "Content must not be null"); + Assert.notNull(score, "Score must not be null"); + + this.content = content; + this.score = score; + } + + /** + * Create a new {@link SearchResult} with the given content and a raw score value. + * + * @param content the result content, must not be {@literal null}. + * @param score the score value. + */ + public SearchResult(T content, double score) { + this(content, Score.of(score)); + } + + /** + * Returns the content associated with this result. + */ + public T getContent() { + return this.content; + } + + /** + * Returns the {@link Score} associated with this result. + */ + public Score getScore() { + return this.score; + } + + /** + * Creates a new {@link SearchResult} by applying the given mapping {@link Function} to this result's content. + * + * @param converter the mapping function to apply to the content, must not be {@literal null}. + * @return a new {@link SearchResult} instance with converted content. + * @param the target type of the mapped content. + */ + public SearchResult map(Function converter) { + + Assert.notNull(converter, "Function must not be null"); + + return new SearchResult<>(converter.apply(getContent()), getScore()); + } + + @Override + public boolean equals(@Nullable Object o) { + + if (this == o) { + return true; + } + + if (!(o instanceof SearchResult result)) { + return false; + } + + if (!ObjectUtils.nullSafeEquals(content, result.content)) { + return false; + } + + return ObjectUtils.nullSafeEquals(score, result.score); + } + + @Override + public int hashCode() { + return ObjectUtils.nullSafeHash(content, score); + } + + @Override + public String toString() { + return String.format("SearchResult [instance: %s, score: %s]", ClassUtils.getShortName(content.getClass()), score); + } + +} diff --git a/src/main/java/org/springframework/data/domain/SearchResults.java b/src/main/java/org/springframework/data/domain/SearchResults.java new file mode 100644 index 0000000000..54e43db071 --- /dev/null +++ b/src/main/java/org/springframework/data/domain/SearchResults.java @@ -0,0 +1,133 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.domain; + +import java.io.Serializable; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.springframework.data.util.Streamable; +import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; + +/** + * Value object encapsulating a collection of {@link SearchResult} instances. + *

+ * Typically used as the result type for search or similarity queries, exposing access to the result content and + * supporting mapping operations to transform the result content type. + * + * @param the type of content contained within each {@link SearchResult}. + * @author Mark Paluch + * @since 4.0 + * @see SearchResult + */ +public class SearchResults implements Iterable>, Serializable { + + private final List> results; + + /** + * Creates a new {@link SearchResults} instance from the given list of {@link SearchResult} items. + * + * @param results the search results to encapsulate, must not be {@code null} + */ + public SearchResults(List> results) { + this.results = results; + } + + /** + * Return the actual content of the {@link SearchResult} items as an unmodifiable list. + */ + public List> getContent() { + return Collections.unmodifiableList(results); + } + + @Override + @SuppressWarnings("unchecked") + public Iterator> iterator() { + return (Iterator>) results.iterator(); + } + + /** + * Returns a sequential {@link Stream} containing {@link SearchResult} items in this {@code SearchResults} instance. + * + * @return a sequential {@link Stream} containing {@link SearchResult} items in this {@code SearchResults} instance. + */ + public Stream> stream() { + return Streamable.of(this).stream(); + } + + /** + * Returns a sequential {@link Stream} containing {@link #getContent() unwrapped content} items in this + * {@code SearchResults} instance. + * + * @return a sequential {@link Stream} containing {@link #getContent() unwrapped content} items in this + * {@code SearchResults} instance. + */ + public Stream contentStream() { + return getContent().stream().map(SearchResult::getContent); + } + + /** + * Creates a new {@code SearchResults} instance with the content of the current results mapped via the given + * {@link Function}. + * + * @param converter the mapping function to apply to the content of each {@link SearchResult}, must not be + * {@literal null}. + * @param the target type of the mapped content. + * @return a new {@code SearchResults} instance containing mapped result content. + */ + public SearchResults map(Function converter) { + + Assert.notNull(converter, "Function must not be null"); + + List> result = results.stream().map(it -> it. map(converter)).collect(Collectors.toList()); + + return new SearchResults<>(result); + } + + @Override + public boolean equals(Object o) { + + if (o == this) { + return true; + } + + if (!(o instanceof SearchResults that)) { + return false; + } + return ObjectUtils.nullSafeEquals(results, that.results); + } + + @Override + public int hashCode() { + return ObjectUtils.nullSafeHashCode(results); + } + + @Override + public String toString() { + + if (results.isEmpty()) { + return "SearchResults [empty]"; + } + + return String.format("SearchResults [size: %s]", results.size()); + } + +} diff --git a/src/main/java/org/springframework/data/domain/Similarity.java b/src/main/java/org/springframework/data/domain/Similarity.java new file mode 100644 index 0000000000..ead7180a52 --- /dev/null +++ b/src/main/java/org/springframework/data/domain/Similarity.java @@ -0,0 +1,133 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.domain; + +import org.springframework.util.Assert; + +/** + * Value object representing a normalized similarity score determined by a {@link ScoringFunction}. + *

+ * Similarity values are constrained to the range {@code [0.0, 1.0]}, where {@code 0.0} denotes the least similarity and + * {@code 1.0} the maximum similarity. This normalization allows for consistent comparison of similarity scores across + * different scoring models and systems. + *

+ * Primarily used in vector search and approximate nearest neighbor arrangements where results are ranked based on + * normalized relevance. Vector searches typically return a collection of results ordered by their similarity to the + * query vector. + *

+ * This class is designed for use in information retrieval contexts, recommendation systems, and other applications + * requiring normalized comparison of results. + *

+ * A {@code Similarity} instance includes both the similarity {@code value} and information about the + * {@link ScoringFunction} used to generate it, providing context for proper interpretation of the score. + *

+ * Instances are immutable and support range-based comparisons, making them suitable for filtering and ranking + * operations. The class extends {@link Score} to inherit common scoring functionality while adding similarity-specific + * semantics. + * + * @author Mark Paluch + * @since 4.0 + * @see Score + */ +public final class Similarity extends Score { + + private Similarity(double value, ScoringFunction function) { + super(value, function); + } + + /** + * Creates a new {@link Similarity} from a plain {@code similarity} value using {@link ScoringFunction#unspecified()}. + * + * @param similarity the similarity value without a specific {@link ScoringFunction}, ranging between {@code 0} and + * {@code 1}. + * @return the new {@link Similarity}. + */ + public static Similarity of(double similarity) { + return of(similarity, ScoringFunction.unspecified()); + } + + /** + * Creates a new {@link Similarity} from a raw value and the associated {@link ScoringFunction}. + * + * @param similarity the similarity value in the {@code [0.0, 1.0]} range. + * @param function the scoring function that produced this similarity. + * @return a new {@link Similarity} instance. + * @throws IllegalArgumentException if the value is outside the allowed range. + */ + public static Similarity of(double similarity, ScoringFunction function) { + + Assert.isTrue(similarity >= 0.0 && similarity <= 1.0, "Similarity must be in [0,1] range."); + + return new Similarity(similarity, function); + } + + /** + * Create a raw {@link Similarity} value without validation. + *

+ * Intended for use when accepting similarity values from trusted sources such as search engines or databases. + * + * @param similarity the similarity value in the {@code [0.0, 1.0]} range. + * @param function the scoring function that produced this similarity. + * @return a new {@link Similarity} instance. + */ + public static Similarity raw(double similarity, ScoringFunction function) { + return new Similarity(similarity, function); + } + + /** + * Creates a {@link Range} between the given {@link Similarity}. + * + * @param min lower value. + * @param max upper value. + * @return the {@link Range} between the given values. + */ + public static Range between(Similarity min, Similarity max) { + return Range.from(Range.Bound.inclusive(min)).to(Range.Bound.inclusive(max)); + } + + /** + * Creates a new {@link Range} by creating minimum and maximum {@link Similarity} from the given values + * {@link ScoringFunction#unspecified() without specifying} a specific scoring function. + * + * @param minValue lower value, ranging between {@code 0} and {@code 1}. + * @param maxValue upper value, ranging between {@code 0} and {@code 1}. + * @return the {@link Range} between the given values. + */ + public static Range between(double minValue, double maxValue) { + return between(minValue, maxValue, ScoringFunction.unspecified()); + } + + /** + * Creates a {@link Range} of {@link Similarity} values using raw values and a specified scoring function. + * + * @param minValue the lower similarity value. + * @param maxValue the upper similarity value. + * @param function the scoring function to associate with the values. + * @return a {@link Range} of {@link Similarity} values. + */ + public static Range between(double minValue, double maxValue, ScoringFunction function) { + return between(Similarity.of(minValue, function), Similarity.of(maxValue, function)); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Similarity other)) { + return false; + } + return super.equals(other); + } + +} diff --git a/src/main/java/org/springframework/data/domain/UnspecifiedScoringFunction.java b/src/main/java/org/springframework/data/domain/UnspecifiedScoringFunction.java new file mode 100644 index 0000000000..986b6e5592 --- /dev/null +++ b/src/main/java/org/springframework/data/domain/UnspecifiedScoringFunction.java @@ -0,0 +1,46 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.domain; + +import java.io.Serializable; + +class UnspecifiedScoringFunction implements ScoringFunction, Serializable { + + static final UnspecifiedScoringFunction INSTANCE = new UnspecifiedScoringFunction(); + + private UnspecifiedScoringFunction() {} + + @Override + public String getName() { + return "Unspecified"; + } + + @Override + public boolean equals(Object o) { + return o instanceof UnspecifiedScoringFunction; + } + + @Override + public int hashCode() { + return 32; + } + + @Override + public String toString() { + return "UNSPECIFIED"; + } + +} diff --git a/src/main/java/org/springframework/data/domain/VectorScoringFunctions.java b/src/main/java/org/springframework/data/domain/VectorScoringFunctions.java new file mode 100644 index 0000000000..e39356505b --- /dev/null +++ b/src/main/java/org/springframework/data/domain/VectorScoringFunctions.java @@ -0,0 +1,92 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.domain; + +/** + * Commonly used {@link ScoringFunction} implementations for vector-based similarity computations. + *

+ * Provides a set of standard scoring strategies for comparing vectors in search or matching operations. Includes + * options such as Euclidean distance, cosine similarity, and dot product. + *

+ * These constants are intended for reuse across components requiring vector scoring semantics. Each scoring function + * represents a mathematical approach to quantifying the similarity or distance between vectors in a multidimensional + * space. + *

+ * When selecting a scoring function, consider the specific requirements of your application domain: + *

    + *
  • For spatial distance measurements where magnitude matters, use {@link #EUCLIDEAN}.
  • + *
  • For directional similarity irrespective of magnitude, use {@link #COSINE}.
  • + *
  • For efficient high-dimensional calculations, use {@link #DOT_PRODUCT}.
  • + *
  • For grid-based or axis-aligned problems, use {@link #TAXICAB}.
  • + *
  • For binary vector or string comparisons, use {@link #HAMMING}.
  • + *
+ * The choice of scoring function can significantly impact the relevance of the results returned by a Vector Search + * query. {@code ScoringFunction} and score values are typically subject to fine-tuning during the development to + * achieve optimal performance and accuracy. + * + * @author Mark Paluch + * @since 4.0 + */ +public enum VectorScoringFunctions implements ScoringFunction { + + /** + * Scoring based on the Euclidean distance between two + * vectors. + *

+ * Computes the L2 norm, involving a square root operation. Typically more computationally expensive than + * {@link #COSINE} or {@link #DOT_PRODUCT}, but precise in spatial distance measurement. + */ + EUCLIDEAN, + + /** + * Scoring based on cosine similarity between two vectors. + *

+ * Measures the angle between vectors, independent of their magnitude. Involves a {@link #DOT_PRODUCT} and + * normalization, offering a balance between precision and performance. + */ + COSINE, + + /** + * Scoring based on the dot product (also known as inner + * product) between two vectors. + *

+ * Efficient to compute and particularly useful in high-dimensional vector spaces. + */ + DOT_PRODUCT, + + /** + * Scoring based on taxicab (Manhattan) distance. + *

+ * Computes the sum of absolute differences across dimensions. Useful in contexts where axis-aligned movement or L1 + * norms are preferred. + */ + TAXICAB, + + /** + * Scoring based on the Hamming distance between two + * vectors or strings. + *

+ * Counts the number of differing positions. Suitable for binary (bitwise) vectors or fixed-length character + * sequences. + */ + HAMMING; + + @Override + public String getName() { + return name(); + } + +} diff --git a/src/main/java/org/springframework/data/geo/Distance.java b/src/main/java/org/springframework/data/geo/Distance.java index 612f905219..eb6d6e5673 100644 --- a/src/main/java/org/springframework/data/geo/Distance.java +++ b/src/main/java/org/springframework/data/geo/Distance.java @@ -71,6 +71,27 @@ public Distance(double value, Metric metric) { this.metric = metric; } + /** + * Creates a new {@link Distance} with a neutral metric. This means the provided value needs to be in normalized form. + * + * @param value distance value. + * @since 4.0 + */ + public static Distance of(double value) { + return new Distance(value); + } + + /** + * Creates a new {@link Distance} with the given {@link Metric}. + * + * @param value distance value. + * @param metric must not be {@literal null}. + * @since 4.0 + */ + public static Distance of(double value, Metric metric) { + return new Distance(value, metric); + } + /** * Creates a {@link Range} between the given {@link Distance}. * diff --git a/src/main/java/org/springframework/data/geo/GeoResult.java b/src/main/java/org/springframework/data/geo/GeoResult.java index ae9fa180ac..3c3a4923bf 100644 --- a/src/main/java/org/springframework/data/geo/GeoResult.java +++ b/src/main/java/org/springframework/data/geo/GeoResult.java @@ -19,8 +19,8 @@ import java.io.Serializable; import org.jspecify.annotations.Nullable; - import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; import org.springframework.util.ObjectUtils; /** @@ -79,6 +79,7 @@ public int hashCode() { @Override public String toString() { - return String.format("GeoResult [content: %s, distance: %s, ]", content.toString(), distance.toString()); + return String.format("GeoResult [instance: %s, distance: %s]", ClassUtils.getShortName(content.getClass()), + distance); } } diff --git a/src/main/java/org/springframework/data/geo/GeoResults.java b/src/main/java/org/springframework/data/geo/GeoResults.java index a22b75ec75..5ff49c720a 100644 --- a/src/main/java/org/springframework/data/geo/GeoResults.java +++ b/src/main/java/org/springframework/data/geo/GeoResults.java @@ -15,7 +15,6 @@ */ package org.springframework.data.geo; - import java.io.Serial; import java.io.Serializable; import java.util.Collections; @@ -23,11 +22,9 @@ import java.util.List; import org.jspecify.annotations.Nullable; - import org.springframework.data.annotation.PersistenceCreator; import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; -import org.springframework.util.StringUtils; /** * Value object to capture {@link GeoResult}s as well as the average distance they have. @@ -129,8 +126,7 @@ public int hashCode() { @Override public String toString() { - return String.format("GeoResults: [averageDistance: %s, results: %s]", averageDistance.toString(), - StringUtils.collectionToCommaDelimitedString(results)); + return String.format("GeoResults: [averageDistance: %s, size: %s]", averageDistance, results.size()); } private static Distance calculateAverageDistance(List> results, Metric metric) { diff --git a/src/main/java/org/springframework/data/repository/query/Parameter.java b/src/main/java/org/springframework/data/repository/query/Parameter.java index 0907d0f035..2061b4f242 100644 --- a/src/main/java/org/springframework/data/repository/query/Parameter.java +++ b/src/main/java/org/springframework/data/repository/query/Parameter.java @@ -28,8 +28,11 @@ import org.springframework.core.ResolvableType; import org.springframework.data.domain.Limit; import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.repository.util.QueryExecutionConverters; import org.springframework.data.repository.util.ReactiveWrapperConverters; import org.springframework.data.util.ClassUtils; @@ -55,6 +58,7 @@ public class Parameter { private final MethodParameter parameter; private final Class parameterType; + private final boolean isScoreRange; private final boolean isDynamicProjectionParameter; private final Lazy> name; @@ -71,6 +75,7 @@ public class Parameter { TYPES = Collections.unmodifiableList(types); } + /** * Creates a new {@link Parameter} for the given {@link MethodParameter} and domain {@link TypeInformation}. * @@ -84,9 +89,11 @@ protected Parameter(MethodParameter parameter, TypeInformation domainType) { Assert.notNull(domainType, "TypeInformation must not be null!"); this.parameter = parameter; + this.isScoreRange = Range.class.isAssignableFrom(parameter.getParameterType()) + && ResolvableType.forMethodParameter(parameter).getGeneric(0).isAssignableFrom(Score.class); this.parameterType = potentiallyUnwrapParameterType(parameter); this.isDynamicProjectionParameter = isDynamicProjectionParameter(parameter, domainType); - this.name = isSpecialParameterType(parameter.getParameterType()) ? Lazy.of(Optional.empty()) : Lazy.of(() -> { + this.name = Lazy.of(() -> { Param annotation = parameter.getParameterAnnotation(Param.class); return Optional.ofNullable(annotation == null ? parameter.getParameterName() : annotation.value()); }); @@ -208,9 +215,31 @@ public String toString() { } /** - * Returns whether the {@link Parameter} is a {@link ScrollPosition} parameter. - * - * @return + * @return {@literal true} if the {@link Parameter} is a {@link Vector} parameter. + * @since 4.0 + */ + boolean isVector() { + return Vector.class.isAssignableFrom(getType()); + } + + /** + * @return {@literal true} if the {@link Parameter} is a {@link Score} parameter. + * @since 4.0 + */ + boolean isScore() { + return Score.class.isAssignableFrom(getType()); + } + + /** + * @return {@literal true} if the {@link Parameter} is a {@link Range} of {@link Score} parameter. + * @since 4.0 + */ + boolean isScoreRange() { + return isScoreRange; + } + + /** + * @return {@literal true} if the {@link Parameter} is a {@link ScrollPosition} parameter. * @since 3.1 */ boolean isScrollPosition() { @@ -218,27 +247,21 @@ boolean isScrollPosition() { } /** - * Returns whether the {@link Parameter} is a {@link Pageable} parameter. - * - * @return + * @return {@literal true} if the {@link Parameter} is a {@link Pageable} parameter. */ boolean isPageable() { return Pageable.class.isAssignableFrom(getType()); } /** - * Returns whether the {@link Parameter} is a {@link Sort} parameter. - * - * @return + * @return {@literal true} if the {@link Parameter} is a {@link Sort} parameter. */ boolean isSort() { return Sort.class.isAssignableFrom(getType()); } /** - * Returns whether the {@link Parameter} is a {@link Limit} parameter. - * - * @return + * @return {@literal true} if the {@link Parameter} is a {@link Limit} parameter. * @since 3.2 */ boolean isLimit() { diff --git a/src/main/java/org/springframework/data/repository/query/ParameterAccessor.java b/src/main/java/org/springframework/data/repository/query/ParameterAccessor.java index d8a406d909..8a69b1a49a 100644 --- a/src/main/java/org/springframework/data/repository/query/ParameterAccessor.java +++ b/src/main/java/org/springframework/data/repository/query/ParameterAccessor.java @@ -21,8 +21,11 @@ import org.springframework.data.domain.Limit; import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; /** * Interface to access method parameters. Allows dedicated access to parameters of special types @@ -33,35 +36,48 @@ public interface ParameterAccessor extends Iterable { /** - * Returns the {@link ScrollPosition} of the parameters, if available. Returns {@code null} otherwise. - * - * @return + * @return the {@link Vector} of the parameters, if available; {@literal null} otherwise. + * @since 4.0 + */ + @Nullable + Vector getVector(); + + /** + * @return the {@link Score} of the parameters, if available; {@literal null} otherwise. + * @since 4.0 + */ + @Nullable + Score getScore(); + + /** + * @return the {@link Range} of {@link Score} of the parameters, if available; {@literal null} otherwise. + * @since 4.0 + */ + @Nullable + Range getScoreRange(); + + /** + * @return the {@link ScrollPosition} of the parameters, if available; {@literal null} otherwise. */ @Nullable ScrollPosition getScrollPosition(); /** - * Returns the {@link Pageable} of the parameters, if available. Returns {@link Pageable#unpaged()} otherwise. - * - * @return + * @return the {@link Pageable} of the parameters, if available; {@link Pageable#unpaged()} otherwise. */ Pageable getPageable(); /** - * Returns the sort instance to be used for query creation. Will use a {@link Sort} parameter if available or the - * {@link Sort} contained in a {@link Pageable} if available. Returns {@link Sort#unsorted()} if no {@link Sort} can - * be found. - * - * @return + * @return the sort instance to be used for query creation. Will use a {@link Sort} parameter if available or the + * {@link Sort} contained in a {@link Pageable} if available. {@link Sort#unsorted()} if no {@link Sort} can + * be found. */ Sort getSort(); /** - * Returns the {@link Limit} instance to be used for query creation. If no {@link java.lang.reflect.Parameter} - * assignable to {@link Limit} can be found {@link Limit} will be created out of {@link Pageable#getPageSize()} if - * present. - * - * @return + * @return the {@link Limit} instance to be used for query creation. If no {@link java.lang.reflect.Parameter} + * assignable to {@link Limit} can be found {@link Limit} will be created out of + * {@link Pageable#getPageSize()} if present. * @since 3.2 */ default Limit getLimit() { @@ -69,9 +85,7 @@ default Limit getLimit() { } /** - * Returns the dynamic projection type to be used when executing the query or {@literal null} if none is defined. - * - * @return + * @return the dynamic projection type to be used when executing the query or {@literal null} if none is defined. * @since 2.2 */ @Nullable @@ -83,7 +97,7 @@ default Limit getLimit() { * {@link String}, {@code #getBindableParameter(1)} would return the second {@link String} value. * * @param index - * @return + * @return the bindable value with the given index */ @Nullable Object getBindableValue(int index); @@ -91,7 +105,7 @@ default Limit getLimit() { /** * Returns whether one of the bindable parameter values is {@literal null}. * - * @return + * @return {@literal true} if one of the bindable parameter values is {@literal null}. */ boolean hasBindableNullValue(); @@ -99,7 +113,9 @@ default Limit getLimit() { * Returns an iterator over all bindable parameters. This means parameters implementing {@link Pageable} or * {@link Sort} will not be included in this {@link Iterator}. * - * @return + * @return iterator over all bindable parameters. */ + @Override Iterator iterator(); + } diff --git a/src/main/java/org/springframework/data/repository/query/Parameters.java b/src/main/java/org/springframework/data/repository/query/Parameters.java index 1acc732fab..c104bdf610 100644 --- a/src/main/java/org/springframework/data/repository/query/Parameters.java +++ b/src/main/java/org/springframework/data/repository/query/Parameters.java @@ -15,7 +15,7 @@ */ package org.springframework.data.repository.query; -import static java.lang.String.*; +import static java.lang.String.format; import java.lang.reflect.Method; import java.util.ArrayList; @@ -27,10 +27,14 @@ import org.springframework.core.DefaultParameterNameDiscoverer; import org.springframework.core.MethodParameter; import org.springframework.core.ParameterNameDiscoverer; +import org.springframework.core.ResolvableType; import org.springframework.data.domain.Limit; import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.util.Lazy; import org.springframework.data.util.Streamable; import org.springframework.util.Assert; @@ -55,6 +59,9 @@ public abstract class Parameters, T extends Parameter private static final ParameterNameDiscoverer PARAMETER_NAME_DISCOVERER = new DefaultParameterNameDiscoverer(); + private final int vectorIndex; + private final int scoreIndex; + private final int scoreRangeIndex; private final int scrollPositionIndex; private final int pageableIndex; private final int sortIndex; @@ -72,8 +79,7 @@ public abstract class Parameters, T extends Parameter * @param parameterFactory must not be {@literal null}. * @since 3.2.1 */ - protected Parameters(ParametersSource parametersSource, - Function parameterFactory) { + protected Parameters(ParametersSource parametersSource, Function parameterFactory) { Assert.notNull(parametersSource, "ParametersSource must not be null"); Assert.notNull(parameterFactory, "Parameter factory must not be null"); @@ -84,6 +90,9 @@ protected Parameters(ParametersSource parametersSource, this.parameters = new ArrayList<>(parameterCount); this.dynamicProjectionIndex = -1; + int vectorIndex = -1; + int scoreIndex = -1; + int scoreRangeIndex = -1; int scrollPositionIndex = -1; int pageableIndex = -1; int sortIndex = -1; @@ -106,6 +115,19 @@ protected Parameters(ParametersSource parametersSource, this.dynamicProjectionIndex = parameter.getIndex(); } + if (Vector.class.isAssignableFrom(parameter.getType())) { + vectorIndex = i; + } + + if (Score.class.isAssignableFrom(parameter.getType())) { + scoreIndex = i; + } + + if (Range.class.isAssignableFrom(parameter.getType()) + && Score.class.isAssignableFrom(ResolvableType.forMethodParameter(methodParameter).getGeneric(0).toClass())) { + scoreRangeIndex = i; + } + if (ScrollPosition.class.isAssignableFrom(parameter.getType())) { scrollPositionIndex = i; } @@ -125,6 +147,9 @@ protected Parameters(ParametersSource parametersSource, parameters.add(parameter); } + this.vectorIndex = vectorIndex; + this.scoreIndex = scoreIndex; + this.scoreRangeIndex = scoreRangeIndex; this.scrollPositionIndex = scrollPositionIndex; this.pageableIndex = pageableIndex; this.sortIndex = sortIndex; @@ -143,6 +168,9 @@ protected Parameters(List originals) { this.parameters = new ArrayList<>(originals.size()); + int vectorIndexTemp = -1; + int scoreIndexTemp = -1; + int scoreRangeIndexTemp = -1; int scrollPositionIndexTemp = -1; int pageableIndexTemp = -1; int sortIndexTemp = -1; @@ -154,6 +182,9 @@ protected Parameters(List originals) { T original = originals.get(i); this.parameters.add(original); + vectorIndexTemp = original.isVector() ? i : -1; + scoreIndexTemp = original.isScore() ? i : -1; + scoreRangeIndexTemp = original.isScoreRange() ? i : -1; scrollPositionIndexTemp = original.isScrollPosition() ? i : -1; pageableIndexTemp = original.isPageable() ? i : -1; sortIndexTemp = original.isSort() ? i : -1; @@ -161,6 +192,9 @@ protected Parameters(List originals) { dynamicProjectionTemp = original.isDynamicProjectionParameter() ? i : -1; } + this.vectorIndex = vectorIndexTemp; + this.scoreIndex = scoreIndexTemp; + this.scoreRangeIndex = scoreRangeIndexTemp; this.scrollPositionIndex = scrollPositionIndexTemp; this.pageableIndex = pageableIndexTemp; this.sortIndex = sortIndexTemp; @@ -183,6 +217,67 @@ private S getBindable() { return createFrom(bindables); } + /** + * Returns whether the method the {@link Parameters} was created for contains a {@link Vector} argument. + * + * @return + * @since 4.0 + */ + public boolean hasVectorParameter() { + return vectorIndex != -1; + } + + /** + * Returns the index of the {@link Vector} argument. + * + * @return the argument index or {@literal -1} if none defined. + * @since 4.0 + */ + public int getVectorIndex() { + return vectorIndex; + } + + /** + * Returns whether the method the {@link Parameters} was created for contains a {@link Score} argument. + * + * @return + * @since 4.0 + */ + public boolean hasScoreParameter() { + return scoreIndex != -1; + } + + /** + * Returns the index of the {@link Score} argument. + * + * @return the argument index or {@literal -1} if none defined. + * @since 4.0 + */ + public int getScoreIndex() { + return scoreIndex; + } + + /** + * Returns whether the method, the {@link Parameters} was created for, contains a {@link Range} of {@link Score} + * argument. + * + * @return + * @since 4.0 + */ + public boolean hasScoreRangeParameter() { + return scoreRangeIndex != -1; + } + + /** + * Returns the index of the argument that contains a {@link Range} of {@link Score}. + * + * @return the argument index or {@literal -1} if none defined. + * @since 4.0 + */ + public int getScoreRangeIndex() { + return scoreRangeIndex; + } + /** * Returns whether the method the {@link Parameters} was created for contains a {@link ScrollPosition} argument. * diff --git a/src/main/java/org/springframework/data/repository/query/ParametersParameterAccessor.java b/src/main/java/org/springframework/data/repository/query/ParametersParameterAccessor.java index 49815e4ca0..9acaa80a46 100644 --- a/src/main/java/org/springframework/data/repository/query/ParametersParameterAccessor.java +++ b/src/main/java/org/springframework/data/repository/query/ParametersParameterAccessor.java @@ -22,8 +22,11 @@ import org.springframework.data.domain.Limit; import org.springframework.data.domain.PageRequest; import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.repository.util.QueryExecutionConverters; import org.springframework.data.repository.util.ReactiveWrapperConverters; import org.springframework.util.Assert; @@ -95,6 +98,36 @@ private static boolean requiresUnwrapping(@Nullable Object[] values) { return this.values; } + @Override + public @Nullable Vector getVector() { + + if (parameters.getVectorIndex() == -1) { + return null; + } + + return (Vector) values[parameters.getVectorIndex()]; + } + + @Override + public @Nullable Score getScore() { + + if (!parameters.hasScoreParameter()) { + return null; + } + + return (Score) values[parameters.getScoreIndex()]; + } + + @Override + public @Nullable Range getScoreRange() { + + if (!parameters.hasScoreRangeParameter()) { + return null; + } + + return (Range) values[parameters.getScoreRangeIndex()]; + } + @Override public @Nullable ScrollPosition getScrollPosition() { diff --git a/src/main/java/org/springframework/data/repository/query/QueryMethod.java b/src/main/java/org/springframework/data/repository/query/QueryMethod.java index a64b7c56a8..25fd18f12a 100644 --- a/src/main/java/org/springframework/data/repository/query/QueryMethod.java +++ b/src/main/java/org/springframework/data/repository/query/QueryMethod.java @@ -26,6 +26,8 @@ import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.ScrollPosition; +import org.springframework.data.domain.SearchResult; +import org.springframework.data.domain.SearchResults; import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; import org.springframework.data.domain.Window; @@ -41,6 +43,7 @@ import org.springframework.data.util.TypeInformation; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; /** * Abstraction of a method that is designated to execute a finder query. Enriches the standard {@link Method} interface @@ -280,6 +283,24 @@ public final boolean isPageQuery() { return org.springframework.util.ClassUtils.isAssignable(Page.class, unwrappedReturnType); } + /** + * Returns whether the finder will return a {@link SearchResults} (or collection of {@link SearchResult}) of results. + * + * @return + * @since 4.0 + */ + public boolean isSearchQuery() { + + if (ClassUtils.isAssignable(SearchResults.class, unwrappedReturnType)) { + return true; + } + + TypeInformation returnType = metadata.getReturnType(method); + TypeInformation componentType = returnType.getComponentType(); + + return componentType != null && SearchResult.class.isAssignableFrom(componentType.getType()); + } + /** * Returns whether the query method is a modifying one. * diff --git a/src/main/java/org/springframework/data/repository/query/ResultProcessor.java b/src/main/java/org/springframework/data/repository/query/ResultProcessor.java index 499a5de4b9..c157da21d7 100644 --- a/src/main/java/org/springframework/data/repository/query/ResultProcessor.java +++ b/src/main/java/org/springframework/data/repository/query/ResultProcessor.java @@ -28,6 +28,7 @@ import org.springframework.core.convert.ConversionService; import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.support.DefaultConversionService; +import org.springframework.data.domain.SearchResults; import org.springframework.data.domain.Slice; import org.springframework.data.domain.Window; import org.springframework.data.projection.ProjectionFactory; @@ -154,6 +155,10 @@ public ReturnedType getReturnedType() { return (T) ((Slice) source).map(converter::convert); } + if (source instanceof SearchResults results && method.isSearchQuery()) { + return (T) results.map(converter::convert); + } + if (source instanceof Collection collection && method.isCollectionQuery()) { Collection target = createCollectionFor(collection); diff --git a/src/main/java/org/springframework/data/repository/util/QueryExecutionConverters.java b/src/main/java/org/springframework/data/repository/util/QueryExecutionConverters.java index a10d45c078..182c05cfa8 100644 --- a/src/main/java/org/springframework/data/repository/util/QueryExecutionConverters.java +++ b/src/main/java/org/springframework/data/repository/util/QueryExecutionConverters.java @@ -38,6 +38,8 @@ import org.springframework.core.convert.support.ConfigurableConversionService; import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.data.domain.Page; +import org.springframework.data.domain.SearchResult; +import org.springframework.data.domain.SearchResults; import org.springframework.data.domain.Slice; import org.springframework.data.domain.Window; import org.springframework.data.geo.GeoResults; @@ -98,6 +100,7 @@ public abstract class QueryExecutionConverters { ALLOWED_PAGEABLE_TYPES.add(Page.class); ALLOWED_PAGEABLE_TYPES.add(List.class); ALLOWED_PAGEABLE_TYPES.add(Window.class); + ALLOWED_PAGEABLE_TYPES.add(SearchResults.class); WRAPPER_TYPES.add(NullableWrapperToCompletableFutureConverter.getWrapperType()); @@ -253,6 +256,8 @@ public static TypeInformation unwrapWrapperTypes(TypeInformation type, Typ boolean needToUnwrap = type.isCollectionLike() // || Slice.class.isAssignableFrom(rawType) // || GeoResults.class.isAssignableFrom(rawType) // + || SearchResult.class.isAssignableFrom(rawType) // + || SearchResults.class.isAssignableFrom(rawType) // || rawType.isArray() // || supports(rawType) // || Stream.class.isAssignableFrom(rawType); diff --git a/src/test/java/org/springframework/data/domain/SearchResultUnitTests.java b/src/test/java/org/springframework/data/domain/SearchResultUnitTests.java new file mode 100755 index 0000000000..8a8f6b334d --- /dev/null +++ b/src/test/java/org/springframework/data/domain/SearchResultUnitTests.java @@ -0,0 +1,69 @@ +/* + * Copyright 2011-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.domain; + +import static org.assertj.core.api.Assertions.*; + +import org.junit.jupiter.api.Test; + +import org.springframework.util.SerializationUtils; + +/** + * Unit tests for {@link SearchResult}. + * + * @author Mark Paluch + */ +class SearchResultUnitTests { + + SearchResult first = new SearchResult<>("Foo", Score.of(2.5)); + SearchResult second = new SearchResult<>("Foo", Score.of(2.5)); + SearchResult third = new SearchResult<>("Bar", Score.of(2.5)); + SearchResult fourth = new SearchResult<>("Foo", Score.of(5.2)); + + @Test // GH- + void considersSameInstanceEqual() { + assertThat(first.equals(first)).isTrue(); + } + + @Test // GH- + void considersSameValuesAsEqual() { + + assertThat(first.equals(second)).isTrue(); + assertThat(second.equals(first)).isTrue(); + assertThat(first.equals(third)).isFalse(); + assertThat(third.equals(first)).isFalse(); + assertThat(first.equals(fourth)).isFalse(); + assertThat(fourth.equals(first)).isFalse(); + } + + @Test + @SuppressWarnings({ "rawtypes", "unchecked" }) + // GH- + void rejectsNullContent() { + assertThatIllegalArgumentException().isThrownBy(() -> new SearchResult(null, Score.of(2.5))); + } + + @Test // GH- + @SuppressWarnings("unchecked") + void testSerialization() { + + var result = new SearchResult<>("test", Score.of(2d)); + + var serialized = (SearchResult) SerializationUtils.deserialize(SerializationUtils.serialize(result)); + assertThat(serialized).isEqualTo(result); + } + +} diff --git a/src/test/java/org/springframework/data/domain/SearchResultsUnitTests.java b/src/test/java/org/springframework/data/domain/SearchResultsUnitTests.java new file mode 100755 index 0000000000..c368d760ab --- /dev/null +++ b/src/test/java/org/springframework/data/domain/SearchResultsUnitTests.java @@ -0,0 +1,69 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.domain; + +import static org.assertj.core.api.Assertions.*; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.junit.jupiter.api.Test; + +import org.springframework.util.SerializationUtils; + +/** + * Unit tests for {@link SearchResults}. + * + * @author Mark Paluch + */ +class SearchResultsUnitTests { + + @SuppressWarnings("unchecked") + @Test // GH- + void testSerialization() { + + var result = new SearchResult<>("test", Score.of(2)); + var searchResults = new SearchResults<>(Collections.singletonList(result)); + + var serialized = (SearchResults) SerializationUtils + .deserialize(SerializationUtils.serialize(searchResults)); + assertThat(serialized).isEqualTo(searchResults); + } + + @SuppressWarnings("unchecked") + @Test // GH- + void testStream() { + + var result = new SearchResult<>("test", Score.of(2)); + var searchResults = new SearchResults<>(Collections.singletonList(result)); + + List> list = searchResults.stream().toList(); + assertThat(list).isEqualTo(searchResults.getContent()); + } + + @SuppressWarnings("unchecked") + @Test // GH- + void testContentStream() { + + var result = new SearchResult<>("test", Score.of(2)); + var searchResults = new SearchResults<>(Collections.singletonList(result)); + + List list = searchResults.contentStream().toList(); + assertThat(list).isEqualTo(Arrays.asList(result.getContent())); + } + +} diff --git a/src/test/java/org/springframework/data/domain/SimilarityUnitTests.java b/src/test/java/org/springframework/data/domain/SimilarityUnitTests.java new file mode 100644 index 0000000000..5d8bffabeb --- /dev/null +++ b/src/test/java/org/springframework/data/domain/SimilarityUnitTests.java @@ -0,0 +1,89 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.domain; + +import static org.assertj.core.api.Assertions.*; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for {@link Similarity}. + * + * @author Mark Paluch + */ +class SimilarityUnitTests { + + @Test + void shouldBeBounded() { + + assertThatIllegalArgumentException().isThrownBy(() -> Similarity.of(-1)); + assertThatIllegalArgumentException().isThrownBy(() -> Similarity.of(1.01)); + } + + @Test + void shouldConstructRawSimilarity() { + + Similarity similarity = Similarity.raw(2, ScoringFunction.unspecified()); + + assertThat(similarity.getValue()).isEqualTo(2); + } + + @Test + void shouldConstructGenericSimilarity() { + + Similarity similarity = Similarity.of(1); + + assertThat(similarity).isEqualTo(Similarity.of(1)).isNotEqualTo(Score.of(1)).isNotEqualTo(Similarity.of(0.5)); + assertThat(similarity).hasToString("1.0"); + assertThat(similarity.getFunction()).isEqualTo(ScoringFunction.unspecified()); + } + + @Test + void shouldConstructMeteredSimilarity() { + + Similarity similarity = Similarity.of(1, VectorScoringFunctions.COSINE); + + assertThat(similarity).isEqualTo(Similarity.of(1, VectorScoringFunctions.COSINE)) + .isNotEqualTo(Score.of(1, VectorScoringFunctions.COSINE)).isNotEqualTo(Similarity.of(1)); + assertThat(similarity).hasToString("1.0 (COSINE)"); + assertThat(similarity.getFunction()).isEqualTo(VectorScoringFunctions.COSINE); + } + + @Test + void shouldConstructRange() { + + Range range = Similarity.between(0.5, 1); + + assertThat(range.getLowerBound().getValue()).contains(Similarity.of(0.5)); + assertThat(range.getLowerBound().isInclusive()).isTrue(); + + assertThat(range.getUpperBound().getValue()).contains(Similarity.of(1)); + assertThat(range.getUpperBound().isInclusive()).isTrue(); + } + + @Test + void shouldConstructRangeWithFunction() { + + Range range = Similarity.between(0.5, 1, VectorScoringFunctions.COSINE); + + assertThat(range.getLowerBound().getValue()).contains(Similarity.of(0.5, VectorScoringFunctions.COSINE)); + assertThat(range.getLowerBound().isInclusive()).isTrue(); + + assertThat(range.getUpperBound().getValue()).contains(Similarity.of(1, VectorScoringFunctions.COSINE)); + assertThat(range.getUpperBound().isInclusive()).isTrue(); + } + +} diff --git a/src/test/java/org/springframework/data/repository/query/ParametersUnitTests.java b/src/test/java/org/springframework/data/repository/query/ParametersUnitTests.java index be4a74b8ee..3ed08c275c 100755 --- a/src/test/java/org/springframework/data/repository/query/ParametersUnitTests.java +++ b/src/test/java/org/springframework/data/repository/query/ParametersUnitTests.java @@ -31,6 +31,9 @@ import org.springframework.data.domain.OffsetScrollPosition; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.Similarity; import org.springframework.data.domain.Sort; import org.springframework.data.domain.Window; import org.springframework.data.repository.Repository; @@ -230,6 +233,22 @@ void considersGenericType() throws Exception { assertThat(parameters.getParameter(0).getType()).isEqualTo(Long.class); } + @Test // GH- + void considersScoreRange() throws Exception { + + var parameters = getParametersFor("methodWithScoreRange", Range.class); + + assertThat(parameters.hasScoreRangeParameter()).isTrue(); + } + + @Test // GH- + void considersSimilarityRange() throws Exception { + + var parameters = getParametersFor("methodWithSimilarityRange", Range.class); + + assertThat(parameters.hasScoreRangeParameter()).isTrue(); + } + private Parameters getParametersFor(String methodName, Class... parameterTypes) throws SecurityException, NoSuchMethodException { @@ -268,6 +287,10 @@ interface SampleDao extends Repository { void methodWithSingle(Single single); + void methodWithScoreRange(Range single); + + void methodWithSimilarityRange(Range single); + Page customPageable(SomePageable pageable); Window customScrollPosition(OffsetScrollPosition request); diff --git a/src/test/java/org/springframework/data/repository/query/QueryMethodUnitTests.java b/src/test/java/org/springframework/data/repository/query/QueryMethodUnitTests.java index 56f6b69bb5..ccbc48bd9d 100755 --- a/src/test/java/org/springframework/data/repository/query/QueryMethodUnitTests.java +++ b/src/test/java/org/springframework/data/repository/query/QueryMethodUnitTests.java @@ -34,14 +34,16 @@ import org.junit.jupiter.api.DynamicTest; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestFactory; + import org.springframework.data.domain.Limit; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.ScrollPosition; +import org.springframework.data.domain.SearchResult; +import org.springframework.data.domain.SearchResults; import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; import org.springframework.data.domain.Window; -import org.springframework.data.domain.Window; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; import org.springframework.data.repository.Repository; @@ -388,6 +390,24 @@ Stream doesNotConsiderQueryMethodReturningAggregateImplementingStre }); } + @Test // GH- + void considersSearchResults() throws NoSuchMethodException { + + var method = SampleRepository.class.getMethod("searchTop5By"); + QueryMethod queryMethod = new QueryMethod(method, metadata, factory); + + assertThat(queryMethod.isSearchQuery()).isTrue(); + } + + @Test // GH- + void considersSearchResult() throws NoSuchMethodException { + + var method = SampleRepository.class.getMethod("searchListTop5By"); + QueryMethod queryMethod = new QueryMethod(method, metadata, factory); + + assertThat(queryMethod.isSearchQuery()).isTrue(); + } + interface SampleRepository extends Repository { String pagingMethodWithInvalidReturnType(Pageable pageable); @@ -460,6 +480,10 @@ interface SampleRepository extends Repository { List findTop5By(Limit limit); List findTop5By(Pageable page); + + SearchResults searchTop5By(); + + List> searchListTop5By(); } class User { diff --git a/src/test/java/org/springframework/data/repository/query/SimpleParameterAccessorUnitTests.java b/src/test/java/org/springframework/data/repository/query/SimpleParameterAccessorUnitTests.java index aec5ed7d4c..3f6c4fc41f 100755 --- a/src/test/java/org/springframework/data/repository/query/SimpleParameterAccessorUnitTests.java +++ b/src/test/java/org/springframework/data/repository/query/SimpleParameterAccessorUnitTests.java @@ -21,7 +21,10 @@ import org.junit.jupiter.api.Test; import org.springframework.data.domain.PageRequest; import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; +import org.springframework.data.domain.Similarity; import org.springframework.data.domain.Sort; /** @@ -32,7 +35,7 @@ */ class SimpleParameterAccessorUnitTests { - Parameters parameters, cursorRequestParameters, sortParameters, pageableParameters; + Parameters parameters, cursorRequestParameters, sortParameters, pageableParameters, scoreParameters; @BeforeEach void setUp() throws SecurityException, NoSuchMethodException { @@ -44,6 +47,9 @@ void setUp() throws SecurityException, NoSuchMethodException { ParametersSource.of(Sample.class.getMethod("sample1", String.class, Sort.class))); pageableParameters = new DefaultParameters( ParametersSource.of(Sample.class.getMethod("sample2", String.class, Pageable.class))); + + scoreParameters = new DefaultParameters( + ParametersSource.of(Sample.class.getMethod("sample", String.class, Score.class, Range.class))); } @Test @@ -122,12 +128,32 @@ void returnsSortFromPageableIfAvailable() throws Exception { assertThat(accessor.getSort()).isEqualTo(sort); } + @Test + void returnsScoreIfAvailable() { + + Score score = Score.of(1); + ParameterAccessor accessor = new ParametersParameterAccessor(scoreParameters, new Object[] { "test", score, null }); + + assertThat(accessor.getScore()).isEqualTo(score); + } + + @Test + void returnsScoreRangeIfAvailable() { + + Range range = Similarity.between(0, 1); + ParameterAccessor accessor = new ParametersParameterAccessor(scoreParameters, new Object[] { "test", null, range }); + + assertThat(accessor.getScoreRange()).isEqualTo(range); + } + interface Sample { void sample(String firstname); void sample(ScrollPosition scrollPosition); + void sample(String firstname, Score score, Range range); + void sample1(String firstname, Sort sort); void sample2(String firstname, Pageable pageable);