15
15
*/
16
16
package org .springframework .data .jpa .repository ;
17
17
18
- import static org .assertj .core .api .Assertions .* ;
18
+ import static org .assertj .core .api .Assertions .assertThat ;
19
19
20
20
import jakarta .persistence .Column ;
21
21
import jakarta .persistence .Entity ;
36
36
import org .junit .jupiter .api .Test ;
37
37
import org .junit .jupiter .params .ParameterizedTest ;
38
38
import org .junit .jupiter .params .provider .MethodSource ;
39
-
40
39
import org .springframework .beans .factory .annotation .Autowired ;
41
40
import org .springframework .data .domain .Range ;
42
41
import org .springframework .data .domain .Score ;
53
52
* Testcase to verify Vector Search work with Hibernate.
54
53
*
55
54
* @author Mark Paluch
55
+ * @author Christoph Strobl
56
56
*/
57
57
@ Transactional
58
58
@ Rollback (value = false )
@@ -65,10 +65,11 @@ abstract class AbstractVectorIntegrationTests {
65
65
@ BeforeEach
66
66
void setUp () {
67
67
68
- WithVector w1 = new WithVector ("de" , "one" , new float [] { 0.1001f , 0.22345f , 0.33456f , 0.44567f , 0.55678f });
69
- WithVector w2 = new WithVector ("de" , "two" , new float [] { 0.2001f , 0.32345f , 0.43456f , 0.54567f , 0.65678f });
70
- WithVector w3 = new WithVector ("en" , "three" , new float [] { 0.9001f , 0.82345f , 0.73456f , 0.64567f , 0.55678f });
71
- WithVector w4 = new WithVector ("de" , "four" , new float [] { 0.9001f , 0.92345f , 0.93456f , 0.94567f , 0.95678f });
68
+ WithVector w1 = new WithVector ("de" , "one" , "d1" , new float [] { 0.1001f , 0.22345f , 0.33456f , 0.44567f , 0.55678f });
69
+ WithVector w2 = new WithVector ("de" , "two" , "d2" , new float [] { 0.2001f , 0.32345f , 0.43456f , 0.54567f , 0.65678f });
70
+ WithVector w3 = new WithVector ("en" , "three" , "d3" ,
71
+ new float [] { 0.9001f , 0.82345f , 0.73456f , 0.64567f , 0.55678f });
72
+ WithVector w4 = new WithVector ("de" , "four" , "d4" , new float [] { 0.9001f , 0.92345f , 0.93456f , 0.94567f , 0.95678f });
72
73
73
74
repository .deleteAllInBatch ();
74
75
repository .saveAllAndFlush (Arrays .asList (w1 , w2 , w3 , w4 ));
@@ -93,7 +94,7 @@ static Set<VectorScoringFunctions> scoringFunctions() {
93
94
VectorScoringFunctions .EUCLIDEAN );
94
95
}
95
96
96
- @ Test
97
+ @ Test // GH-3868
97
98
void shouldNormalizeEuclideanSimilarity () {
98
99
99
100
SearchResults <WithVector > results = repository .searchTop5ByCountryAndEmbeddingWithin ("de" , VECTOR ,
@@ -108,7 +109,16 @@ void shouldNormalizeEuclideanSimilarity() {
108
109
assertThat (two .getScore ().getValue ()).isGreaterThan (0.99 );
109
110
}
110
111
111
- @ Test
112
+ @ Test // GH-3868
113
+ void orderTargetsProperty () {
114
+
115
+ SearchResults <WithVector > results = repository .searchTop5ByCountryAndEmbeddingWithinOrderByDistance ("de" , VECTOR ,
116
+ Similarity .of (0 , VectorScoringFunctions .EUCLIDEAN ));
117
+
118
+ assertThat (results .getContent ()).extracting (it -> it .getContent ().getDistance ()).containsExactly ("d1" , "d2" , "d4" );
119
+ }
120
+
121
+ @ Test // GH-3868
112
122
void shouldNormalizeCosineSimilarity () {
113
123
114
124
SearchResults <WithVector > results = repository .searchTop5ByCountryAndEmbeddingWithin ("de" , VECTOR ,
@@ -123,7 +133,7 @@ void shouldNormalizeCosineSimilarity() {
123
133
assertThat (two .getScore ().getValue ()).isGreaterThan (0.99 );
124
134
}
125
135
126
- @ Test
136
+ @ Test // GH-3868
127
137
void shouldRunStringQuery () {
128
138
129
139
List <WithVector > results = repository .findAnnotatedByCountryAndEmbeddingWithin ("de" , VECTOR ,
@@ -133,7 +143,7 @@ void shouldRunStringQuery() {
133
143
assertThat (results ).extracting (WithVector ::getDescription ).containsSequence ("two" , "one" , "four" );
134
144
}
135
145
136
- @ Test
146
+ @ Test // GH-3868
137
147
void shouldRunStringQueryWithDistance () {
138
148
139
149
SearchResults <WithVector > results = repository .searchAnnotatedByCountryAndEmbeddingWithin ("de" , VECTOR ,
@@ -149,7 +159,7 @@ void shouldRunStringQueryWithDistance() {
149
159
assertThat (result .getScore ().getFunction ()).isEqualTo (VectorScoringFunctions .COSINE );
150
160
}
151
161
152
- @ Test
162
+ @ Test // GH-3868
153
163
void shouldRunStringQueryWithFloatDistance () {
154
164
155
165
SearchResults <WithVector > results = repository .searchAnnotatedByCountryAndEmbeddingWithin ("de" , VECTOR , 2 );
@@ -164,7 +174,7 @@ void shouldRunStringQueryWithFloatDistance() {
164
174
assertThat (result .getScore ().getFunction ()).isEqualTo (ScoringFunction .unspecified ());
165
175
}
166
176
167
- @ Test
177
+ @ Test // GH-3868
168
178
void shouldApplyVectorSearchWithRange () {
169
179
170
180
SearchResults <WithVector > results = repository .searchAllByCountryAndEmbeddingWithin ("de" , VECTOR ,
@@ -176,7 +186,7 @@ void shouldApplyVectorSearchWithRange() {
176
186
.containsSequence ("two" , "one" , "four" );
177
187
}
178
188
179
- @ Test
189
+ @ Test // GH-3868
180
190
void shouldApplyVectorSearchAndReturnList () {
181
191
182
192
List <WithVector > results = repository .findAllByCountryAndEmbeddingWithin ("de" , VECTOR ,
@@ -186,7 +196,7 @@ void shouldApplyVectorSearchAndReturnList() {
186
196
assertThat (results ).extracting (WithVector ::getDescription ).containsSequence ("one" , "two" , "four" );
187
197
}
188
198
189
- @ Test
199
+ @ Test // GH-3868
190
200
void shouldProjectVectorSearchAsInterface () {
191
201
192
202
SearchResults <WithDescription > results = repository .searchInterfaceProjectionByCountryAndEmbeddingWithin ("de" ,
@@ -196,7 +206,7 @@ void shouldProjectVectorSearchAsInterface() {
196
206
.containsSequence ("two" , "one" , "four" );
197
207
}
198
208
199
- @ Test
209
+ @ Test // GH-3868
200
210
void shouldProjectVectorSearchAsDto () {
201
211
202
212
SearchResults <DescriptionDto > results = repository .searchDtoByCountryAndEmbeddingWithin ("de" , VECTOR ,
@@ -206,7 +216,7 @@ void shouldProjectVectorSearchAsDto() {
206
216
.containsSequence ("two" , "one" , "four" );
207
217
}
208
218
209
- @ Test
219
+ @ Test // GH-3868
210
220
void shouldProjectVectorSearchDynamically () {
211
221
212
222
SearchResults <DescriptionDto > dtos = repository .searchDynamicByCountryAndEmbeddingWithin ("de" , VECTOR ,
@@ -233,16 +243,19 @@ public static class WithVector {
233
243
private String country ;
234
244
private String description ;
235
245
246
+ private String distance ;
247
+
236
248
@ Column (name = "the_embedding" )
237
249
@ JdbcTypeCode (SqlTypes .VECTOR )
238
250
@ Array (length = 5 ) private float [] embedding ;
239
251
240
252
public WithVector () {}
241
253
242
- public WithVector (String country , String description , float [] embedding ) {
254
+ public WithVector (String country , String description , String distance , float [] embedding ) {
243
255
this .country = country ;
244
256
this .description = description ;
245
257
this .embedding = embedding ;
258
+ this .distance = distance ;
246
259
}
247
260
248
261
public Integer getId () {
@@ -273,9 +286,22 @@ public void setEmbedding(float[] embedding) {
273
286
this .embedding = embedding ;
274
287
}
275
288
289
+ public void setDescription (String description ) {
290
+ this .description = description ;
291
+ }
292
+
293
+ public String getDistance () {
294
+ return distance ;
295
+ }
296
+
297
+ public void setDistance (String distance ) {
298
+ this .distance = distance ;
299
+ }
300
+
276
301
@ Override
277
302
public String toString () {
278
- return "WithVector{" + "country='" + country + '\'' + ", description='" + description + '\'' + '}' ;
303
+ return "WithVector{" + "id=" + id + ", country='" + country + '\'' + ", description='" + description + '\''
304
+ + ", distance='" + distance + '\'' + ", embedding=" + Arrays .toString (embedding ) + '}' ;
279
305
}
280
306
}
281
307
@@ -328,6 +354,9 @@ SearchResults<WithVector> searchAllByCountryAndEmbeddingWithin(String country, V
328
354
329
355
SearchResults <WithVector > searchTop5ByCountryAndEmbeddingWithin (String country , Vector embedding , Score distance );
330
356
357
+ SearchResults <WithVector > searchTop5ByCountryAndEmbeddingWithinOrderByDistance (String country , Vector embedding ,
358
+ Score distance );
359
+
331
360
SearchResults <WithDescription > searchInterfaceProjectionByCountryAndEmbeddingWithin (String country ,
332
361
Vector embedding , Score distance );
333
362
0 commit comments