From c6500af6c15dc64da52d9feb253b5059fb613d75 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Tue, 15 Apr 2025 15:28:30 +0200 Subject: [PATCH 1/9] Prepare issue branch. --- pom.xml | 4 ++-- spring-data-cassandra-distribution/pom.xml | 2 +- spring-data-cassandra/pom.xml | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pom.xml b/pom.xml index 9d2f70be2..7e3ecaca0 100644 --- a/pom.xml +++ b/pom.xml @@ -11,7 +11,7 @@ org.springframework.data spring-data-cassandra-parent - 5.0.0-SNAPSHOT + 5.0.0-SEARCH-SNAPSHOT pom Spring Data for Apache Cassandra @@ -97,7 +97,7 @@ 0.5.4 1.01 multi - 4.0.0-SNAPSHOT + 4.0.0-SEARCH-RESULT-SNAPSHOT diff --git a/spring-data-cassandra-distribution/pom.xml b/spring-data-cassandra-distribution/pom.xml index cf545591f..3458d30bc 100644 --- a/spring-data-cassandra-distribution/pom.xml +++ b/spring-data-cassandra-distribution/pom.xml @@ -8,7 +8,7 @@ org.springframework.data spring-data-cassandra-parent - 5.0.0-SNAPSHOT + 5.0.0-SEARCH-SNAPSHOT ../pom.xml diff --git a/spring-data-cassandra/pom.xml b/spring-data-cassandra/pom.xml index 7f28c7f2b..ffb830e1a 100644 --- a/spring-data-cassandra/pom.xml +++ b/spring-data-cassandra/pom.xml @@ -8,7 +8,7 @@ org.springframework.data spring-data-cassandra-parent - 5.0.0-SNAPSHOT + 5.0.0-SEARCH-SNAPSHOT ../pom.xml From 21a48f49c7d1a0168d3e550581eeb37d76466e71 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Tue, 15 Apr 2025 15:31:42 +0200 Subject: [PATCH 2/9] Cleanup. --- pom.xml | 11 ----------- spring-data-cassandra/pom.xml | 13 ++++++++----- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/pom.xml b/pom.xml index 7e3ecaca0..99cdb55c8 100644 --- a/pom.xml +++ b/pom.xml @@ -92,11 +92,7 @@ 5.0.3 4.19.0 spring-data-cassandra - 1.0 - - 0.5.4 1.01 - multi 4.0.0-SEARCH-RESULT-SNAPSHOT @@ -158,13 +154,6 @@ test - - com.carrotsearch - hppc - ${hppc.version} - test - - edu.umd.cs.mtc multithreadedtc diff --git a/spring-data-cassandra/pom.xml b/spring-data-cassandra/pom.xml index ffb830e1a..127b50246 100644 --- a/spring-data-cassandra/pom.xml +++ b/spring-data-cassandra/pom.xml @@ -167,6 +167,14 @@ javax.inject javax.inject + + org.perfkit.sjk.parsers + * + + + com.jrockit.mc + * + @@ -198,11 +206,6 @@ multithreadedtc - - com.carrotsearch - hppc - - org.jetbrains.kotlin From 845453a4fa335eecb805d4fff08813b6ffed21e8 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Wed, 16 Apr 2025 10:03:08 +0200 Subject: [PATCH 3/9] Explore returning Search Results. --- .../cassandra/core/CassandraOperations.java | 45 ++-- .../cassandra/core/CassandraTemplate.java | 16 +- .../data/cassandra/core/StatementFactory.java | 5 +- .../data/cassandra/core/query/ColumnName.java | 2 +- .../data/cassandra/core/query/Columns.java | 10 +- .../core/query/SerializationUtils.java | 2 +- .../query/AbstractCassandraQuery.java | 13 +- .../repository/query/BindingContext.java | 62 ++++-- .../query/CassandraParameterAccessor.java | 24 +++ .../repository/query/CassandraParameters.java | 34 ++- .../CassandraParametersParameterAccessor.java | 39 ++++ .../query/CassandraQueryCreator.java | 58 ++++- .../query/CassandraQueryExecution.java | 43 ++++ .../query/ConvertingParameterAccessor.java | 29 +++ .../query/QueryStatementCreator.java | 65 +++++- .../VectorSearchIntegrationTests.java | 202 ++++++++++++++++++ .../query/CassandraParametersUnitTests.java | 14 +- .../query/StubParameterAccessor.java | 29 +++ .../test/util/CassandraDelegate.java | 2 +- 19 files changed, 626 insertions(+), 68 deletions(-) create mode 100644 spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/VectorSearchIntegrationTests.java diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraOperations.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraOperations.java index 943473ee2..ad4d3289c 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraOperations.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraOperations.java @@ -17,9 +17,11 @@ import java.util.Iterator; import java.util.List; +import java.util.function.BiFunction; import java.util.stream.Stream; import org.jspecify.annotations.Nullable; + import org.springframework.dao.DataAccessException; import org.springframework.data.cassandra.core.convert.CassandraConverter; import org.springframework.data.cassandra.core.cql.CqlOperations; @@ -33,6 +35,7 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.cql.BatchType; import com.datastax.oss.driver.api.core.cql.ResultSet; +import com.datastax.oss.driver.api.core.cql.Row; import com.datastax.oss.driver.api.core.cql.Statement; /** @@ -92,7 +95,7 @@ default CassandraBatchOperations batchOps() { /** * The table name used for the specified class by this template. * - * @param entityClass The entity type must not be {@literal null}. + * @param entityClass the entity type must not be {@literal null}. * @return the {@link CqlIdentifier} */ CqlIdentifier getTableName(Class entityClass); @@ -105,7 +108,7 @@ default CassandraBatchOperations batchOps() { * Execute a {@code SELECT} query and convert the resulting items to a {@link List} of entities. * * @param cql must not be {@literal null}. - * @param entityClass The entity type must not be {@literal null}. + * @param entityClass the entity type must not be {@literal null}. * @return the converted results * @throws DataAccessException if there is any problem executing the query. */ @@ -129,7 +132,7 @@ default CassandraBatchOperations batchOps() { * Execute a {@code SELECT} query and convert the resulting item to an entity. * * @param cql must not be {@literal null}. - * @param entityClass The entity type must not be {@literal null}. + * @param entityClass the entity type must not be {@literal null}. * @return the converted object or {@literal null}. * @throws DataAccessException if there is any problem executing the query. */ @@ -154,18 +157,32 @@ default CassandraBatchOperations batchOps() { * Execute a {@code SELECT} query and convert the resulting items to a {@link List} of entities. * * @param statement must not be {@literal null}. - * @param entityClass The entity type must not be {@literal null}. + * @param entityClass the entity type must not be {@literal null}. * @return the converted results * @throws DataAccessException if there is any problem executing the query. */ List select(Statement statement, Class entityClass) throws DataAccessException; + /** + * Execute a {@code SELECT} query and convert the resulting items to a {@link List} of entities considering the given + * {@link BiFunction mapping function}. + * + * @param statement must not be {@literal null}. + * @param entityClass the entity type must not be {@literal null}. + * @param mapper mapping function invoked after materializing {@code entityClass} must not be {@literal null}. + * @return the converted results + * @throws DataAccessException if there is any problem executing the query. + * @since 5.0 + */ + List select(Statement statement, Class entityClass, BiFunction mapper) + throws DataAccessException; + /** * Execute a {@code SELECT} query with paging and convert the result set to a {@link Slice} of entities. A sliced * query translates the effective {@link Statement#getFetchSize() fetch size} to the page size. * * @param statement the CQL statement, must not be {@literal null}. - * @param entityClass The entity type must not be {@literal null}. + * @param entityClass the entity type must not be {@literal null}. * @return the converted results * @throws DataAccessException if there is any problem executing the query. * @since 2.0 @@ -190,7 +207,7 @@ default CassandraBatchOperations batchOps() { * Execute a {@code SELECT} query and convert the resulting item to an entity. * * @param statement must not be {@literal null}. - * @param entityClass The entity type must not be {@literal null}. + * @param entityClass the entity type must not be {@literal null}. * @return the converted object or {@literal null}. * @throws DataAccessException if there is any problem executing the query. */ @@ -204,7 +221,7 @@ default CassandraBatchOperations batchOps() { * Execute a {@code SELECT} query and convert the resulting items to a {@link List} of entities. * * @param query must not be {@literal null}. - * @param entityClass The entity type must not be {@literal null}. + * @param entityClass the entity type must not be {@literal null}. * @return the converted results * @throws DataAccessException if there is any problem executing the query. * @since 2.0 @@ -215,7 +232,7 @@ default CassandraBatchOperations batchOps() { * Execute a {@code SELECT} query with paging and convert the result set to a {@link Slice} of entities. * * @param query the query object used to create a CQL statement, must not be {@literal null}. - * @param entityClass The entity type must not be {@literal null}. + * @param entityClass the entity type must not be {@literal null}. * @return the converted results * @throws DataAccessException if there is any problem executing the query. * @since 2.0 @@ -241,7 +258,7 @@ default CassandraBatchOperations batchOps() { * Execute a {@code SELECT} query and convert the resulting item to an entity. * * @param query must not be {@literal null}. - * @param entityClass The entity type must not be {@literal null}. + * @param entityClass the entity type must not be {@literal null}. * @return the converted object or {@literal null}. * @throws DataAccessException if there is any problem executing the query. * @since 2.0 @@ -253,7 +270,7 @@ default CassandraBatchOperations batchOps() { * * @param query must not be {@literal null}. * @param update must not be {@literal null}. - * @param entityClass The entity type must not be {@literal null}. + * @param entityClass the entity type must not be {@literal null}. * @throws DataAccessException if there is any problem executing the query. */ boolean update(Query query, Update update, Class entityClass) throws DataAccessException; @@ -262,7 +279,7 @@ default CassandraBatchOperations batchOps() { * Remove entities (rows)/columns from the table by {@link Query}. * * @param query must not be {@literal null}. - * @param entityClass The entity type must not be {@literal null}. + * @param entityClass the entity type must not be {@literal null}. * @throws DataAccessException if there is any problem executing the query. */ boolean delete(Query query, Class entityClass) throws DataAccessException; @@ -322,7 +339,7 @@ default CassandraBatchOperations batchOps() { * @param id the Id value. For single primary keys it's the plain value. For composite primary keys either the * {@link org.springframework.data.cassandra.core.mapping.PrimaryKeyClass} or * {@link org.springframework.data.cassandra.core.mapping.MapId}. Must not be {@literal null}. - * @param entityClass The entity type must not be {@literal null}. + * @param entityClass the entity type must not be {@literal null}. * @return the converted object or {@literal null}. * @throws DataAccessException if there is any problem executing the query. */ @@ -407,7 +424,7 @@ default WriteResult delete(Object entity, DeleteOptions options) throws DataAcce * @param id the Id value. For single primary keys it's the plain value. For composite primary keys either the * {@link org.springframework.data.cassandra.core.mapping.PrimaryKeyClass} or * {@link org.springframework.data.cassandra.core.mapping.MapId}. Must not be {@literal null}. - * @param entityClass The entity type must not be {@literal null}. + * @param entityClass the entity type must not be {@literal null}. * @throws DataAccessException if there is any problem executing the query. */ boolean deleteById(Object id, Class entityClass) throws DataAccessException; @@ -415,7 +432,7 @@ default WriteResult delete(Object entity, DeleteOptions options) throws DataAcce /** * Execute a {@code TRUNCATE} query to remove all entities of a given class. * - * @param entityClass The entity type must not be {@literal null}. + * @param entityClass the entity type must not be {@literal null}. * @throws DataAccessException if there is any problem executing the query. */ void truncate(Class entityClass) throws DataAccessException; diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraTemplate.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraTemplate.java index 9ec3804cb..0b1f614da 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraTemplate.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraTemplate.java @@ -16,6 +16,7 @@ package org.springframework.data.cassandra.core; import java.util.List; +import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; @@ -349,14 +350,25 @@ public ResultSet execute(Statement statement) { @Override public List select(Statement statement, Class entityClass) { + return select(statement, entityClass, (t, row) -> t); + } + + @Override + public List select(Statement statement, Class entityClass, BiFunction mapper) + throws DataAccessException { Assert.notNull(statement, "Statement must not be null"); Assert.notNull(entityClass, "Entity type must not be null"); + Assert.notNull(mapper, "Row Mapper function must not be null"); - Function mapper = getMapper(EntityProjection.nonProjecting(entityClass), + Function defaultMapper = getMapper(EntityProjection.nonProjecting(entityClass), EntityQueryUtils.getTableName(statement)); - return doQuery(statement, (row, rowNum) -> mapper.apply(row)); + return doQuery(statement, (row, rowNum) -> { + + S intermediate = defaultMapper.apply(row); + return mapper.apply(intermediate, row); + }); } @Override diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/StatementFactory.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/StatementFactory.java index 79dbfd6fc..cab05be15 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/StatementFactory.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/StatementFactory.java @@ -767,12 +767,11 @@ private static com.datastax.oss.driver.api.querybuilder.select.Selector getSelec .stream().map(param -> { if (param instanceof ColumnSelector s) { - - return com.datastax.oss.driver.api.querybuilder.select.Selector.column(s.getExpression()); + return com.datastax.oss.driver.api.querybuilder.select.Selector.column(s.getIdentifier()); } if (param instanceof CqlIdentifier i) { - return com.datastax.oss.driver.api.querybuilder.select.Selector.column(i.toString()); + return com.datastax.oss.driver.api.querybuilder.select.Selector.column(i); } return new SimpleSelector(param.toString()); diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/query/ColumnName.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/query/ColumnName.java index 3a3837b31..6fccd44c7 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/query/ColumnName.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/query/ColumnName.java @@ -169,7 +169,7 @@ public Optional getCqlIdentifier() { @Override public String toCql() { - return this.cqlIdentifier.toString(); + return this.cqlIdentifier.asInternal(); } @Override diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/query/Columns.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/query/Columns.java index a63cc5f28..5fc32110e 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/query/Columns.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/query/Columns.java @@ -26,10 +26,11 @@ import java.util.Optional; import java.util.function.Function; +import org.jspecify.annotations.Nullable; + import org.springframework.data.cassandra.core.convert.CassandraVector; import org.springframework.data.cassandra.core.mapping.SimilarityFunction; import org.springframework.data.domain.Vector; -import org.jspecify.annotations.Nullable; import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; @@ -340,6 +341,9 @@ default Selector as(String alias) { */ Selector as(CqlIdentifier alias); + /** + * @return the expression that forms this selection. + */ String getExpression(); Optional getAlias(); @@ -410,6 +414,10 @@ public Optional getAlias() { return alias; } + public CqlIdentifier getIdentifier() { + return columnName.getCqlIdentifier().orElseGet(() -> CqlIdentifier.fromCql(columnName.toCql())); + } + @Override public String getExpression() { return columnName.toCql(); diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/query/SerializationUtils.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/query/SerializationUtils.java index 7d49a1b79..61373f73e 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/query/SerializationUtils.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/query/SerializationUtils.java @@ -57,7 +57,7 @@ private SerializationUtils() {} CriteriaDefinition.Predicate predicate = criteria.getPredicate(); return String.format("%s %s", criteria.getColumnName(), - predicate.getOperator().toCql(serializeToCqlSafely(predicate.getValue()))); + predicate.getOperator().toCql(predicate != null ? serializeToCqlSafely(predicate.getValue()) : "")); } /** diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/AbstractCassandraQuery.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/AbstractCassandraQuery.java index a8d99be73..66adea760 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/AbstractCassandraQuery.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/AbstractCassandraQuery.java @@ -16,17 +16,10 @@ package org.springframework.data.cassandra.repository.query; import org.jspecify.annotations.Nullable; + import org.springframework.core.convert.converter.Converter; import org.springframework.data.cassandra.core.CassandraOperations; -import org.springframework.data.cassandra.repository.query.CassandraQueryExecution.CollectionExecution; -import org.springframework.data.cassandra.repository.query.CassandraQueryExecution.ExistsExecution; -import org.springframework.data.cassandra.repository.query.CassandraQueryExecution.ResultProcessingConverter; -import org.springframework.data.cassandra.repository.query.CassandraQueryExecution.ResultProcessingExecution; -import org.springframework.data.cassandra.repository.query.CassandraQueryExecution.ResultSetQuery; -import org.springframework.data.cassandra.repository.query.CassandraQueryExecution.SingleEntityExecution; -import org.springframework.data.cassandra.repository.query.CassandraQueryExecution.SlicedExecution; -import org.springframework.data.cassandra.repository.query.CassandraQueryExecution.StreamExecution; -import org.springframework.data.cassandra.repository.query.CassandraQueryExecution.WindowExecution; +import org.springframework.data.cassandra.repository.query.CassandraQueryExecution.*; import org.springframework.data.repository.query.ParameterAccessor; import org.springframework.data.repository.query.RepositoryQuery; import org.springframework.data.repository.query.ResultProcessor; @@ -120,6 +113,8 @@ private CassandraQueryExecution getExecutionToWrap(CassandraParameterAccessor pa return new SlicedExecution(getOperations(), parameterAccessor.getPageable()); } else if (getQueryMethod().isScrollQuery()) { return new WindowExecution(getOperations(), parameterAccessor.getScrollPosition(), parameterAccessor.getLimit()); + } else if (getQueryMethod().isSearchQuery()) { + return new SearchExecution(getOperations(), parameterAccessor); } else if (getQueryMethod().isCollectionQuery()) { return new CollectionExecution(getOperations()); } else if (getQueryMethod().isResultSetQuery()) { diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/BindingContext.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/BindingContext.java index 378795ab7..028f1bd0b 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/BindingContext.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/BindingContext.java @@ -20,9 +20,9 @@ import java.util.List; import org.jspecify.annotations.Nullable; + +import org.springframework.data.domain.Limit; import org.springframework.data.mapping.model.ValueExpressionEvaluator; -import org.springframework.data.repository.query.Parameter; -import org.springframework.data.repository.query.ParameterAccessor; import org.springframework.util.Assert; /** @@ -35,7 +35,7 @@ class BindingContext { private final CassandraParameters parameters; - private final ParameterAccessor parameterAccessor; + private final CassandraParameterAccessor parameterAccessor; private final List bindings; @@ -44,7 +44,7 @@ class BindingContext { /** * Create new {@link BindingContext}. */ - BindingContext(CassandraParameters parameters, ParameterAccessor parameterAccessor, + BindingContext(CassandraParameters parameters, CassandraParameterAccessor parameterAccessor, List bindings, ValueExpressionEvaluator evaluator) { this.parameters = parameters; @@ -75,8 +75,9 @@ public List getBindingValues() { List parameters = new ArrayList<>(bindings.size()); for (ParameterBinding binding : bindings) { + Object parameterValueForBinding = getParameterValueForBinding(binding); - parameters.add(parameterValueForBinding); + parameters.add(binding.prepareValue(parameterValueForBinding)); } return parameters; @@ -95,20 +96,20 @@ public List getBindingValues() { } return binding.isNamed() - ? parameterAccessor.getBindableValue(getParameterIndex(parameters, binding.getRequiredParameterName())) + ? parameterAccessor.getValue(getParameterIndex(parameters, binding.getRequiredParameterName())) : parameterAccessor.getBindableValue(binding.getParameterIndex()); } private int getParameterIndex(CassandraParameters parameters, String parameterName) { - return parameters.stream() // - .filter(cassandraParameter -> cassandraParameter // - .getName().filter(s -> s.equals(parameterName)) // - .isPresent()) // - .mapToInt(Parameter::getIndex) // - .findFirst() // - .orElseThrow(() -> new IllegalArgumentException( - String.format("Invalid parameter name; Cannot resolve parameter [%s]", parameterName))); + for (CassandraParameters.CassandraParameter parameter : parameters) { + if (parameter.getName().filter(s -> s.equals(parameterName)).isPresent()) { + return parameter.getIndex(); + } + } + + throw new IllegalArgumentException( + String.format("Invalid parameter name; Cannot resolve parameter [%s]", parameterName)); } /** @@ -129,31 +130,31 @@ private ParameterBinding(int parameterIndex, @Nullable String expression, @Nulla this.parameterName = parameterName; } - static ParameterBinding expression(String expression, boolean quoted) { + public static ParameterBinding expression(String expression, boolean quoted) { return new ParameterBinding(-1, expression, null); } - static ParameterBinding indexed(int parameterIndex) { + public static ParameterBinding indexed(int parameterIndex) { return new ParameterBinding(parameterIndex, null, null); } - static ParameterBinding named(String name) { + public static ParameterBinding named(String name) { return new ParameterBinding(-1, null, name); } - boolean isNamed() { + public boolean isNamed() { return (parameterName != null); } - int getParameterIndex() { + public int getParameterIndex() { return parameterIndex; } - String getParameter() { + public String getParameter() { return ("?" + (isExpression() ? "expr" : "") + parameterIndex); } - String getRequiredExpression() { + public String getRequiredExpression() { Assert.state(expression != null, "ParameterBinding is not an expression"); return expression; @@ -169,5 +170,24 @@ String getRequiredParameterName() { return parameterName; } + + /** + * Prepare a value before binding it to the query. + * + * @param value + * @return + */ + public @Nullable Object prepareValue(@Nullable Object value) { + + if (value == null) { + return value; + } + + if (value instanceof Limit limit) { + return limit.max(); + } + + return value; + } } } diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraParameterAccessor.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraParameterAccessor.java index 022de0a25..8d3c8a77e 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraParameterAccessor.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraParameterAccessor.java @@ -16,9 +16,11 @@ package org.springframework.data.cassandra.repository.query; import org.jspecify.annotations.Nullable; + import org.springframework.data.cassandra.core.cql.QueryOptions; import org.springframework.data.cassandra.core.mapping.CassandraType; import org.springframework.data.cassandra.core.query.CassandraScrollPosition; +import org.springframework.data.domain.ScoringFunction; import org.springframework.data.repository.query.ParameterAccessor; import com.datastax.oss.driver.api.core.type.DataType; @@ -65,6 +67,17 @@ public interface CassandraParameterAccessor extends ParameterAccessor { */ Class getParameterType(int index); + /** + * Get the value of the parameter at the given index. In contrast to {@link #getBindableValue(int)}, this method has + * access to all parameters. + * + * @param parameterIndex + * @return + * @since 5.0 + */ + @Nullable + Object getValue(int parameterIndex); + /** * Returns the raw parameter values of the underlying query method. * @@ -76,6 +89,16 @@ public interface CassandraParameterAccessor extends ParameterAccessor { @Override CassandraScrollPosition getScrollPosition(); + /** + * Returns the {@link ScoringFunction} from a {@link org.springframework.data.domain.Score} or + * {@link org.springframework.data.domain.Range} of scores if such a parameter is declared. + * + * @return the scoring function or {@literal null} if none is provided. + * @since 5.0 + */ + @Nullable + ScoringFunction getScoringFunction(); + /** * Returns the {@link QueryOptions} associated with the associated Repository query method. * @@ -85,4 +108,5 @@ public interface CassandraParameterAccessor extends ParameterAccessor { @Nullable QueryOptions getQueryOptions(); + } diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraParameters.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraParameters.java index bd9fc6c7c..da24b3b10 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraParameters.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraParameters.java @@ -29,6 +29,9 @@ import org.springframework.data.cassandra.core.cql.QueryOptions; import org.springframework.data.cassandra.core.mapping.CassandraType; import org.springframework.data.cassandra.repository.query.CassandraParameters.CassandraParameter; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.ScoringFunction; import org.springframework.data.repository.query.Parameter; import org.springframework.data.repository.query.Parameters; import org.springframework.data.repository.query.ParametersSource; @@ -46,6 +49,7 @@ public class CassandraParameters extends Parameters { private final @Nullable Integer queryOptionsIndex; + private final @Nullable Integer scoringFunctionIndex; /** * Create a new {@link CassandraParameters} instance from the given {@link Method}. @@ -58,18 +62,23 @@ public CassandraParameters(ParametersSource parametersSource) { this.queryOptionsIndex = Arrays.asList(parametersSource.getMethod().getParameterTypes()) .indexOf(QueryOptions.class); + + this.scoringFunctionIndex = Arrays.asList(parametersSource.getMethod().getParameterTypes()) + .indexOf(ScoringFunction.class); } - private CassandraParameters(List originals, @Nullable Integer queryOptionsIndex) { + private CassandraParameters(List originals, @Nullable Integer queryOptionsIndex, + @Nullable Integer scoringFunctionIndex) { super(originals); this.queryOptionsIndex = queryOptionsIndex; + this.scoringFunctionIndex = scoringFunctionIndex; } @Override protected CassandraParameters createFrom(List parameters) { - return new CassandraParameters(parameters, queryOptionsIndex); + return new CassandraParameters(parameters, queryOptionsIndex, scoringFunctionIndex); } /** @@ -82,6 +91,16 @@ public int getQueryOptionsIndex() { return (queryOptionsIndex != null ? queryOptionsIndex : -1); } + /** + * Returns the index of the {@link ScoringFunction} parameter to be applied to queries. + * + * @return + * @since 5.0 + */ + public int getScoringFunctionIndex() { + return (scoringFunctionIndex != null ? scoringFunctionIndex : -1); + } + /** * Custom {@link Parameter} implementation adding {@link CassandraType} support. * @@ -91,6 +110,8 @@ static class CassandraParameter extends Parameter { private final @Nullable CassandraType cassandraType; private final Class parameterType; + private final boolean isScoreRange; + private final boolean isScoringFunction; CassandraParameter(MethodParameter parameter, TypeInformation domainType) { @@ -104,12 +125,17 @@ static class CassandraParameter extends Parameter { this.cassandraType = null; } - parameterType = potentiallyUnwrapParameterType(parameter); + this.parameterType = potentiallyUnwrapParameterType(parameter); + + ResolvableType type = ResolvableType.forMethodParameter(parameter); + this.isScoreRange = Range.class.isAssignableFrom(getType()) && type.getGeneric(0).isAssignableFrom(Score.class); + this.isScoringFunction = ScoringFunction.class.isAssignableFrom(getType()); } @Override public boolean isSpecialParameter() { - return super.isSpecialParameter() || QueryOptions.class.isAssignableFrom(getType()); + return super.isSpecialParameter() || isScoreRange || isScoringFunction || Score.class.isAssignableFrom(getType()) + || QueryOptions.class.isAssignableFrom(getType()); } /** diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraParametersParameterAccessor.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraParametersParameterAccessor.java index 947010361..c766be770 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraParametersParameterAccessor.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraParametersParameterAccessor.java @@ -16,11 +16,15 @@ package org.springframework.data.cassandra.repository.query; import org.jspecify.annotations.Nullable; + import org.springframework.data.cassandra.core.cql.QueryOptions; import org.springframework.data.cassandra.core.mapping.CassandraSimpleTypeHolder; import org.springframework.data.cassandra.core.mapping.CassandraType; import org.springframework.data.cassandra.core.query.CassandraScrollPosition; import org.springframework.data.domain.Limit; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.ScoringFunction; import org.springframework.data.domain.ScrollPosition; import org.springframework.data.repository.query.ParameterAccessor; import org.springframework.data.repository.query.ParametersParameterAccessor; @@ -79,6 +83,11 @@ public CassandraParameters getParameters() { return super.getValues(); } + @Override + public @Nullable Object getValue(int parameterIndex) { + return super.getValue(parameterIndex); + } + @Override public CassandraScrollPosition getScrollPosition() { @@ -95,6 +104,36 @@ public CassandraScrollPosition getScrollPosition() { "Unsupported scroll position " + scrollPosition + ". Only CassandraScrollPosition supported."); } + @Override + public @Nullable ScoringFunction getScoringFunction() { + + Score score = getScore(); + + if (score != null) { + return score.getFunction(); + } + + Range range = getScoreRange(); + + if (range != null) { + + if (range.getLowerBound().isBounded()) { + return range.getLowerBound().getValue().get().getFunction(); + } + + if (range.getUpperBound().isBounded()) { + return range.getUpperBound().getValue().get().getFunction(); + } + } + + int scoringFunctionIndex = getParameters().getScoringFunctionIndex(); + if (scoringFunctionIndex != -1) { + return (ScoringFunction) getValue(scoringFunctionIndex); + } + + return null; + } + @Override public Limit getLimit() { diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraQueryCreator.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraQueryCreator.java index e475a285d..eebb31904 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraQueryCreator.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraQueryCreator.java @@ -32,9 +32,12 @@ import org.springframework.data.cassandra.core.query.CriteriaDefinition; import org.springframework.data.cassandra.core.query.Filter; import org.springframework.data.cassandra.core.query.Query; +import org.springframework.data.cassandra.core.query.VectorSort; import org.springframework.data.domain.Range; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.mapping.PersistentPropertyPath; +import org.springframework.data.mapping.PropertyPath; import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.repository.query.parser.AbstractQueryCreator; import org.springframework.data.repository.query.parser.Part; @@ -56,6 +59,8 @@ class CassandraQueryCreator extends AbstractQueryCreator { private final MappingContext mappingContext; private final QueryBuilder queryBuilder = new QueryBuilder(); + private final CassandraParameterAccessor parameterAccessor; + private final PartTree tree; /** * Create a new {@link CassandraQueryCreator} from the given {@link PartTree}, {@link ConvertingParameterAccessor} and @@ -72,6 +77,8 @@ public CassandraQueryCreator(PartTree tree, CassandraParameterAccessor parameter Assert.notNull(mappingContext, "CassandraMappingContext must not be null"); + this.tree = tree; + this.parameterAccessor = parameterAccessor; this.mappingContext = mappingContext; } @@ -102,9 +109,6 @@ protected Filter create(Part part, Iterator iterator) { .getPersistentPropertyPath(part.getProperty()); CassandraPersistentProperty property = path.getLeafProperty(); - - Assert.state(property != null && path.toDotPath() != null, "Leaf property must not be null"); - Object filterOrCriteria = from(part, property, Criteria.where(path.toDotPath()), iterator); if (filterOrCriteria instanceof CriteriaDefinition) { @@ -115,10 +119,12 @@ protected Filter create(Part part, Iterator iterator) { } @Override - protected Filter and(Part part, Filter base, Iterator iterator) { + protected Filter and(Part part, @Nullable Filter base, Iterator iterator) { - for (CriteriaDefinition criterion : base) { - getQueryBuilder().and(criterion); + if (base != null) { + for (CriteriaDefinition criterion : base) { + getQueryBuilder().and(criterion); + } } return create(part, iterator); @@ -139,7 +145,8 @@ protected Query complete(@Nullable Filter criteria, Sort sort) { } } - Query query = getQueryBuilder().create(sort); + Query query = sort.isUnsorted() && parameterAccessor.getVector() != null ? getQueryBuilder().create(getVectorSort()) + : getQueryBuilder().create(sort); if (LOG.isDebugEnabled()) { LOG.debug(String.format("Created query [%s]", query)); @@ -148,10 +155,29 @@ protected Query complete(@Nullable Filter criteria, Sort sort) { return query; } + private Sort getVectorSort() { + return VectorSort.ann(getVectorProperty().toDotPath(), parameterAccessor.getVector()); + } + + PropertyPath getVectorProperty() { + + for (PartTree.OrPart parts : tree) { + for (Part part : parts) { + + if (part.getType() == Type.NEAR || part.getType() == Type.WITHIN) { + return part.getProperty(); + } + } + } + + throw new IllegalArgumentException("No Near/Within property found"); + } + /** * Returns a {@link Filter} or {@link CriteriaDefinition} object representing the criterion for a {@link Part}. */ - private Object from(Part part, CassandraPersistentProperty property, Criteria where, Iterator parameters) { + private @Nullable Object from(Part part, CassandraPersistentProperty property, Criteria where, + Iterator parameters) { Type type = part.getType(); @@ -182,10 +208,24 @@ private Object from(Part part, CassandraPersistentProperty property, Criteria wh return where.is(false); case SIMPLE_PROPERTY: return where.is(parameters.next()); + + case NEAR: + case WITHIN: + + Object next = parameters.next(); + + if (!(next instanceof Vector)) { + + throw new IllegalArgumentException("Expected a Vector for Near/Within keyword but got [%s]" + .formatted(next == null ? "null" : next.getClass())); + } + + return null; default: throw new InvalidDataAccessApiUsageException( String.format("Unsupported keyword [%s] in part [%s]", type, part)); } + } /** @@ -280,7 +320,7 @@ private Object[] nextAsArray(Iterator iterator) { */ static class QueryBuilder { - private List criterias = new ArrayList<>(); + private final List criterias = new ArrayList<>(); CriteriaDefinition and(CriteriaDefinition clause) { criterias.add(clause); diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraQueryExecution.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraQueryExecution.java index 8cb650633..f14384e7d 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraQueryExecution.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraQueryExecution.java @@ -18,6 +18,7 @@ import java.util.List; import org.jspecify.annotations.Nullable; + import org.springframework.core.convert.converter.Converter; import org.springframework.dao.IncorrectResultSizeDataAccessException; import org.springframework.data.cassandra.core.CassandraOperations; @@ -28,6 +29,10 @@ import org.springframework.data.convert.DtoInstantiatingConverter; import org.springframework.data.domain.Limit; import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.ScoringFunction; +import org.springframework.data.domain.SearchResult; +import org.springframework.data.domain.SearchResults; import org.springframework.data.domain.Slice; import org.springframework.data.domain.SliceImpl; import org.springframework.data.mapping.context.MappingContext; @@ -164,6 +169,44 @@ public Object execute(Statement statement, Class type) { } + final class SearchExecution implements CassandraQueryExecution { + + private final CassandraOperations operations; + private final CassandraParameterAccessor accessor; + + public SearchExecution(CassandraOperations operations, CassandraParameterAccessor accessor) { + + this.operations = operations; + this.accessor = accessor; + } + + @Override + public Object execute(Statement statement, Class type) { + + ScoringFunction function = accessor.getScoringFunction(); + + List> results = operations.select(statement, type, (o, row) -> { + + if (row.getColumnDefinitions().contains("__score__")) { + return new SearchResult<>(o, getScore(row, "__score__", function)); + } + + if (row.getColumnDefinitions().contains("score")) { + return new SearchResult<>(o, getScore(row, "score", function)); + } + return new SearchResult<>(o, 0); + }); + + return new SearchResults(results); + } + + private Score getScore(Row row, String columnName, ScoringFunction function) { + + Object object = row.getObject(columnName); + return Score.of(((Number) object).doubleValue(), function); + } + } + /** * {@link CassandraQueryExecution} to return a single entity. * diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/ConvertingParameterAccessor.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/ConvertingParameterAccessor.java index 003de8677..53f539826 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/ConvertingParameterAccessor.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/ConvertingParameterAccessor.java @@ -18,6 +18,7 @@ import java.util.Iterator; import org.jspecify.annotations.Nullable; + import org.springframework.data.cassandra.core.convert.CassandraConverter; import org.springframework.data.cassandra.core.cql.QueryOptions; import org.springframework.data.cassandra.core.mapping.CassandraType; @@ -25,7 +26,10 @@ import org.springframework.data.domain.Limit; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.ScoringFunction; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import com.datastax.oss.driver.api.core.type.DataType; @@ -54,6 +58,26 @@ public CassandraScrollPosition getScrollPosition() { return delegate.getScrollPosition(); } + @Override + public ScoringFunction getScoringFunction() { + return delegate.getScoringFunction(); + } + + @Override + public @Nullable Vector getVector() { + return delegate.getVector(); + } + + @Override + public @Nullable Score getScore() { + return delegate.getScore(); + } + + @Override + public @Nullable Range getScoreRange() { + return delegate.getScoreRange(); + } + @Override public Pageable getPageable() { return this.delegate.getPageable(); @@ -113,6 +137,11 @@ public Object[] getValues() { return this.delegate.getValues(); } + @Override + public @Nullable Object getValue(int parameterIndex) { + return potentiallyConvert(parameterIndex, this.delegate.getValue(parameterIndex)); + } + @Nullable Object potentiallyConvert(int index, @Nullable Object bindableValue) { diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/QueryStatementCreator.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/QueryStatementCreator.java index e023eea7b..c4b96e0b3 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/QueryStatementCreator.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/QueryStatementCreator.java @@ -15,11 +15,16 @@ */ package org.springframework.data.cassandra.repository.query; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.function.Function; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.jspecify.annotations.Nullable; + import org.springframework.data.cassandra.core.StatementFactory; import org.springframework.data.cassandra.core.cql.QueryExtractorDelegate; import org.springframework.data.cassandra.core.cql.QueryOptions; @@ -27,10 +32,15 @@ import org.springframework.data.cassandra.core.cql.QueryOptionsUtil; import org.springframework.data.cassandra.core.mapping.CassandraPersistentEntity; import org.springframework.data.cassandra.core.mapping.CassandraPersistentProperty; +import org.springframework.data.cassandra.core.mapping.SimilarityFunction; import org.springframework.data.cassandra.core.query.Columns; import org.springframework.data.cassandra.core.query.Query; import org.springframework.data.cassandra.repository.Query.Idempotency; import org.springframework.data.domain.Limit; +import org.springframework.data.domain.ScoringFunction; +import org.springframework.data.domain.Vector; +import org.springframework.data.domain.VectorScoringFunctions; +import org.springframework.data.mapping.PropertyPath; import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mapping.model.ValueExpressionEvaluator; import org.springframework.data.repository.query.QueryCreationException; @@ -50,6 +60,10 @@ */ class QueryStatementCreator { + private static final Map SIMILARITY_FUNCTIONS = Map.of( + VectorScoringFunctions.COSINE, SimilarityFunction.COSINE, VectorScoringFunctions.EUCLIDEAN, + SimilarityFunction.EUCLIDEAN, VectorScoringFunctions.DOT, SimilarityFunction.DOT_PRODUCT); + private static final Log LOG = LogFactory.getLog(QueryStatementCreator.class); private final CassandraQueryMethod queryMethod; @@ -81,9 +95,28 @@ SimpleStatement select(StatementFactory statementFactory, PartTree tree, Cassand ReturnedType returnedType = processor.withDynamicProjection(parameterAccessor).getReturnedType(); + Columns columns = null; if (returnedType.needsCustomConstruction()) { + columns = Columns.from(returnedType.getInputProperties().toArray(new String[0])); + } else if (queryMethod.isSearchQuery()) { + columns = getColumns(returnedType.getReturnedType()); + } + + if (columns != null) { + + if (queryMethod.isSearchQuery()) { + + CassandraQueryCreator queryCreator = new CassandraQueryCreator(tree, parameterAccessor, this.mappingContext); + + PropertyPath vectorProperty = queryCreator.getVectorProperty(); + Vector vector = parameterAccessor.getVector(); + SimilarityFunction similarityFunction = getSimilarityFunction(parameterAccessor.getScoringFunction()); + + columns = columns.select(vectorProperty.toDotPath(), + selectorBuilder -> selectorBuilder.similarity(vector).using(similarityFunction).as("\"__score__\"")); + } + - Columns columns = Columns.from(returnedType.getInputProperties().toArray(new String[0])); query = query.columns(columns); } @@ -93,12 +126,42 @@ SimpleStatement select(StatementFactory statementFactory, PartTree tree, Cassand LOG.debug(String.format("Created query [%s]", statement)); } + System.out.println(statement.getQuery()); + return statement; }; return doWithQuery(parameterAccessor, tree, function); } + private Columns getColumns(Class returnedType) { + + CassandraPersistentEntity entity = mappingContext.getRequiredPersistentEntity(returnedType); + List names = new ArrayList<>(); + for (CassandraPersistentProperty property : entity) { + names.add(property.getName()); + } + + return Columns.from(names.toArray(new String[0])); + } + + private SimilarityFunction getSimilarityFunction(@Nullable ScoringFunction function) { + + if (function == null) { + throw new IllegalStateException( + "Cannot determine ScoringFunction. No Score or bounded Score Range parameters provided."); + } + + SimilarityFunction similarityFunction = SIMILARITY_FUNCTIONS.get(function); + + if (similarityFunction == null) { + throw new IllegalArgumentException( + "Cannot determine SimilarityFunction from ScoreFunction '%s'".formatted(function)); + } + + return similarityFunction; + } + /** * Create a {@literal COUNT} {@link Statement} from a {@link PartTree} and apply query options. * diff --git a/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/VectorSearchIntegrationTests.java b/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/VectorSearchIntegrationTests.java new file mode 100644 index 000000000..6ec060b57 --- /dev/null +++ b/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/VectorSearchIntegrationTests.java @@ -0,0 +1,202 @@ +/* + * Copyright 2016-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.cassandra.repository; + +import static org.assertj.core.api.Assertions.*; + +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.UUID; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.ComponentScan; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.FilterType; +import org.springframework.data.annotation.Id; +import org.springframework.data.cassandra.config.SchemaAction; +import org.springframework.data.cassandra.core.mapping.Indexed; +import org.springframework.data.cassandra.core.mapping.SaiIndexed; +import org.springframework.data.cassandra.core.mapping.Table; +import org.springframework.data.cassandra.core.mapping.VectorType; +import org.springframework.data.cassandra.repository.config.EnableCassandraRepositories; +import org.springframework.data.cassandra.repository.support.AbstractSpringDataEmbeddedCassandraIntegrationTest; +import org.springframework.data.cassandra.repository.support.IntegrationTestConfig; +import org.springframework.data.domain.Limit; +import org.springframework.data.domain.ScoringFunction; +import org.springframework.data.domain.SearchResult; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Vector; +import org.springframework.data.domain.VectorScoringFunctions; +import org.springframework.data.repository.CrudRepository; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; + +/** + * Integration tests for Vector Search using repositories. + * + * @author Mark Paluch + */ +@SpringJUnitConfig +class VectorSearchIntegrationTests extends AbstractSpringDataEmbeddedCassandraIntegrationTest { + + @Configuration + @EnableCassandraRepositories(basePackageClasses = CommentsRepository.class, considerNestedRepositories = true, + includeFilters = @ComponentScan.Filter(classes = CommentsRepository.class, type = FilterType.ASSIGNABLE_TYPE)) + public static class Config extends IntegrationTestConfig { + + @Override + protected Set> getInitialEntitySet() { + return Collections.singleton(Comments.class); + } + + @Override + public SchemaAction getSchemaAction() { + return SchemaAction.RECREATE_DROP_UNUSED; + } + } + + @Autowired CommentsRepository repository; + + @BeforeEach + void setUp() { + + repository.deleteAll(); + + Comments one = new Comments(); + one.setId(UUID.randomUUID()); + one.setLanguage("en"); + one.setEmbedding(Vector.of(0.45f, 0.09f, 0.01f, 0.2f, 0.11f)); + one.setComment("Raining too hard should have postponed"); + + Comments two = new Comments(); + two.setId(UUID.randomUUID()); + two.setLanguage("en"); + two.setEmbedding(Vector.of(0.99f, 0.5f, -10.99f, -100.1f, 0.34f)); + two.setComment("Second rest stop was out of water"); + + Comments three = new Comments(); + three.setId(UUID.randomUUID()); + three.setLanguage("en"); + three.setEmbedding(Vector.of(0.9f, 0.54f, 0.12f, 0.1f, 0.95f)); + three.setComment("LATE RIDERS SHOULD NOT DELAY THE START"); + + repository.saveAll(List.of(one, two, three)); + + } + + @Test // GH- + void shouldConsiderScoringFunction() { + + Vector vector = Vector.of(0.9f, 0.54f, 0.12f, 0.1f, 0.95f); + + SearchResults result = repository.searchByEmbeddingNear(vector, + VectorScoringFunctions.COSINE, Limit.of(100)); + + assertThat(result).hasSize(3); + for (SearchResult commentSearch : result) { + assertThat(commentSearch.getScore().getValue()).isNotCloseTo(0d, offset(0.1d)); + } + + result = repository.searchByEmbeddingNear(vector, VectorScoringFunctions.EUCLIDEAN, Limit.of(100)); + + assertThat(result).hasSize(3); + for (SearchResult commentSearch : result) { + assertThat(commentSearch.getScore().getValue()).isNotCloseTo(0.3d, offset(0.1d)); + } + } + + @Test // GH- + void shouldRunAnnotatedSearchByVector() { + + Vector vector = Vector.of(0.9f, 0.54f, 0.12f, 0.1f, 0.95f); + + SearchResults result = repository.searchAnnotatedByEmbeddingNear(vector, Limit.of(100)); + + assertThat(result).hasSize(3); + for (SearchResult commentSearch : result) { + assertThat(commentSearch.getScore().getValue()).isNotCloseTo(0d, offset(0.1d)); + } + } + + @Test // GH- + void shouldFindByVector() { + + Vector vector = Vector.of(0.9f, 0.54f, 0.12f, 0.1f, 0.95f); + + List result = repository.findByEmbeddingNear(vector, Limit.of(100)); + + assertThat(result).hasSize(3); + } + + @Table + static class Comments { + + @Id UUID id; + String comment; + + @Indexed String language; + + @VectorType(dimensions = 5) + @SaiIndexed Vector embedding; + + public UUID getId() { + return id; + } + + public void setId(UUID id) { + this.id = id; + } + + public String getLanguage() { + return language; + } + + public void setLanguage(String language) { + this.language = language; + } + + public String getComment() { + return comment; + } + + public void setComment(String comment) { + this.comment = comment; + } + + public Vector getEmbedding() { + return embedding; + } + + public void setEmbedding(Vector embedding) { + this.embedding = embedding; + } + } + + interface CommentsRepository extends CrudRepository { + + SearchResults searchByEmbeddingNear(Vector embedding, ScoringFunction function, Limit limit); + + List findByEmbeddingNear(Vector embedding, Limit limit); + + @Query("SELECT id,comment,language,similarity_cosine(embedding,:embedding) AS score FROM comments ORDER BY embedding ANN OF :embedding LIMIT :limit") + SearchResults searchAnnotatedByEmbeddingNear(Vector embedding, Limit limit); + + } + +} diff --git a/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/query/CassandraParametersUnitTests.java b/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/query/CassandraParametersUnitTests.java index 292de9b81..dfd04f7e5 100755 --- a/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/query/CassandraParametersUnitTests.java +++ b/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/query/CassandraParametersUnitTests.java @@ -26,10 +26,11 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; + import org.springframework.data.cassandra.core.mapping.CassandraType; import org.springframework.data.cassandra.domain.Person; +import org.springframework.data.domain.ScoringFunction; import org.springframework.data.repository.Repository; -import org.springframework.data.repository.core.support.DefaultRepositoryMetadata; import org.springframework.data.repository.query.ParametersSource; /** @@ -89,6 +90,15 @@ void shouldReturnTypeForComposedAnnotationType() throws Exception { assertThat(cassandraParameters.getParameter(0).getCassandraType().type()).isEqualTo(Name.BOOLEAN); } + @Test // GH- + void considersScoringFunctionIndex() throws Exception { + + Method method = PersonRepository.class.getMethod("findByObject", ScoringFunction.class); + CassandraParameters cassandraParameters = new CassandraParameters(ParametersSource.of(method)); + + assertThat(cassandraParameters.getScoringFunctionIndex()).isEqualTo(0); + } + interface PersonRepository extends Repository { Person findByFirstname(String firstname); @@ -97,6 +107,8 @@ interface PersonRepository extends Repository { Person findByObject(Object firstname); + Person findByObject(ScoringFunction firstname); + Person findByAnnotatedObject(@CassandraType(type = Name.TIME) Object firstname); Person findByComposedAnnotationObject(@ComposedCassandraTypeAnnotation Object firstname); diff --git a/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/query/StubParameterAccessor.java b/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/query/StubParameterAccessor.java index dafd0a650..3245ade93 100644 --- a/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/query/StubParameterAccessor.java +++ b/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/query/StubParameterAccessor.java @@ -24,7 +24,11 @@ import org.springframework.data.cassandra.core.mapping.CassandraType; import org.springframework.data.cassandra.core.query.CassandraScrollPosition; import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.ScoringFunction; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.repository.query.ParameterAccessor; import com.datastax.oss.driver.api.core.type.DataType; @@ -77,6 +81,26 @@ public CassandraScrollPosition getScrollPosition() { return null; } + @Override + public ScoringFunction getScoringFunction() { + return null; + } + + @Override + public @Nullable Vector getVector() { + return null; + } + + @Override + public @Nullable Score getScore() { + return null; + } + + @Override + public @Nullable Range getScoreRange() { + return null; + } + @Override public Pageable getPageable() { return null; @@ -117,4 +141,9 @@ public CassandraType findCassandraType(int index) { public Object[] getValues() { return new Object[0]; } + + @Override + public @Nullable Object getValue(int parameterIndex) { + return null; + } } diff --git a/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/test/util/CassandraDelegate.java b/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/test/util/CassandraDelegate.java index 2636125d6..477a889b9 100644 --- a/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/test/util/CassandraDelegate.java +++ b/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/test/util/CassandraDelegate.java @@ -290,7 +290,7 @@ private void runTestcontainerCassandra() { if (container == null) { container = getCassandraDockerImageName().map(CassandraContainer::new) - .orElseGet(() -> new CassandraContainer("cassandra:5.0.3")); + .orElseGet(() -> new CassandraContainer("cassandra:5.0.3")).withReuse(true); container.withEnv("MAX_HEAP_SIZE", "1500M"); container.withEnv("HEAP_NEWSIZE", "300M"); From 6f57f2f1624a0f03e323971e0bdec6d262a379b1 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Wed, 16 Apr 2025 16:35:43 +0200 Subject: [PATCH 4/9] Add support for QueryResultConverter. --- .../core/AsyncCassandraOperations.java | 2 +- .../cassandra/core/CassandraOperations.java | 18 +- .../cassandra/core/CassandraTemplate.java | 123 ++++++++--- .../cassandra/core/EntityResultConverter.java | 33 +++ .../core/ExecutableSelectOperation.java | 119 ++++++++++- .../ExecutableSelectOperationSupport.java | 184 ++++++++++++++++- .../cassandra/core/QueryResultConverter.java | 85 ++++++++ .../core/ReactiveCassandraOperations.java | 2 +- .../core/ReactiveCassandraTemplate.java | 46 ++++- .../core/ReactiveSelectOperation.java | 115 ++++++++++- .../core/ReactiveSelectOperationSupport.java | 192 +++++++++++++++--- .../query/CassandraQueryExecution.java | 17 +- .../core/CassandraTemplateUnitTests.java | 12 +- ...electOperationSupportIntegrationTests.java | 27 ++- ...electOperationSupportIntegrationTests.java | 17 +- 15 files changed, 871 insertions(+), 121 deletions(-) create mode 100644 spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/EntityResultConverter.java create mode 100644 spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/QueryResultConverter.java diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/AsyncCassandraOperations.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/AsyncCassandraOperations.java index febc49b4f..90cdc8f76 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/AsyncCassandraOperations.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/AsyncCassandraOperations.java @@ -128,7 +128,7 @@ CompletableFuture select(String cql, Consumer entityConsumer, Class /** * Execute a {@code SELECT} query with paging and convert the result set to a {@link Slice} of entities. A sliced - * query translates the effective {@link Statement#getFetchSize() fetch size} to the page size. + * query translates the effective {@link Statement#getPageSize() fetch size} to the page size. * * @param statement the CQL statement, must not be {@literal null}. * @param entityClass The entity type must not be {@literal null}. diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraOperations.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraOperations.java index ad4d3289c..a7a970cd2 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraOperations.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraOperations.java @@ -17,7 +17,6 @@ import java.util.Iterator; import java.util.List; -import java.util.function.BiFunction; import java.util.stream.Stream; import org.jspecify.annotations.Nullable; @@ -35,7 +34,6 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.cql.BatchType; import com.datastax.oss.driver.api.core.cql.ResultSet; -import com.datastax.oss.driver.api.core.cql.Row; import com.datastax.oss.driver.api.core.cql.Statement; /** @@ -163,23 +161,9 @@ default CassandraBatchOperations batchOps() { */ List select(Statement statement, Class entityClass) throws DataAccessException; - /** - * Execute a {@code SELECT} query and convert the resulting items to a {@link List} of entities considering the given - * {@link BiFunction mapping function}. - * - * @param statement must not be {@literal null}. - * @param entityClass the entity type must not be {@literal null}. - * @param mapper mapping function invoked after materializing {@code entityClass} must not be {@literal null}. - * @return the converted results - * @throws DataAccessException if there is any problem executing the query. - * @since 5.0 - */ - List select(Statement statement, Class entityClass, BiFunction mapper) - throws DataAccessException; - /** * Execute a {@code SELECT} query with paging and convert the result set to a {@link Slice} of entities. A sliced - * query translates the effective {@link Statement#getFetchSize() fetch size} to the page size. + * query translates the effective {@link Statement#getPageSize()} to the page size. * * @param statement the CQL statement, must not be {@literal null}. * @param entityClass the entity type must not be {@literal null}. diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraTemplate.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraTemplate.java index 0b1f614da..300bc1c7b 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraTemplate.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraTemplate.java @@ -16,7 +16,6 @@ package org.springframework.data.cassandra.core; import java.util.List; -import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; @@ -59,6 +58,7 @@ import org.springframework.data.projection.EntityProjection; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; +import org.springframework.data.util.Lazy; import org.springframework.util.Assert; import com.datastax.oss.driver.api.core.CqlIdentifier; @@ -350,25 +350,21 @@ public ResultSet execute(Statement statement) { @Override public List select(Statement statement, Class entityClass) { - return select(statement, entityClass, (t, row) -> t); - } - - @Override - public List select(Statement statement, Class entityClass, BiFunction mapper) - throws DataAccessException { Assert.notNull(statement, "Statement must not be null"); Assert.notNull(entityClass, "Entity type must not be null"); - Assert.notNull(mapper, "Row Mapper function must not be null"); - Function defaultMapper = getMapper(EntityProjection.nonProjecting(entityClass), - EntityQueryUtils.getTableName(statement)); + return doSelect(statement, entityClass, getTableName(entityClass), entityClass, QueryResultConverter.entity()); + } - return doQuery(statement, (row, rowNum) -> { + List doSelect(Statement statement, Class entityClass, CqlIdentifier tableName, Class returnType, + QueryResultConverter mappingFunction) { - S intermediate = defaultMapper.apply(row); - return mapper.apply(intermediate, row); - }); + EntityProjection projection = entityOperations.introspectProjection(returnType, entityClass); + + RowMapper rowMapper = getRowMapper(projection, tableName, mappingFunction); + + return doQuery(statement, rowMapper); } @Override @@ -384,13 +380,14 @@ public Slice slice(Statement statement, Class entityClass) { Assert.notNull(statement, "Statement must not be null"); Assert.notNull(entityClass, "Entity type must not be null"); - ResultSet resultSet = doQueryForResultSet(statement); + return doSlice(statement, + getRowMapper(entityClass, EntityQueryUtils.getTableName(statement), QueryResultConverter.entity())); + } - Function mapper = getMapper(EntityProjection.nonProjecting(entityClass), - EntityQueryUtils.getTableName(statement)); + Slice doSlice(Statement statement, RowMapper mapper) { - return EntityQueryUtils.readSlice(resultSet, (row, rowNum) -> mapper.apply(row), 0, - getEffectivePageSize(statement)); + ResultSet resultSet = doQueryForResultSet(statement); + return EntityQueryUtils.readSlice(resultSet, mapper, 0, getEffectivePageSize(statement)); } @Override @@ -399,9 +396,17 @@ public Stream stream(Statement statement, Class entityClass) throws Assert.notNull(statement, "Statement must not be null"); Assert.notNull(entityClass, "Entity type must not be null"); - Function mapper = getMapper(EntityProjection.nonProjecting(entityClass), - EntityQueryUtils.getTableName(statement)); - return doQueryForStream(statement, (row, rowNum) -> mapper.apply(row)); + return doStream(statement, entityClass, EntityQueryUtils.getTableName(statement), entityClass, + QueryResultConverter.entity()); + } + + Stream doStream(Statement statement, Class entityClass, CqlIdentifier tableName, Class returnType, + QueryResultConverter mappingFunction) { + + EntityProjection projection = entityOperations.introspectProjection(returnType, entityClass); + + RowMapper rowMapper = getRowMapper(projection, tableName, mappingFunction); + return doQueryForStream(statement, rowMapper); } // ------------------------------------------------------------------------- @@ -414,10 +419,11 @@ public List select(Query query, Class entityClass) throws DataAccessEx Assert.notNull(query, "Query must not be null"); Assert.notNull(entityClass, "Entity type must not be null"); - return doSelect(query, entityClass, getTableName(entityClass), entityClass); + return doSelect(query, entityClass, getTableName(entityClass), entityClass, QueryResultConverter.entity()); } - List doSelect(Query query, Class entityClass, CqlIdentifier tableName, Class returnType) { + List doSelect(Query query, Class entityClass, CqlIdentifier tableName, Class returnType, + QueryResultConverter mappingFunction) { CassandraPersistentEntity entity = getRequiredPersistentEntity(entityClass); EntityProjection projection = entityOperations.introspectProjection(returnType, entityClass); @@ -427,9 +433,9 @@ List doSelect(Query query, Class entityClass, CqlIdentifier tableName, Query queryToUse = query.columns(columns); StatementBuilder select = getStatementFactory().select(query, getRequiredPersistentEntity(entityClass)); + return doSlice(query, entityClass, getRequiredPersistentEntity(entityClass).getTableName(), entityClass, + QueryResultConverter.entity()); + } + + Slice doSlice(Query query, Class entityClass, CqlIdentifier tableName, Class returnType, + QueryResultConverter mappingFunction) { + + CassandraPersistentEntity entity = getRequiredPersistentEntity(entityClass); + EntityProjection projection = entityOperations.introspectProjection(returnType, entityClass); + Columns columns = getStatementFactory().computeColumnsForProjection(projection, query.getColumns(), entity, + returnType); + + Query queryToUse = query.columns(columns); + + StatementBuilder select = getStatementFactory().select(query, getRequiredPersistentEntity(entityClass), tableName); EntityProjection projection = entityOperations.introspectProjection(returnType, entityClass); - Function mapper = getMapper(projection, tableName); - return doQueryForStream(select.build(), (row, rowNum) -> mapper.apply(row)); + RowMapper rowMapper = getRowMapper(projection, tableName, mappingFunction); + + return doQueryForStream(select.build(), rowMapper); } @Override @@ -779,6 +802,16 @@ public ExecutableSelect query(Class domainType) { return new ExecutableSelectOperationSupport(this).query(domainType); } + @Override + public UntypedSelect query(String cql) { + return new ExecutableSelectOperationSupport(this).query(cql); + } + + @Override + public UntypedSelect query(Statement statement) { + return new ExecutableSelectOperationSupport(this).query(statement); + } + @Override public ExecutableInsert insert(Class domainType) { return new ExecutableInsertOperationSupport(this).insert(domainType); @@ -921,6 +954,32 @@ public String getCql() { return getCqlOperations().execute(new GetConfiguredPageSize()); } + @SuppressWarnings("unchecked") + RowMapper getRowMapper(EntityProjection projection, CqlIdentifier tableName, + QueryResultConverter mappingFunction) { + + Function mapper = getMapper(projection, tableName); + + return mappingFunction == QueryResultConverter.entity() ? (row, rowNum) -> (R) mapper.apply(row) + : (row, rowNum) -> { + Lazy reader = Lazy.of(() -> mapper.apply(row)); + return mappingFunction.mapRow(row, reader::get); + }; + } + + @SuppressWarnings("unchecked") + RowMapper getRowMapper(Class domainClass, CqlIdentifier tableName, + QueryResultConverter mappingFunction) { + + Function mapper = getMapper(EntityProjection.nonProjecting(domainClass), tableName); + + return mappingFunction == QueryResultConverter.entity() ? (row, rowNum) -> (R) mapper.apply(row) + : (row, rowNum) -> { + Lazy reader = Lazy.of(() -> mapper.apply(row)); + return mappingFunction.mapRow(row, reader::get); + }; + } + @SuppressWarnings("unchecked") private Function getMapper(EntityProjection projection, CqlIdentifier tableName) { diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/EntityResultConverter.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/EntityResultConverter.java new file mode 100644 index 000000000..440c1df75 --- /dev/null +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/EntityResultConverter.java @@ -0,0 +1,33 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.cassandra.core; + +import com.datastax.oss.driver.api.core.cql.Row; + +enum EntityResultConverter implements QueryResultConverter { + + INSTANCE; + + @Override + public Object mapRow(Row row, ConversionResultSupplier reader) { + return reader.get(); + } + + @Override + public QueryResultConverter andThen(QueryResultConverter after) { + return (QueryResultConverter) after; + } +} diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ExecutableSelectOperation.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ExecutableSelectOperation.java index 54006e963..664159548 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ExecutableSelectOperation.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ExecutableSelectOperation.java @@ -17,15 +17,21 @@ import java.util.List; import java.util.Optional; +import java.util.function.Function; import java.util.stream.Stream; import org.jspecify.annotations.Nullable; + +import org.springframework.data.cassandra.core.cql.RowMapper; import org.springframework.data.cassandra.core.query.Query; +import org.springframework.data.domain.Slice; import org.springframework.lang.Contract; import org.springframework.util.Assert; import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.cql.ResultSet; +import com.datastax.oss.driver.api.core.cql.Row; +import com.datastax.oss.driver.api.core.cql.Statement; /** * The {@link ExecutableSelectOperation} interface allows creation and execution of Cassandra {@code SELECT} operations @@ -68,6 +74,76 @@ public interface ExecutableSelectOperation { */ ExecutableSelect query(Class domainType); + /** + * Begin creating a Cassandra {@code SELECT} query operation for the given {@code cql}. The given {@code cql} must be + * a {@code SELECT} query. + * + * @param cql {@code SELECT} statement, must not be {@literal null}. + * @return new instance of {@link UntypedSelect}. + * @throws IllegalArgumentException if {@code cql} is {@literal null}. + * @since 5.0 + * @see ExecutableSelect + */ + UntypedSelect query(String cql); + + /** + * Begin creating a Cassandra {@code SELECT} query operation for the given {@link Statement}. The given + * {@link Statement} must be a {@code SELECT} query. + * + * @param statement {@code SELECT} statement, must not be {@literal null}. + * @return new instance of {@link UntypedSelect}. + * @throws IllegalArgumentException if {@link Statement statement} is {@literal null}. + * @since 5.0 + * @see ExecutableSelect + */ + UntypedSelect query(Statement statement); + + /** + * Select query that is not yet associated with a result type. + * + * @since 5.0 + */ + interface UntypedSelect { + + /** + * Define the {@link Class result target type} that the Cassandra Row fields should be mapped to. + * + * @param resultType result type; must not be {@literal null}. + * @param {@link Class type} of the result. + * @return new instance of {@link TerminatingResults}. + * @throws IllegalArgumentException if {@link Class resultType} is {@literal null}. + */ + @Contract("_ -> new") + TerminatingResults as(Class resultType); + + /** + * Configure a {@link Function mapping function} that maps the Cassandra Row to a result type. This is a simplified + * variant of {@link #map(RowMapper)}. + * + * @param mapper row mapping function; must not be {@literal null}. + * @param {@link Class type} of the result. + * @return new instance of {@link TerminatingResults}. + * @throws IllegalArgumentException if {@link Function mapper} is {@literal null}. + * @see #map(RowMapper) + */ + @Contract("_ -> new") + default TerminatingResults map(Function mapper) { + return map((row, rowNum) -> mapper.apply(row)); + } + + /** + * Configure a {@link RowMapper} that maps the Cassandra Row to a result type. + * + * @param mapper the row mapper; must not be {@literal null}. + * @param {@link Class type} of the result. + * @return new instance of {@link TerminatingResults}. + * @throws IllegalArgumentException if {@link RowMapper mapper} is {@literal null}. + */ + @Contract("_ -> new") + TerminatingResults map(RowMapper mapper); + + } + /** * Table override (optional). */ @@ -121,7 +197,7 @@ interface SelectWithProjection extends SelectWithQuery { * @param {@link Class type} of the result. * @param resultType desired {@link Class target type} of the result; must not be {@literal null}. * @return new instance of {@link SelectWithQuery}. - * @throws IllegalArgumentException if resultType is {@literal null}. + * @throws IllegalArgumentException if {@link Class resultType} is {@literal null}. * @see SelectWithQuery */ @Contract("_ -> new") @@ -130,18 +206,19 @@ interface SelectWithProjection extends SelectWithQuery { } /** - * Filtering (optional). + * Define a {@link Query} used as the filter for the {@code SELECT}. */ interface SelectWithQuery extends TerminatingSelect { /** - * Set the {@link Query} to use as a filter. + * Set the {@link Query} used as a filter in the {@code SELECT} statement. * * @param query {@link Query} used as a filter; must not be {@literal null}. * @return new instance of {@link TerminatingSelect}. * @throws IllegalArgumentException if {@link Query} is {@literal null}. * @see TerminatingSelect */ + @Contract("_ -> new") TerminatingSelect matching(Query query); } @@ -149,7 +226,13 @@ interface SelectWithQuery extends TerminatingSelect { /** * Trigger {@code SELECT} query execution by calling one of the terminating methods. */ - interface TerminatingSelect { + interface TerminatingSelect extends TerminatingProjections, TerminatingResults {} + + /** + * Trigger {@code SELECT} query execution by calling one of the terminating methods returning result projections for + * count and exists projections. + */ + interface TerminatingProjections { /** * Get the number of matching elements. @@ -168,6 +251,25 @@ default boolean exists() { return count() > 0; } + } + + /** + * Trigger {@code SELECT} query execution by calling one of the terminating methods and return mapped results. + */ + interface TerminatingResults { + + /** + * Map the query result to a different type using {@link QueryResultConverter}. + * + * @param {@link Class type} of the result. + * @param converter the converter, must not be {@literal null}. + * @return new instance of {@link TerminatingResults}. + * @throws IllegalArgumentException if {@link QueryResultConverter converter} is {@literal null}. + * @since 5.0 + */ + @Contract("_ -> new") + TerminatingResults map(QueryResultConverter converter); + /** * Get the first result, or no result. * @@ -214,6 +316,15 @@ default Optional one() { */ List all(); + /** + * Execute the query with paging and convert the result set to a {@link Slice} of entities. A sliced query + * translates the effective {@link Statement#getPageSize() fetch size} to the page size. + * + * @return the converted results + * @since 5.0 + */ + Slice slice(); + /** * Stream all matching elements. * diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ExecutableSelectOperationSupport.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ExecutableSelectOperationSupport.java index 2e17e3a67..679c5d2b9 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ExecutableSelectOperationSupport.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ExecutableSelectOperationSupport.java @@ -19,12 +19,18 @@ import java.util.stream.Stream; import org.jspecify.annotations.Nullable; + import org.springframework.dao.IncorrectResultSizeDataAccessException; +import org.springframework.data.cassandra.core.cql.QueryExtractorDelegate; +import org.springframework.data.cassandra.core.cql.RowMapper; import org.springframework.data.cassandra.core.query.Query; +import org.springframework.data.domain.Slice; import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.cql.SimpleStatement; +import com.datastax.oss.driver.api.core.cql.Statement; /** * Implementation of {@link ExecutableSelectOperation}. @@ -47,36 +53,182 @@ public ExecutableSelect query(Class domainType) { Assert.notNull(domainType, "DomainType must not be null"); - return new ExecutableSelectSupport<>(this.template, domainType, domainType, Query.empty(), null); + return new ExecutableSelectSupport<>(this.template, domainType, domainType, QueryResultConverter.entity(), + Query.empty(), null); } - static class ExecutableSelectSupport implements ExecutableSelect { + @Override + public UntypedSelect query(String cql) { + + Assert.hasText(cql, "CQL must not be empty"); + + return new UntypedSelectSupport(this.template, SimpleStatement.newInstance(cql)); + } + + @Override + public UntypedSelect query(Statement statement) { + + Assert.notNull(statement, "Statement must not be null"); + + return new UntypedSelectSupport(this.template, statement); + } + + private record UntypedSelectSupport(CassandraTemplate template, Statement statement) implements UntypedSelect { + + @Override + public TerminatingResults as(Class resultType) { + + Assert.notNull(resultType, "Result type must not be null"); + + return new TypedSelectSupport<>(template, statement, resultType); + } + + @Override + public TerminatingResults map(RowMapper mapper) { + + Assert.notNull(mapper, "RowMapper must not be null"); + + return new TerminatingSelectResultSupport<>(template, statement, mapper); + } + + } + + static class TypedSelectSupport extends TerminatingSelectResultSupport implements TerminatingResults { + + private final Class domainType; + + TypedSelectSupport(CassandraTemplate template, Statement statement, Class domainType) { + super(template, statement, + template.getRowMapper(domainType, EntityQueryUtils.getTableName(statement), QueryResultConverter.entity())); + + this.domainType = domainType; + } + + @Override + public TerminatingResults map(QueryResultConverter converter) { + + Assert.notNull(converter, "Mapping function must not be null"); + + return new TerminatingSelectResultSupport<>(this.template, this.statement, this.domainType, converter); + } + + } + + static class TerminatingSelectResultSupport implements TerminatingResults { + + final CassandraTemplate template; + + final Statement statement; + + final RowMapper rowMapper; + + TerminatingSelectResultSupport(CassandraTemplate template, Statement statement, RowMapper rowMapper) { + this.template = template; + this.statement = statement; + this.rowMapper = rowMapper; + } + + TerminatingSelectResultSupport(CassandraTemplate template, Statement statement, Class domainType, + QueryResultConverter mappingFunction) { + this(template, statement, + template.getRowMapper(domainType, EntityQueryUtils.getTableName(statement), mappingFunction)); + } + + @Override + public TerminatingResults map(QueryResultConverter converter) { + + return new TerminatingSelectResultSupport<>(this.template, this.statement, (row, rowNum) -> { + + return converter.mapRow(row, () -> { + return this.rowMapper.mapRow(row, rowNum); + }); + }); + } + + @Override + public @Nullable T firstValue() { + + List result = this.template.getCqlOperations().query(this.statement, this.rowMapper); + + return ObjectUtils.isEmpty(result) ? null : result.iterator().next(); + } + + @Override + public @Nullable T oneValue() { + + List result = this.template.getCqlOperations().query(this.statement, this.rowMapper); + + if (ObjectUtils.isEmpty(result)) { + return null; + } + + if (result.size() > 1) { + throw new IncorrectResultSizeDataAccessException( + String.format("Query [%s] returned non unique result", QueryExtractorDelegate.getCql(this.statement)), 1); + } + + return result.iterator().next(); + } + + @Override + public List all() { + return this.template.getCqlOperations().query(this.statement, this.rowMapper); + } + + @Override + public Slice slice() { + return this.template.doSlice(this.statement, this.rowMapper); + } + + @Override + public Stream stream() { + return this.template.getCqlOperations().queryForStream(this.statement, this.rowMapper); + } + + } + + static class ExecutableSelectSupport implements ExecutableSelect { private final CassandraTemplate template; private final Class domainType; - private final Class returnType; + private final Class returnType; + + private final QueryResultConverter mappingFunction; private final Query query; private final @Nullable CqlIdentifier tableName; - public ExecutableSelectSupport(CassandraTemplate template, Class domainType, Class returnType, Query query, + public ExecutableSelectSupport(CassandraTemplate template, Class domainType, Class returnType, + QueryResultConverter mappingFunction, Query query, @Nullable CqlIdentifier tableName) { + this.template = template; this.domainType = domainType; this.returnType = returnType; + this.mappingFunction = mappingFunction; this.query = query; this.tableName = tableName; } + @Override + public TerminatingResults map(QueryResultConverter converter) { + + Assert.notNull(converter, "Mapping function name must not be null"); + + return new ExecutableSelectSupport<>(this.template, this.domainType, this.returnType, + this.mappingFunction.andThen(converter), this.query, tableName); + } + @Override public SelectWithProjection inTable(CqlIdentifier tableName) { Assert.notNull(tableName, "Table name must not be null"); - return new ExecutableSelectSupport<>(this.template, this.domainType, this.returnType, this.query, tableName); + return new ExecutableSelectSupport<>(this.template, this.domainType, this.returnType, this.mappingFunction, + this.query, tableName); } @Override @@ -84,7 +236,8 @@ public SelectWithQuery as(Class returnType) { Assert.notNull(returnType, "ReturnType must not be null"); - return new ExecutableSelectSupport<>(this.template, this.domainType, returnType, this.query, this.tableName); + return new ExecutableSelectSupport<>(this.template, this.domainType, returnType, QueryResultConverter.entity(), + this.query, this.tableName); } @Override @@ -92,7 +245,8 @@ public TerminatingSelect matching(Query query) { Assert.notNull(query, "Query must not be null"); - return new ExecutableSelectSupport<>(this.template, this.domainType, this.returnType, query, this.tableName); + return new ExecutableSelectSupport<>(this.template, this.domainType, this.returnType, this.mappingFunction, query, + this.tableName); } @Override @@ -108,7 +262,8 @@ public boolean exists() { @Override public @Nullable T firstValue() { - List result = this.template.doSelect(this.query.limit(1), this.domainType, getTableName(), this.returnType); + List result = this.template.doSelect(this.query.limit(1), this.domainType, getTableName(), this.returnType, + this.mappingFunction); return ObjectUtils.isEmpty(result) ? null : result.iterator().next(); } @@ -116,7 +271,8 @@ public boolean exists() { @Override public @Nullable T oneValue() { - List result = this.template.doSelect(this.query.limit(2), this.domainType, getTableName(), this.returnType); + List result = this.template.doSelect(this.query.limit(2), this.domainType, getTableName(), this.returnType, + this.mappingFunction); if (ObjectUtils.isEmpty(result)) { return null; @@ -132,12 +288,17 @@ public boolean exists() { @Override public List all() { - return this.template.doSelect(this.query, this.domainType, getTableName(), this.returnType); + return this.template.doSelect(this.query, this.domainType, getTableName(), this.returnType, this.mappingFunction); + } + + @Override + public Slice slice() { + return this.template.doSlice(this.query, this.domainType, getTableName(), this.returnType, this.mappingFunction); } @Override public Stream stream() { - return this.template.doStream(this.query, this.domainType, getTableName(), this.returnType); + return this.template.doStream(this.query, this.domainType, getTableName(), this.returnType, this.mappingFunction); } private CqlIdentifier getTableName() { @@ -146,4 +307,5 @@ private CqlIdentifier getTableName() { } + } diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/QueryResultConverter.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/QueryResultConverter.java new file mode 100644 index 000000000..fc08e997b --- /dev/null +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/QueryResultConverter.java @@ -0,0 +1,85 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.cassandra.core; + +import com.datastax.oss.driver.api.core.cql.Row; + +/** + * Converter for Cassandra query results. + *

+ * This is a functional interface that allows for mapping a {@link Row} to a result type. + * {@link #mapRow(Row, ConversionResultSupplier) row mapping} can obtain upstream a {@link ConversionResultSupplier + * upstream converter} to enrich the final result object. This is useful when e.g. wrapping result objects where the + * wrapper needs to obtain information from the actual {@link Row}. + * + * @param object type accepted by this converter. + * @param the returned result type. + * @author Mark Paluch + * @since 5.0 + */ +@FunctionalInterface +public interface QueryResultConverter { + + /** + * Returns a function that returns the materialized entity. + * + * @param the type of the input and output entity to the function. + * @return a function that returns the materialized entity. + */ + @SuppressWarnings("unchecked") + static QueryResultConverter entity() { + return (QueryResultConverter) EntityResultConverter.INSTANCE; + } + + /** + * Map a {@link Row} that is read from the Cassandra database to a query result. + * + * @param row the raw row from the Cassandra result. + * @param reader reader object that supplies an upstream result from an earlier converter. + * @return the mapped result. + */ + R mapRow(Row row, ConversionResultSupplier reader); + + /** + * Returns a composed function that first applies this function to its input, and then applies the {@code after} + * function to the result. If evaluation of either function throws an exception, it is relayed to the caller of the + * composed function. + * + * @param the type of output of the {@code after} function, and of the composed function. + * @param after the function to apply after this function is applied. + * @return a composed function that first applies this function and then applies the {@code after} function. + */ + default QueryResultConverter andThen(QueryResultConverter after) { + return (row, reader) -> after.mapRow(row, () -> mapRow(row, reader)); + } + + /** + * A supplier that converts a {@link Row} into {@code T}. Allows for lazy reading of query results. + * + * @param type of the returned result. + */ + interface ConversionResultSupplier { + + /** + * Obtain the upstream conversion result. + * + * @return the upstream conversion result. + */ + T get(); + + } + +} diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ReactiveCassandraOperations.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ReactiveCassandraOperations.java index 311d9dfb6..ef9daa813 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ReactiveCassandraOperations.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ReactiveCassandraOperations.java @@ -129,7 +129,7 @@ default ReactiveCassandraBatchOperations batchOps() { /** * Execute a {@code SELECT} query with paging and convert the result set to a {@link Slice} of entities. A sliced - * query translates the effective {@link Statement#getFetchSize() fetch size} to the page size. + * query translates the effective {@link Statement#getPageSize() fetch size} to the page size. * * @param statement the CQL statement, must not be {@literal null}. * @param entityClass The entity type must not be {@literal null}. diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ReactiveCassandraTemplate.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ReactiveCassandraTemplate.java index 37ae596d0..ae54a522e 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ReactiveCassandraTemplate.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ReactiveCassandraTemplate.java @@ -66,6 +66,7 @@ import org.springframework.data.projection.EntityProjection; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; +import org.springframework.data.util.Lazy; import org.springframework.util.Assert; import com.datastax.oss.driver.api.core.CqlIdentifier; @@ -391,10 +392,11 @@ public Flux select(Query query, Class entityClass) throws DataAccessEx Assert.notNull(query, "Query must not be null"); Assert.notNull(entityClass, "Entity type must not be null"); - return doSelect(query, entityClass, getTableName(entityClass), entityClass); + return doSelect(query, entityClass, getTableName(entityClass), entityClass, QueryResultConverter.entity()); } - Flux doSelect(Query query, Class entityClass, CqlIdentifier tableName, Class returnType) { + Flux doSelect(Query query, Class entityClass, CqlIdentifier tableName, Class returnType, + QueryResultConverter mappingFunction) { CassandraPersistentEntity persistentEntity = getRequiredPersistentEntity(entityClass); EntityProjection projection = entityOperations.introspectProjection(returnType, entityClass); @@ -404,9 +406,9 @@ Flux doSelect(Query query, Class entityClass, CqlIdentifier tableName, Query queryToUse = query.columns(columns); StatementBuilder