Skip to content

Commit c327010

Browse files
Polishing.
Original Pull Request: #3868
1 parent a320669 commit c327010

File tree

10 files changed

+71
-32
lines changed

10 files changed

+71
-32
lines changed

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/EmptyIntrospectedQuery.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ enum EmptyIntrospectedQuery implements EntityQuery {
3434

3535
EmptyIntrospectedQuery() {}
3636

37-
38-
3937
@Override
4038
public boolean hasParameterBindings() {
4139
return false;
@@ -61,6 +59,7 @@ public List<ParameterBinding> getParameterBindings() {
6159
}
6260

6361
@Override
62+
@SuppressWarnings("NullAway")
6463
public <T> T doWithEnhancer(Function<QueryEnhancer, T> function) {
6564
return null;
6665
}

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JSqlParserQueryEnhancer.java

+6-3
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
*/
1616
package org.springframework.data.jpa.repository.query;
1717

18-
import static org.springframework.data.jpa.repository.query.JSqlParserUtils.*;
19-
import static org.springframework.data.jpa.repository.query.QueryUtils.*;
18+
import static org.springframework.data.jpa.repository.query.JSqlParserUtils.getJSqlCount;
19+
import static org.springframework.data.jpa.repository.query.JSqlParserUtils.getJSqlLower;
20+
import static org.springframework.data.jpa.repository.query.QueryUtils.checkSortExpression;
2021

2122
import net.sf.jsqlparser.expression.Alias;
2223
import net.sf.jsqlparser.expression.Expression;
@@ -52,7 +53,6 @@
5253
import java.util.function.Supplier;
5354

5455
import org.jspecify.annotations.Nullable;
55-
5656
import org.springframework.data.domain.Sort;
5757
import org.springframework.data.util.Predicates;
5858
import org.springframework.util.Assert;
@@ -356,6 +356,8 @@ private String doApplySorting(Sort sort, @Nullable String alias) {
356356

357357
private String applySorting(@Nullable Select selectStatement, Sort sort, @Nullable String alias) {
358358

359+
Assert.notNull(selectStatement, "SelectStatement must not be null");
360+
359361
if (selectStatement instanceof SetOperationList setOperationList) {
360362
return applySortingToSetOperationList(setOperationList, sort);
361363
}
@@ -381,6 +383,7 @@ private String applySorting(@Nullable Select selectStatement, Sort sort, @Nullab
381383
}
382384

383385
@Override
386+
@SuppressWarnings("NullAway")
384387
public String createCountQueryFor(@Nullable String countProjection) {
385388

386389
if (this.parsedType != ParsedType.SELECT) {

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryExecution.java

+1
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ private long count(Query resultQuery, AbstractJpaQuery repositoryQuery, JpaParam
295295
return provider.getResultCount(resultQuery, () -> doCount(repositoryQuery, accessor));
296296
}
297297

298+
@SuppressWarnings("NullAway")
298299
long doCount(AbstractJpaQuery repositoryQuery, JpaParametersParameterAccessor accessor) {
299300

300301
List<?> totals = repositoryQuery.createCountQuery(accessor).getResultList();

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterBinding.java

-4
Original file line numberDiff line numberDiff line change
@@ -323,10 +323,6 @@ public boolean isIsNullParameter() {
323323

324324
return Collections.singleton(value);
325325
}
326-
327-
public String lower() {
328-
return null;
329-
}
330326
}
331327

332328
/**

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterMetadataProvider.java

+11-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ private ParameterMetadataProvider(@Nullable Iterator<Object> bindableParameterVa
120120
this.templates = templates;
121121
}
122122

123-
public JpaParameters getParameters() {
123+
JpaParameters getParameters() {
124124
return this.jpaParameters;
125125
}
126126

@@ -216,6 +216,10 @@ private <T> PartTreeParameterBinding next(Part part, Class<T> type, Parameter pa
216216
return binding;
217217
}
218218

219+
/**
220+
* @return the scoring function if available {@link ScoringFunction#unspecified()} by default.
221+
* @since 4.0
222+
*/
219223
ScoringFunction getScoringFunction() {
220224

221225
if (accessor != null) {
@@ -225,6 +229,12 @@ ScoringFunction getScoringFunction() {
225229
return ScoringFunction.unspecified();
226230
}
227231

232+
/**
233+
*
234+
* @return the vector binding identifier.
235+
* @throws IllegalStateException if parameters do not cotain
236+
* @since 4.0
237+
*/
228238
ParameterBinding getVectorBinding() {
229239

230240
if (!getParameters().hasVectorParameter()) {

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/AbstractVectorIntegrationTests.java

+47-18
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
*/
1616
package org.springframework.data.jpa.repository;
1717

18-
import static org.assertj.core.api.Assertions.*;
18+
import static org.assertj.core.api.Assertions.assertThat;
1919

2020
import jakarta.persistence.Column;
2121
import jakarta.persistence.Entity;
@@ -36,7 +36,6 @@
3636
import org.junit.jupiter.api.Test;
3737
import org.junit.jupiter.params.ParameterizedTest;
3838
import org.junit.jupiter.params.provider.MethodSource;
39-
4039
import org.springframework.beans.factory.annotation.Autowired;
4140
import org.springframework.data.domain.Range;
4241
import org.springframework.data.domain.Score;
@@ -53,6 +52,7 @@
5352
* Testcase to verify Vector Search work with Hibernate.
5453
*
5554
* @author Mark Paluch
55+
* @author Christoph Strobl
5656
*/
5757
@Transactional
5858
@Rollback(value = false)
@@ -65,10 +65,11 @@ abstract class AbstractVectorIntegrationTests {
6565
@BeforeEach
6666
void setUp() {
6767

68-
WithVector w1 = new WithVector("de", "one", new float[] { 0.1001f, 0.22345f, 0.33456f, 0.44567f, 0.55678f });
69-
WithVector w2 = new WithVector("de", "two", new float[] { 0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f });
70-
WithVector w3 = new WithVector("en", "three", new float[] { 0.9001f, 0.82345f, 0.73456f, 0.64567f, 0.55678f });
71-
WithVector w4 = new WithVector("de", "four", new float[] { 0.9001f, 0.92345f, 0.93456f, 0.94567f, 0.95678f });
68+
WithVector w1 = new WithVector("de", "one", "d1", new float[] { 0.1001f, 0.22345f, 0.33456f, 0.44567f, 0.55678f });
69+
WithVector w2 = new WithVector("de", "two", "d2", new float[] { 0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f });
70+
WithVector w3 = new WithVector("en", "three", "d3",
71+
new float[] { 0.9001f, 0.82345f, 0.73456f, 0.64567f, 0.55678f });
72+
WithVector w4 = new WithVector("de", "four", "d4", new float[] { 0.9001f, 0.92345f, 0.93456f, 0.94567f, 0.95678f });
7273

7374
repository.deleteAllInBatch();
7475
repository.saveAllAndFlush(Arrays.asList(w1, w2, w3, w4));
@@ -93,7 +94,7 @@ static Set<VectorScoringFunctions> scoringFunctions() {
9394
VectorScoringFunctions.EUCLIDEAN);
9495
}
9596

96-
@Test
97+
@Test // GH-3868
9798
void shouldNormalizeEuclideanSimilarity() {
9899

99100
SearchResults<WithVector> results = repository.searchTop5ByCountryAndEmbeddingWithin("de", VECTOR,
@@ -108,7 +109,16 @@ void shouldNormalizeEuclideanSimilarity() {
108109
assertThat(two.getScore().getValue()).isGreaterThan(0.99);
109110
}
110111

111-
@Test
112+
@Test // GH-3868
113+
void orderTargetsProperty() {
114+
115+
SearchResults<WithVector> results = repository.searchTop5ByCountryAndEmbeddingWithinOrderByDistance("de", VECTOR,
116+
Similarity.of(0, VectorScoringFunctions.EUCLIDEAN));
117+
118+
assertThat(results.getContent()).extracting(it -> it.getContent().getDistance()).containsExactly("d1", "d2", "d4");
119+
}
120+
121+
@Test// GH-3868
112122
void shouldNormalizeCosineSimilarity() {
113123

114124
SearchResults<WithVector> results = repository.searchTop5ByCountryAndEmbeddingWithin("de", VECTOR,
@@ -123,7 +133,7 @@ void shouldNormalizeCosineSimilarity() {
123133
assertThat(two.getScore().getValue()).isGreaterThan(0.99);
124134
}
125135

126-
@Test
136+
@Test // GH-3868
127137
void shouldRunStringQuery() {
128138

129139
List<WithVector> results = repository.findAnnotatedByCountryAndEmbeddingWithin("de", VECTOR,
@@ -133,7 +143,7 @@ void shouldRunStringQuery() {
133143
assertThat(results).extracting(WithVector::getDescription).containsSequence("two", "one", "four");
134144
}
135145

136-
@Test
146+
@Test // GH-3868
137147
void shouldRunStringQueryWithDistance() {
138148

139149
SearchResults<WithVector> results = repository.searchAnnotatedByCountryAndEmbeddingWithin("de", VECTOR,
@@ -149,7 +159,7 @@ void shouldRunStringQueryWithDistance() {
149159
assertThat(result.getScore().getFunction()).isEqualTo(VectorScoringFunctions.COSINE);
150160
}
151161

152-
@Test
162+
@Test // GH-3868
153163
void shouldRunStringQueryWithFloatDistance() {
154164

155165
SearchResults<WithVector> results = repository.searchAnnotatedByCountryAndEmbeddingWithin("de", VECTOR, 2);
@@ -164,7 +174,7 @@ void shouldRunStringQueryWithFloatDistance() {
164174
assertThat(result.getScore().getFunction()).isEqualTo(ScoringFunction.unspecified());
165175
}
166176

167-
@Test
177+
@Test // GH-3868
168178
void shouldApplyVectorSearchWithRange() {
169179

170180
SearchResults<WithVector> results = repository.searchAllByCountryAndEmbeddingWithin("de", VECTOR,
@@ -176,7 +186,7 @@ void shouldApplyVectorSearchWithRange() {
176186
.containsSequence("two", "one", "four");
177187
}
178188

179-
@Test
189+
@Test // GH-3868
180190
void shouldApplyVectorSearchAndReturnList() {
181191

182192
List<WithVector> results = repository.findAllByCountryAndEmbeddingWithin("de", VECTOR,
@@ -186,7 +196,7 @@ void shouldApplyVectorSearchAndReturnList() {
186196
assertThat(results).extracting(WithVector::getDescription).containsSequence("one", "two", "four");
187197
}
188198

189-
@Test
199+
@Test // GH-3868
190200
void shouldProjectVectorSearchAsInterface() {
191201

192202
SearchResults<WithDescription> results = repository.searchInterfaceProjectionByCountryAndEmbeddingWithin("de",
@@ -196,7 +206,7 @@ void shouldProjectVectorSearchAsInterface() {
196206
.containsSequence("two", "one", "four");
197207
}
198208

199-
@Test
209+
@Test // GH-3868
200210
void shouldProjectVectorSearchAsDto() {
201211

202212
SearchResults<DescriptionDto> results = repository.searchDtoByCountryAndEmbeddingWithin("de", VECTOR,
@@ -206,7 +216,7 @@ void shouldProjectVectorSearchAsDto() {
206216
.containsSequence("two", "one", "four");
207217
}
208218

209-
@Test
219+
@Test // GH-3868
210220
void shouldProjectVectorSearchDynamically() {
211221

212222
SearchResults<DescriptionDto> dtos = repository.searchDynamicByCountryAndEmbeddingWithin("de", VECTOR,
@@ -233,16 +243,19 @@ public static class WithVector {
233243
private String country;
234244
private String description;
235245

246+
private String distance;
247+
236248
@Column(name = "the_embedding")
237249
@JdbcTypeCode(SqlTypes.VECTOR)
238250
@Array(length = 5) private float[] embedding;
239251

240252
public WithVector() {}
241253

242-
public WithVector(String country, String description, float[] embedding) {
254+
public WithVector(String country, String description, String distance, float[] embedding) {
243255
this.country = country;
244256
this.description = description;
245257
this.embedding = embedding;
258+
this.distance = distance;
246259
}
247260

248261
public Integer getId() {
@@ -273,9 +286,22 @@ public void setEmbedding(float[] embedding) {
273286
this.embedding = embedding;
274287
}
275288

289+
public void setDescription(String description) {
290+
this.description = description;
291+
}
292+
293+
public String getDistance() {
294+
return distance;
295+
}
296+
297+
public void setDistance(String distance) {
298+
this.distance = distance;
299+
}
300+
276301
@Override
277302
public String toString() {
278-
return "WithVector{" + "country='" + country + '\'' + ", description='" + description + '\'' + '}';
303+
return "WithVector{" + "id=" + id + ", country='" + country + '\'' + ", description='" + description + '\''
304+
+ ", distance='" + distance + '\'' + ", embedding=" + Arrays.toString(embedding) + '}';
279305
}
280306
}
281307

@@ -328,6 +354,9 @@ SearchResults<WithVector> searchAllByCountryAndEmbeddingWithin(String country, V
328354

329355
SearchResults<WithVector> searchTop5ByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance);
330356

357+
SearchResults<WithVector> searchTop5ByCountryAndEmbeddingWithinOrderByDistance(String country, Vector embedding,
358+
Score distance);
359+
331360
SearchResults<WithDescription> searchInterfaceProjectionByCountryAndEmbeddingWithin(String country,
332361
Vector embedding, Score distance);
333362

spring-data-jpa/src/test/resources/scripts/oracle-vector.sql

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ CREATE TABLE IF NOT EXISTS with_vector
55
id NUMBER GENERATED BY DEFAULT ON NULL AS IDENTITY,
66
country varchar2(10),
77
description varchar2(10),
8+
distance varchar2(10),
89
the_embedding vector(5, FLOAT32) annotations(Distance 'COSINE', IndexType 'IVF')
910
);;
1011

spring-data-jpa/src/test/resources/scripts/pgvector.sql

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ CREATE EXTENSION IF NOT EXISTS vector;
22

33
DROP TABLE IF EXISTS with_vector;
44

5-
CREATE TABLE IF NOT EXISTS with_vector (id bigserial PRIMARY KEY,country varchar(10), description varchar(10),the_embedding vector(5));
5+
CREATE TABLE IF NOT EXISTS with_vector (id bigserial PRIMARY KEY,country varchar(10), description varchar(10), distance varchar(10), the_embedding vector(5));
66

77
CREATE INDEX ON with_vector USING hnsw (the_embedding vector_l2_ops);

src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ interface CommentRepository extends Repository<Comment, String> {
1111
WHERE c.country = ?1
1212
AND cosine_distance(c.embedding, :embedding) <= :distance
1313
ORDER BY distance asc""")
14-
SearchResults<WithVector> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding,
14+
SearchResults<Comment> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding,
1515
Score distance);
1616
1717
@Query("""
1818
SELECT c FROM Comment c
1919
WHERE c.country = ?1
2020
AND cosine_distance(c.embedding, :embedding) <= :distance
2121
ORDER BY cosine_distance(c.embedding, :embedding) asc""")
22-
List<WithVector> findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance);
22+
List<Comment> findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance);
2323
}
2424
----
2525
====

src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ interface CommentRepository extends Repository<Comment, String> {
1212
WHERE c.country = ?1
1313
AND cosine_distance(c.embedding, :embedding) <= :distance
1414
ORDER BY distance asc""")
15-
SearchResults<WithVector> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding,
15+
SearchResults<Comment> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding,
1616
Score distance);
1717
}
1818

0 commit comments

Comments
 (0)