From 73e31d16c916e10b222d968269b2b4abb6f0e723 Mon Sep 17 00:00:00 2001 From: mck Date: Sun, 11 May 2025 13:28:38 +0200 Subject: [PATCH] Optimise CassandraChatMemoryRepository for MessageWindowChatMemory usage pattern Time-series each chat window in Cassandra, keeping past (and deleted) windows still in the db. Add ability to store different MessageTypes. Signed-off-by: mck --- ...ChatMemoryRepositoryAutoConfiguration.java | 3 +- ...ssandraChatMemoryRepositoryProperties.java | 27 +--- ...atMemoryRepositoryAutoConfigurationIT.java | 4 +- ...draChatMemoryRepositoryPropertiesTest.java | 12 +- .../CassandraChatMemoryRepository.java | 133 ++++++++---------- .../CassandraChatMemoryRepositoryConfig.java | 97 ++++++++----- .../CassandraChatMemoryRepositoryIT.java | 86 ++++++----- .../modules/ROOT/pages/api/chat-memory.adoc | 14 +- 8 files changed, 183 insertions(+), 193 deletions(-) diff --git a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryAutoConfiguration.java b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryAutoConfiguration.java index dbc2c767dc1..73561074fe4 100644 --- a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryAutoConfiguration.java +++ b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryAutoConfiguration.java @@ -49,8 +49,7 @@ public CassandraChatMemoryRepository cassandraChatMemoryRepository( builder = builder.withKeyspaceName(properties.getKeyspace()) .withTableName(properties.getTable()) - .withAssistantColumnName(properties.getAssistantColumn()) - .withUserColumnName(properties.getUserColumn()); + .withMessagesColumnName(properties.getMessagesColumn()); if (!properties.isInitializeSchema()) { builder = builder.disallowSchemaChanges(); diff --git a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryProperties.java b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryProperties.java index 15b778f95cd..2e0054bb251 100644 --- a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryProperties.java +++ b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryProperties.java @@ -18,9 +18,6 @@ import java.time.Duration; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryRepositoryConfig; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.lang.Nullable; @@ -35,17 +32,13 @@ @ConfigurationProperties(CassandraChatMemoryRepositoryProperties.CONFIG_PREFIX) public class CassandraChatMemoryRepositoryProperties { - public static final String CONFIG_PREFIX = "spring.ai.chat.memory.repository.cassandra"; - - private static final Logger logger = LoggerFactory.getLogger(CassandraChatMemoryRepositoryProperties.class); + public static final String CONFIG_PREFIX = "spring.ai.chat.memory.cassandra"; private String keyspace = CassandraChatMemoryRepositoryConfig.DEFAULT_KEYSPACE_NAME; private String table = CassandraChatMemoryRepositoryConfig.DEFAULT_TABLE_NAME; - private String assistantColumn = CassandraChatMemoryRepositoryConfig.DEFAULT_ASSISTANT_COLUMN_NAME; - - private String userColumn = CassandraChatMemoryRepositoryConfig.DEFAULT_USER_COLUMN_NAME; + private String messagesColumn = CassandraChatMemoryRepositoryConfig.DEFAULT_MESSAGES_COLUMN_NAME; private boolean initializeSchema = true; @@ -75,20 +68,12 @@ public void setTable(String table) { this.table = table; } - public String getAssistantColumn() { - return this.assistantColumn; - } - - public void setAssistantColumn(String assistantColumn) { - this.assistantColumn = assistantColumn; - } - - public String getUserColumn() { - return this.userColumn; + public String getMessagesColumn() { + return this.messagesColumn; } - public void setUserColumn(String userColumn) { - this.userColumn = userColumn; + public void setMessagesColumn(String messagesColumn) { + this.messagesColumn = messagesColumn; } @Nullable diff --git a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/test/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryAutoConfigurationIT.java b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/test/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryAutoConfigurationIT.java index af898108c87..febe5065c49 100644 --- a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/test/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryAutoConfigurationIT.java +++ b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/test/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryAutoConfigurationIT.java @@ -59,7 +59,7 @@ void addAndGet() { this.contextRunner.withPropertyValues("spring.cassandra.contactPoints=" + getContactPointHost()) .withPropertyValues("spring.cassandra.port=" + getContactPointPort()) .withPropertyValues("spring.cassandra.localDatacenter=" + cassandraContainer.getLocalDatacenter()) - .withPropertyValues("spring.ai.chat.memory.repository.cassandra.time-to-live=" + getTimeToLive()) + .withPropertyValues("spring.ai.chat.memory.cassandra.time-to-live=" + getTimeToLive()) .run(context -> { CassandraChatMemoryRepository memory = context.getBean(CassandraChatMemoryRepository.class); @@ -96,7 +96,7 @@ void compareTimeToLive_ISO8601Format() { this.contextRunner.withPropertyValues("spring.cassandra.contactPoints=" + getContactPointHost()) .withPropertyValues("spring.cassandra.port=" + getContactPointPort()) .withPropertyValues("spring.cassandra.localDatacenter=" + cassandraContainer.getLocalDatacenter()) - .withPropertyValues("spring.ai.chat.memory.repository.cassandra.time-to-live=" + getTimeToLiveString()) + .withPropertyValues("spring.ai.chat.memory.cassandra.time-to-live=" + getTimeToLiveString()) .run(context -> { CassandraChatMemoryRepositoryProperties properties = context .getBean(CassandraChatMemoryRepositoryProperties.class); diff --git a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/test/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryPropertiesTest.java b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/test/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryPropertiesTest.java index e14c55be625..f61dcc2c78b 100644 --- a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/test/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryPropertiesTest.java +++ b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/test/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryPropertiesTest.java @@ -36,9 +36,9 @@ void defaultValues() { var props = new CassandraChatMemoryRepositoryProperties(); assertThat(props.getKeyspace()).isEqualTo(CassandraChatMemoryRepositoryConfig.DEFAULT_KEYSPACE_NAME); assertThat(props.getTable()).isEqualTo(CassandraChatMemoryRepositoryConfig.DEFAULT_TABLE_NAME); - assertThat(props.getAssistantColumn()) - .isEqualTo(CassandraChatMemoryRepositoryConfig.DEFAULT_ASSISTANT_COLUMN_NAME); - assertThat(props.getUserColumn()).isEqualTo(CassandraChatMemoryRepositoryConfig.DEFAULT_USER_COLUMN_NAME); + assertThat(props.getMessagesColumn()) + .isEqualTo(CassandraChatMemoryRepositoryConfig.DEFAULT_MESSAGES_COLUMN_NAME); + assertThat(props.getTimeToLive()).isNull(); assertThat(props.isInitializeSchema()).isTrue(); } @@ -48,15 +48,13 @@ void customValues() { var props = new CassandraChatMemoryRepositoryProperties(); props.setKeyspace("my_keyspace"); props.setTable("my_table"); - props.setAssistantColumn("my_assistant_column"); - props.setUserColumn("my_user_column"); + props.setMessagesColumn("my_messages_column"); props.setTimeToLive(Duration.ofDays(1)); props.setInitializeSchema(false); assertThat(props.getKeyspace()).isEqualTo("my_keyspace"); assertThat(props.getTable()).isEqualTo("my_table"); - assertThat(props.getAssistantColumn()).isEqualTo("my_assistant_column"); - assertThat(props.getUserColumn()).isEqualTo("my_user_column"); + assertThat(props.getMessagesColumn()).isEqualTo("my_messages_column"); assertThat(props.getTimeToLive()).isEqualTo(Duration.ofDays(1)); assertThat(props.isInitializeSchema()).isFalse(); } diff --git a/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepository.java b/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepository.java index 58c2ec26cce..391427e1c3b 100644 --- a/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepository.java +++ b/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepository.java @@ -18,24 +18,26 @@ import java.time.Instant; import java.util.ArrayList; -import java.util.Collections; import java.util.List; -import java.util.concurrent.atomic.AtomicLong; +import java.util.Map; import com.datastax.oss.driver.api.core.cql.BoundStatement; import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder; import com.datastax.oss.driver.api.core.cql.PreparedStatement; import com.datastax.oss.driver.api.core.cql.Row; +import com.datastax.oss.driver.api.core.data.UdtValue; import com.datastax.oss.driver.api.querybuilder.QueryBuilder; -import com.datastax.oss.driver.api.querybuilder.delete.Delete; -import com.datastax.oss.driver.api.querybuilder.delete.DeleteSelection; import com.datastax.oss.driver.api.querybuilder.insert.InsertInto; import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert; import com.datastax.oss.driver.api.querybuilder.select.Select; import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; + import org.springframework.ai.chat.memory.ChatMemoryRepository; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.util.Assert; @@ -54,23 +56,17 @@ public class CassandraChatMemoryRepository implements ChatMemoryRepository { private final PreparedStatement allStmt; - private final PreparedStatement addUserStmt; - - private final PreparedStatement addAssistantStmt; + private final PreparedStatement addStmt; private final PreparedStatement getStmt; - private final PreparedStatement deleteStmt; - private CassandraChatMemoryRepository(CassandraChatMemoryRepositoryConfig conf) { Assert.notNull(conf, "conf cannot be null"); this.conf = conf; this.conf.ensureSchemaExists(); this.allStmt = prepareAllStatement(); - this.addUserStmt = prepareAddStmt(this.conf.userColumn); - this.addAssistantStmt = prepareAddStmt(this.conf.assistantColumn); + this.addStmt = prepareAddStmt(); this.getStmt = prepareGetStatement(); - this.deleteStmt = prepareDeleteStmt(); } public static CassandraChatMemoryRepository create(CassandraChatMemoryRepositoryConfig conf) { @@ -97,6 +93,10 @@ public List findConversationIds() { @Override public List findByConversationId(String conversationId) { + return findByConversationIdWithLimit(conversationId, 1); + } + + List findByConversationIdWithLimit(String conversationId, int limit) { Assert.hasText(conversationId, "conversationId cannot be null or empty"); List primaryKeys = this.conf.primaryKeyTranslator.apply(conversationId); @@ -106,19 +106,14 @@ public List findByConversationId(String conversationId) { CassandraChatMemoryRepositoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k); builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType()); } + builder = builder.setInt("legacy_limit", limit); List messages = new ArrayList<>(); for (Row r : this.conf.session.execute(builder.build())) { - String assistant = r.getString(this.conf.assistantColumn); - String user = r.getString(this.conf.userColumn); - if (null != assistant) { - messages.add(new AssistantMessage(assistant)); - } - if (null != user) { - messages.add(new UserMessage(user)); + for (UdtValue udt : r.getList(this.conf.messagesColumn, UdtValue.class)) { + messages.add(getMessage(udt)); } } - Collections.reverse(messages); return messages; } @@ -128,58 +123,49 @@ public void saveAll(String conversationId, List messages) { Assert.notNull(messages, "messages cannot be null"); Assert.noNullElements(messages, "messages cannot contain null elements"); - final AtomicLong instantSeq = new AtomicLong(Instant.now().toEpochMilli()); - messages.forEach(msg -> { - if (msg.getMetadata().containsKey(CONVERSATION_TS)) { - msg.getMetadata().put(CONVERSATION_TS, Instant.ofEpochMilli(instantSeq.getAndIncrement())); - } - save(conversationId, msg); - }); - } - - void save(String conversationId, Message msg) { - - Preconditions.checkArgument( - !msg.getMetadata().containsKey(CONVERSATION_TS) - || msg.getMetadata().get(CONVERSATION_TS) instanceof Instant, - "messages only accept metadata '%s' entries of type Instant", CONVERSATION_TS); - - msg.getMetadata().putIfAbsent(CONVERSATION_TS, Instant.now()); - - PreparedStatement stmt = getStatement(msg); - + Instant instant = Instant.now(); List primaryKeys = this.conf.primaryKeyTranslator.apply(conversationId); - BoundStatementBuilder builder = stmt.boundStatementBuilder(); + BoundStatementBuilder builder = addStmt.boundStatementBuilder(); for (int k = 0; k < primaryKeys.size(); ++k) { CassandraChatMemoryRepositoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k); builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType()); } - Instant instant = (Instant) msg.getMetadata().get(CONVERSATION_TS); + List msgs = new ArrayList<>(); + for (Message msg : messages) { + + Preconditions.checkArgument( + !msg.getMetadata().containsKey(CONVERSATION_TS) + || msg.getMetadata().get(CONVERSATION_TS) instanceof Instant, + "messages only accept metadata '%s' entries of type Instant", CONVERSATION_TS); + msg.getMetadata().putIfAbsent(CONVERSATION_TS, instant); + + UdtValue udt = this.conf.session.getMetadata() + .getKeyspace(this.conf.schema.keyspace()) + .get() + .getUserDefinedType(this.conf.messageUDT) + .get() + .newValue() + .setInstant(this.conf.messageUdtTimestampColumn, (Instant) msg.getMetadata().get(CONVERSATION_TS)) + .setString(this.conf.messageUdtTypeColumn, msg.getMessageType().name()) + .setString(this.conf.messageUdtContentColumn, msg.getText()); + + msgs.add(udt); + } builder = builder.setInstant(CassandraChatMemoryRepositoryConfig.DEFAULT_EXCHANGE_ID_NAME, instant) - .setString("message", msg.getText()); + .setList("msgs", msgs, UdtValue.class); this.conf.session.execute(builder.build()); } @Override public void deleteByConversationId(String conversationId) { - Assert.hasText(conversationId, "conversationId cannot be null or empty"); - - List primaryKeys = this.conf.primaryKeyTranslator.apply(conversationId); - BoundStatementBuilder builder = this.deleteStmt.boundStatementBuilder(); - - for (int k = 0; k < primaryKeys.size(); ++k) { - CassandraChatMemoryRepositoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k); - builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType()); - } - - this.conf.session.execute(builder.build()); + saveAll(conversationId, List.of()); } - private PreparedStatement prepareAddStmt(String column) { + private PreparedStatement prepareAddStmt() { RegularInsert stmt = null; InsertInto stmtStart = QueryBuilder.insertInto(this.conf.schema.keyspace(), this.conf.schema.table()); for (var c : this.conf.schema.partitionKeys()) { @@ -188,7 +174,7 @@ private PreparedStatement prepareAddStmt(String column) { for (var c : this.conf.schema.clusteringKeys()) { stmt = stmt.value(c.name(), QueryBuilder.bindMarker(c.name())); } - stmt = stmt.value(column, QueryBuilder.bindMarker("message")); + stmt = stmt.value(this.conf.messagesColumn, QueryBuilder.bindMarker("msgs")); return this.conf.session.prepare(stmt.build()); } @@ -214,28 +200,27 @@ private PreparedStatement prepareGetStatement() { String columnName = this.conf.schema.clusteringKeys().get(i).name(); stmt = stmt.whereColumn(columnName).isEqualTo(QueryBuilder.bindMarker(columnName)); } + stmt = stmt.limit(QueryBuilder.bindMarker("legacy_limit")); return this.conf.session.prepare(stmt.build()); } - private PreparedStatement prepareDeleteStmt() { - Delete stmt = null; - DeleteSelection stmtStart = QueryBuilder.deleteFrom(this.conf.schema.keyspace(), this.conf.schema.table()); - for (var c : this.conf.schema.partitionKeys()) { - stmt = (null != stmt ? stmt : stmtStart).whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name())); - } - for (int i = 0; i + 1 < this.conf.schema.clusteringKeys().size(); ++i) { - String columnName = this.conf.schema.clusteringKeys().get(i).name(); - stmt = stmt.whereColumn(columnName).isEqualTo(QueryBuilder.bindMarker(columnName)); + private Message getMessage(UdtValue udt) { + String content = udt.getString(this.conf.messageUdtContentColumn); + Map props = Map.of(CONVERSATION_TS, udt.getInstant(this.conf.messageUdtTimestampColumn)); + switch (MessageType.valueOf(udt.getString(this.conf.messageUdtTypeColumn))) { + case ASSISTANT: + return new AssistantMessage(content, props); + case USER: + return UserMessage.builder().text(content).metadata(props).build(); + case SYSTEM: + return SystemMessage.builder().text(content).metadata(props).build(); + case TOOL: + // todo – persist ToolResponse somehow + return new ToolResponseMessage(List.of(), props); + default: + throw new IllegalStateException( + String.format("unknown message type %s", udt.getString(this.conf.messageUdtTypeColumn))); } - return this.conf.session.prepare(stmt.build()); - } - - private PreparedStatement getStatement(Message msg) { - return switch (msg.getMessageType()) { - case USER -> this.addUserStmt; - case ASSISTANT -> this.addAssistantStmt; - default -> throw new IllegalArgumentException("Cant add type " + msg); - }; } } diff --git a/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepositoryConfig.java b/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepositoryConfig.java index 3d232726613..90ba1d2e8a1 100644 --- a/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepositoryConfig.java +++ b/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepositoryConfig.java @@ -30,6 +30,7 @@ import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata; import com.datastax.oss.driver.api.core.type.DataType; import com.datastax.oss.driver.api.core.type.DataTypes; +import com.datastax.oss.driver.api.core.type.UserDefinedType; import com.datastax.oss.driver.api.core.type.codec.registry.CodecRegistry; import com.datastax.oss.driver.api.core.type.reflect.GenericType; import com.datastax.oss.driver.api.querybuilder.SchemaBuilder; @@ -61,9 +62,7 @@ public final class CassandraChatMemoryRepositoryConfig { // todo – make configurable public static final String DEFAULT_EXCHANGE_ID_NAME = "message_timestamp"; - public static final String DEFAULT_ASSISTANT_COLUMN_NAME = "assistant"; - - public static final String DEFAULT_USER_COLUMN_NAME = "user"; + public static final String DEFAULT_MESSAGES_COLUMN_NAME = "messages"; private static final Logger logger = LoggerFactory.getLogger(CassandraChatMemoryRepositoryConfig.class); @@ -71,9 +70,18 @@ public final class CassandraChatMemoryRepositoryConfig { final Schema schema; - final String assistantColumn; + final String messageUDT = "ai_chat_message"; + + final String messagesColumn; + + // todo – make configurable + final String messageUdtTimestampColumn = "msg_timestamp"; + + // todo – make configurable + final String messageUdtTypeColumn = "msg_type"; - final String userColumn; + // todo – make configurable + final String messageUdtContentColumn = "msg_content"; final SessionIdToPrimaryKeysTranslator primaryKeyTranslator; @@ -84,8 +92,7 @@ public final class CassandraChatMemoryRepositoryConfig { private CassandraChatMemoryRepositoryConfig(Builder builder) { this.session = builder.session; this.schema = new Schema(builder.keyspace, builder.table, builder.partitionKeys, builder.clusteringKeys); - this.assistantColumn = builder.assistantColumn; - this.userColumn = builder.userColumn; + this.messagesColumn = builder.messagesColumn; this.timeToLiveSeconds = builder.timeToLiveSeconds; this.disallowSchemaChanges = builder.disallowSchemaChanges; this.primaryKeyTranslator = builder.primaryKeyTranslator; @@ -109,6 +116,7 @@ void dropKeyspace() { void ensureSchemaExists() { if (!this.disallowSchemaChanges) { SchemaUtil.ensureKeyspaceExists(this.session, this.schema.keyspace); + ensureMessageTypeExist(); ensureTableExists(); ensureTableColumnsExist(); SchemaUtil.checkSchemaAgreement(this.session); @@ -129,17 +137,35 @@ void checkSchemaValid() { .getTable(this.schema.table) .isPresent(), "table %s does not exist"); + Preconditions.checkState(this.session.getMetadata() + .getKeyspace(this.schema.keyspace()) + .get() + .getUserDefinedType(messageUDT) + .isPresent(), "table %s does not exist"); + + UserDefinedType udt = this.session.getMetadata() + .getKeyspace(this.schema.keyspace()) + .get() + .getUserDefinedType(messageUDT) + .get(); + + Preconditions.checkState(udt.contains(this.messageUdtTimestampColumn), "field %s does not exist", + this.messageUdtTimestampColumn); + + Preconditions.checkState(udt.contains(this.messageUdtTypeColumn), "field %s does not exist", + this.messageUdtTypeColumn); + + Preconditions.checkState(udt.contains(this.messageUdtContentColumn), "field %s does not exist", + this.messageUdtContentColumn); + TableMetadata tableMetadata = this.session.getMetadata() .getKeyspace(this.schema.keyspace) .get() .getTable(this.schema.table) .get(); - Preconditions.checkState(tableMetadata.getColumn(this.assistantColumn).isPresent(), "column %s does not exist", - this.assistantColumn); - - Preconditions.checkState(tableMetadata.getColumn(this.userColumn).isPresent(), "column %s does not exist", - this.userColumn); + Preconditions.checkState(tableMetadata.getColumn(this.messagesColumn).isPresent(), "column %s does not exist", + this.messagesColumn); } private void ensureTableExists() { @@ -159,9 +185,11 @@ private void ensureTableExists() { String lastClusteringColumn = this.schema.clusteringKeys.get(this.schema.clusteringKeys.size() - 1).name(); - CreateTableWithOptions createTableWithOptions = createTable.withColumn(this.userColumn, DataTypes.TEXT) + CreateTableWithOptions createTableWithOptions = createTable + .withColumn(this.messagesColumn, DataTypes.frozenListOf(SchemaBuilder.udt(messageUDT, true))) .withClusteringOrder(lastClusteringColumn, ClusteringOrder.DESC) - // TODO replace w/ SchemaBuilder.unifiedCompactionStrategy() is available + // TODO replace w/ SchemaBuilder.unifiedCompactionStrategy() when + // available .withOption("compaction", Map.of("class", "UnifiedCompactionStrategy")); if (null != this.timeToLiveSeconds) { @@ -171,6 +199,18 @@ private void ensureTableExists() { } } + private void ensureMessageTypeExist() { + + SimpleStatement stmt = SchemaBuilder.createType(messageUDT) + .ifNotExists() + .withField(messageUdtTimestampColumn, DataTypes.TIMESTAMP) + .withField(messageUdtTypeColumn, DataTypes.TEXT) + .withField(messageUdtContentColumn, DataTypes.TEXT) + .build(); + + this.session.execute(stmt.setKeyspace(this.schema.keyspace)); + } + private void ensureTableColumnsExist() { TableMetadata tableMetadata = this.session.getMetadata() @@ -179,18 +219,12 @@ private void ensureTableColumnsExist() { .getTable(this.schema.table()) .get(); - boolean addAssistantColumn = tableMetadata.getColumn(this.assistantColumn).isEmpty(); - boolean addUserColumn = tableMetadata.getColumn(this.userColumn).isEmpty(); + if (tableMetadata.getColumn(this.messagesColumn).isEmpty()) { + + SimpleStatement stmt = SchemaBuilder.alterTable(this.schema.keyspace(), this.schema.table()) + .addColumn(this.messagesColumn, DataTypes.frozenListOf(SchemaBuilder.udt(messageUDT, true))) + .build(); - if (addAssistantColumn || addUserColumn) { - AlterTableAddColumn alterTable = SchemaBuilder.alterTable(this.schema.keyspace(), this.schema.table()); - if (addAssistantColumn) { - alterTable = alterTable.addColumn(this.assistantColumn, DataTypes.TEXT); - } - if (addUserColumn) { - alterTable = alterTable.addColumn(this.userColumn, DataTypes.TEXT); - } - SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build(); logger.debug("Executing {}", stmt.getQuery()); this.session.execute(stmt); } @@ -228,9 +262,7 @@ public static final class Builder { private List clusteringKeys = List .of(new SchemaColumn(DEFAULT_EXCHANGE_ID_NAME, DataTypes.TIMESTAMP)); - private String assistantColumn = DEFAULT_ASSISTANT_COLUMN_NAME; - - private String userColumn = DEFAULT_USER_COLUMN_NAME; + private String messagesColumn = DEFAULT_MESSAGES_COLUMN_NAME; private Integer timeToLiveSeconds = null; @@ -289,13 +321,8 @@ public Builder withClusteringKeys(List clusteringKeys) { return this; } - public Builder withAssistantColumnName(String name) { - this.assistantColumn = name; - return this; - } - - public Builder withUserColumnName(String name) { - this.userColumn = name; + public Builder withMessagesColumnName(String name) { + this.messagesColumn = name; return this; } diff --git a/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepositoryIT.java b/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepositoryIT.java index 0bf52898553..14b259d317f 100644 --- a/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepositoryIT.java +++ b/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepositoryIT.java @@ -23,6 +23,7 @@ import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.CqlSessionBuilder; import com.datastax.oss.driver.api.core.cql.ResultSet; +import com.datastax.oss.driver.api.core.data.UdtValue; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -32,6 +33,7 @@ import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; +import org.springframework.ai.chat.memory.ChatMemoryRepository; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; @@ -43,7 +45,6 @@ import org.springframework.context.annotation.Bean; import static org.assertj.core.api.Assertions.assertThat; -import org.springframework.ai.chat.memory.ChatMemoryRepository; /** * Use `mvn failsafe:integration-test -Dit.test=CassandraChatMemoryRepositoryIT` @@ -84,30 +85,26 @@ void add_shouldInsertSingleMessage(String content, MessageType messageType) { }; chatMemory.saveAll(sessionId, List.of(message)); + assertThat(chatMemory.findConversationIds()).isNotEmpty(); var cqlSession = context.getBean(CqlSession.class); + var query = """ - SELECT session_id, message_timestamp, a, u + SELECT session_id, message_timestamp, msgs FROM test_springframework.ai_chat_memory WHERE session_id = ? """; - ResultSet resultSet = cqlSession.execute(query, sessionId); - var rows = resultSet.all(); - assertThat(rows.size()).isEqualTo(1); + var result = cqlSession.execute(query, sessionId).one(); - var firstRow = rows.get(0); + assertThat(result.getString("session_id")).isNotNull(); + assertThat(result.getString("session_id")).isEqualTo(sessionId); + assertThat(result.getInstant("message_timestamp")).isNotNull(); + List msgUdts = result.getList("msgs", UdtValue.class); + assertThat(msgUdts.size()).isEqualTo(1); - assertThat(firstRow.getString("session_id")).isEqualTo(sessionId); - assertThat(firstRow.getInstant("message_timestamp")).isNotNull(); - if (messageType == MessageType.ASSISTANT) { - assertThat(firstRow.getString("a")).isEqualTo(content); - assertThat(firstRow.getString("u")).isNull(); - } - else if (messageType == MessageType.USER) { - assertThat(firstRow.getString("a")).isNull(); - assertThat(firstRow.getString("u")).isEqualTo(content); - } + assertThat(msgUdts.get(0).getString("msg_type")).isEqualTo(messageType.name()); + assertThat(msgUdts.get(0).getString("msg_content")).isEqualTo(content); }); } @@ -121,35 +118,31 @@ void add_shouldInsertMessages() { new UserMessage("Message from user")); chatMemory.saveAll(sessionId, messages); + assertThat(chatMemory.findConversationIds()).isNotEmpty(); var cqlSession = context.getBean(CqlSession.class); + var query = """ - SELECT session_id, message_timestamp, a, u + SELECT session_id, message_timestamp, msgs FROM test_springframework.ai_chat_memory WHERE session_id = ? - ORDER BY message_timestamp ASC """; - ResultSet resultSet = cqlSession.execute(query, sessionId); - var rows = resultSet.all(); - assertThat(rows.size()).isEqualTo(messages.size()); + var result = cqlSession.execute(query, sessionId).one(); - for (var i = 0; i < messages.size(); i++) { - var message = messages.get(i); - var result = rows.get(i); - - assertThat(result.getString("session_id")).isNotNull(); - assertThat(result.getString("session_id")).isEqualTo(sessionId); - if (message.getMessageType() == MessageType.ASSISTANT) { - assertThat(result.getString("a")).isEqualTo(message.getText()); - assertThat(result.getString("u")).isNull(); - } - else if (message.getMessageType() == MessageType.USER) { - assertThat(result.getString("a")).isNull(); - assertThat(result.getString("u")).isEqualTo(message.getText()); - } - assertThat(result.getInstant("message_timestamp")).isNotNull(); - } + assertThat(result.getString("session_id")).isNotNull(); + assertThat(result.getString("session_id")).isEqualTo(sessionId); + assertThat(result.getInstant("message_timestamp")).isNotNull(); + List msgUdts = result.getList("msgs", UdtValue.class); + assertThat(msgUdts.size()).isEqualTo(2); + + assertThat(msgUdts.get(0).getInstant("msg_timestamp").toEpochMilli()) + .isLessThanOrEqualTo(msgUdts.get(1).getInstant("msg_timestamp").toEpochMilli()); + + assertThat(msgUdts.get(0).getString("msg_type")).isEqualTo(MessageType.ASSISTANT.name()); + assertThat(msgUdts.get(0).getString("msg_content")).isEqualTo("Message from assistant"); + assertThat(msgUdts.get(1).getString("msg_type")).isEqualTo(MessageType.USER.name()); + assertThat(msgUdts.get(1).getString("msg_content")).isEqualTo("Message from user"); }); } @@ -159,16 +152,15 @@ void get_shouldReturnMessages() { var chatMemory = context.getBean(ChatMemoryRepository.class); assertThat(chatMemory instanceof CassandraChatMemoryRepository); var sessionId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant 1 - " + sessionId), new AssistantMessage("Message from assistant 2 - " + sessionId), new UserMessage("Message from user - " + sessionId)); chatMemory.saveAll(sessionId, messages); - assertThat(chatMemory.findConversationIds()).isNotEmpty(); var results = chatMemory.findByConversationId(sessionId); - assertThat(results.size()).isEqualTo(messages.size()); for (var i = 0; i < messages.size(); i++) { @@ -191,11 +183,9 @@ void get_afterMultipleAdds_shouldReturnMessagesInSameOrder() { var assistantMessage = new AssistantMessage("Message from assistant - " + sessionId); chatMemory.saveAll(sessionId, List.of(userMessage, assistantMessage)); - assertThat(chatMemory.findConversationIds()).isNotEmpty(); var results = chatMemory.findByConversationId(sessionId); - assertThat(results.size()).isEqualTo(2); var messages = List.of(userMessage, assistantMessage); @@ -215,23 +205,28 @@ void clear_shouldDeleteMessages() { var chatMemory = context.getBean(ChatMemoryRepository.class); assertThat(chatMemory instanceof CassandraChatMemoryRepository); var sessionId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant - " + sessionId), new UserMessage("Message from user - " + sessionId)); chatMemory.saveAll(sessionId, messages); - assertThat(chatMemory.findConversationIds()).isNotEmpty(); chatMemory.deleteByConversationId(sessionId); + var results = chatMemory.findByConversationId(sessionId); + + assertThat(results.size()).isEqualTo(0); var cqlSession = context.getBean(CqlSession.class); + var query = """ - SELECT COUNT(*) + SELECT msgs FROM test_springframework.ai_chat_memory WHERE session_id = ? """; + ResultSet resultSet = cqlSession.execute(query, sessionId); - var count = resultSet.all().get(0).getLong(0); + var count = resultSet.all().get(0).getList("msgs", UdtValue.class).size(); assertThat(count).isZero(); }); @@ -247,8 +242,7 @@ public CassandraChatMemoryRepository memory(CqlSession cqlSession) { var conf = CassandraChatMemoryRepositoryConfig.builder() .withCqlSession(cqlSession) .withKeyspaceName("test_" + CassandraChatMemoryRepositoryConfig.DEFAULT_KEYSPACE_NAME) - .withAssistantColumnName("a") - .withUserColumnName("u") + .withMessagesColumnName("msgs") .withTimeToLive(Duration.ofMinutes(1)) .build(); diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat-memory.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat-memory.adoc index 893f350cf1a..dc2444034d3 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat-memory.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat-memory.adoc @@ -175,9 +175,11 @@ ChatMemoryRepository chatMemoryRepository = JdbcChatMemoryRepository.builder() === CassandraChatMemoryRepository -`CassandraChatMemoryRepository` uses Apache Cassandra to store messages. It is suitable for applications that require persistent storage of chat memory, especially for availability, at scale, or when taking advantage of time-to-live (TTL) messages. +`CassandraChatMemoryRepository` uses Apache Cassandra to store messages. It is suitable for applications that require persistent storage of chat memory, especially for availability, durability, scale, and when taking advantage of time-to-live (TTL) feature. -First, add the following dependency to your project: +`CassandraChatMemoryRepository` has a time-series schema, keeping record of all past chat windows, valuable for governance and auditing. Setting time-to-live to some value, for example three years, is recommended. + +To use `CassandraChatMemoryRepository` first, add the dependency to your project: [tabs] ====== @@ -235,10 +237,10 @@ ChatMemory chatMemory = MessageWindowChatMemory.builder() | `spring.cassandra.contactPoints` | Host(s) to initiate cluster discovery | `127.0.0.1` | `spring.cassandra.port` | Cassandra native protocol port to connect to | `9042` | `spring.cassandra.localDatacenter` | Cassandra datacenter to connect to | `datacenter1` -| `spring.ai.chat.memory.repository.cassandra.time-to-live` | Time to live (TTL) for messages written in Cassandra | -| `spring.ai.chat.memory.repository.cassandra.keyspace` | Cassandra keyspace | `springframework` -| `spring.ai.chat.memory.repository.cassandra.table` | Cassandra table | `ai_chat_memory` -| `spring.ai.chat.memory.repository.cassandra.initialize-schema` | Whether to initialize the schema on startup. | `true` +| `spring.ai.chat.memory.cassandra.time-to-live` | Time to live (TTL) for messages written in Cassandra | +| `spring.ai.chat.memory.cassandra.keyspace` | Cassandra keyspace | `springframework` +| `spring.ai.chat.memory.cassandra.table` | Cassandra table | `ai_chat_memory` +| `spring.ai.chat.memory.cassandra.initialize-schema` | Whether to initialize the schema on startup. | `true` |=== ==== Schema Initialization