Skip to content

Optimise CassandraChatMemoryRepository for MessageWindowChatMemory usage pattern #3097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ public CassandraChatMemoryRepository cassandraChatMemoryRepository(

builder = builder.withKeyspaceName(properties.getKeyspace())
.withTableName(properties.getTable())
.withAssistantColumnName(properties.getAssistantColumn())
.withUserColumnName(properties.getUserColumn());
.withMessagesColumnName(properties.getMessagesColumn());

if (!properties.isInitializeSchema()) {
builder = builder.disallowSchemaChanges();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@

import java.time.Duration;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryRepositoryConfig;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.lang.Nullable;
Expand All @@ -35,17 +32,13 @@
@ConfigurationProperties(CassandraChatMemoryRepositoryProperties.CONFIG_PREFIX)
public class CassandraChatMemoryRepositoryProperties {

public static final String CONFIG_PREFIX = "spring.ai.chat.memory.repository.cassandra";

private static final Logger logger = LoggerFactory.getLogger(CassandraChatMemoryRepositoryProperties.class);
public static final String CONFIG_PREFIX = "spring.ai.chat.memory.cassandra";

private String keyspace = CassandraChatMemoryRepositoryConfig.DEFAULT_KEYSPACE_NAME;

private String table = CassandraChatMemoryRepositoryConfig.DEFAULT_TABLE_NAME;

private String assistantColumn = CassandraChatMemoryRepositoryConfig.DEFAULT_ASSISTANT_COLUMN_NAME;

private String userColumn = CassandraChatMemoryRepositoryConfig.DEFAULT_USER_COLUMN_NAME;
private String messagesColumn = CassandraChatMemoryRepositoryConfig.DEFAULT_MESSAGES_COLUMN_NAME;

private boolean initializeSchema = true;

Expand Down Expand Up @@ -75,20 +68,12 @@ public void setTable(String table) {
this.table = table;
}

public String getAssistantColumn() {
return this.assistantColumn;
}

public void setAssistantColumn(String assistantColumn) {
this.assistantColumn = assistantColumn;
}

public String getUserColumn() {
return this.userColumn;
public String getMessagesColumn() {
return this.messagesColumn;
}

public void setUserColumn(String userColumn) {
this.userColumn = userColumn;
public void setMessagesColumn(String messagesColumn) {
this.messagesColumn = messagesColumn;
}

@Nullable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ void addAndGet() {
this.contextRunner.withPropertyValues("spring.cassandra.contactPoints=" + getContactPointHost())
.withPropertyValues("spring.cassandra.port=" + getContactPointPort())
.withPropertyValues("spring.cassandra.localDatacenter=" + cassandraContainer.getLocalDatacenter())
.withPropertyValues("spring.ai.chat.memory.repository.cassandra.time-to-live=" + getTimeToLive())
.withPropertyValues("spring.ai.chat.memory.cassandra.time-to-live=" + getTimeToLive())
.run(context -> {
CassandraChatMemoryRepository memory = context.getBean(CassandraChatMemoryRepository.class);

Expand Down Expand Up @@ -96,7 +96,7 @@ void compareTimeToLive_ISO8601Format() {
this.contextRunner.withPropertyValues("spring.cassandra.contactPoints=" + getContactPointHost())
.withPropertyValues("spring.cassandra.port=" + getContactPointPort())
.withPropertyValues("spring.cassandra.localDatacenter=" + cassandraContainer.getLocalDatacenter())
.withPropertyValues("spring.ai.chat.memory.repository.cassandra.time-to-live=" + getTimeToLiveString())
.withPropertyValues("spring.ai.chat.memory.cassandra.time-to-live=" + getTimeToLiveString())
.run(context -> {
CassandraChatMemoryRepositoryProperties properties = context
.getBean(CassandraChatMemoryRepositoryProperties.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ void defaultValues() {
var props = new CassandraChatMemoryRepositoryProperties();
assertThat(props.getKeyspace()).isEqualTo(CassandraChatMemoryRepositoryConfig.DEFAULT_KEYSPACE_NAME);
assertThat(props.getTable()).isEqualTo(CassandraChatMemoryRepositoryConfig.DEFAULT_TABLE_NAME);
assertThat(props.getAssistantColumn())
.isEqualTo(CassandraChatMemoryRepositoryConfig.DEFAULT_ASSISTANT_COLUMN_NAME);
assertThat(props.getUserColumn()).isEqualTo(CassandraChatMemoryRepositoryConfig.DEFAULT_USER_COLUMN_NAME);
assertThat(props.getMessagesColumn())
.isEqualTo(CassandraChatMemoryRepositoryConfig.DEFAULT_MESSAGES_COLUMN_NAME);

assertThat(props.getTimeToLive()).isNull();
assertThat(props.isInitializeSchema()).isTrue();
}
Expand All @@ -48,15 +48,13 @@ void customValues() {
var props = new CassandraChatMemoryRepositoryProperties();
props.setKeyspace("my_keyspace");
props.setTable("my_table");
props.setAssistantColumn("my_assistant_column");
props.setUserColumn("my_user_column");
props.setMessagesColumn("my_messages_column");
props.setTimeToLive(Duration.ofDays(1));
props.setInitializeSchema(false);

assertThat(props.getKeyspace()).isEqualTo("my_keyspace");
assertThat(props.getTable()).isEqualTo("my_table");
assertThat(props.getAssistantColumn()).isEqualTo("my_assistant_column");
assertThat(props.getUserColumn()).isEqualTo("my_user_column");
assertThat(props.getMessagesColumn()).isEqualTo("my_messages_column");
assertThat(props.getTimeToLive()).isEqualTo(Duration.ofDays(1));
assertThat(props.isInitializeSchema()).isFalse();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,26 @@

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import java.util.Map;

import com.datastax.oss.driver.api.core.cql.BoundStatement;
import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder;
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
import com.datastax.oss.driver.api.core.cql.Row;
import com.datastax.oss.driver.api.core.data.UdtValue;
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
import com.datastax.oss.driver.api.querybuilder.delete.Delete;
import com.datastax.oss.driver.api.querybuilder.delete.DeleteSelection;
import com.datastax.oss.driver.api.querybuilder.insert.InsertInto;
import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert;
import com.datastax.oss.driver.api.querybuilder.select.Select;
import com.datastax.oss.driver.shaded.guava.common.base.Preconditions;

import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.util.Assert;

Expand All @@ -54,23 +56,17 @@ public class CassandraChatMemoryRepository implements ChatMemoryRepository {

private final PreparedStatement allStmt;

private final PreparedStatement addUserStmt;

private final PreparedStatement addAssistantStmt;
private final PreparedStatement addStmt;

private final PreparedStatement getStmt;

private final PreparedStatement deleteStmt;

private CassandraChatMemoryRepository(CassandraChatMemoryRepositoryConfig conf) {
Assert.notNull(conf, "conf cannot be null");
this.conf = conf;
this.conf.ensureSchemaExists();
this.allStmt = prepareAllStatement();
this.addUserStmt = prepareAddStmt(this.conf.userColumn);
this.addAssistantStmt = prepareAddStmt(this.conf.assistantColumn);
this.addStmt = prepareAddStmt();
this.getStmt = prepareGetStatement();
this.deleteStmt = prepareDeleteStmt();
}

public static CassandraChatMemoryRepository create(CassandraChatMemoryRepositoryConfig conf) {
Expand All @@ -97,6 +93,10 @@ public List<String> findConversationIds() {

@Override
public List<Message> findByConversationId(String conversationId) {
return findByConversationIdWithLimit(conversationId, 1);
}

List<Message> findByConversationIdWithLimit(String conversationId, int limit) {
Assert.hasText(conversationId, "conversationId cannot be null or empty");

List<Object> primaryKeys = this.conf.primaryKeyTranslator.apply(conversationId);
Expand All @@ -106,19 +106,14 @@ public List<Message> findByConversationId(String conversationId) {
CassandraChatMemoryRepositoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
}
builder = builder.setInt("legacy_limit", limit);

List<Message> messages = new ArrayList<>();
for (Row r : this.conf.session.execute(builder.build())) {
String assistant = r.getString(this.conf.assistantColumn);
String user = r.getString(this.conf.userColumn);
if (null != assistant) {
messages.add(new AssistantMessage(assistant));
}
if (null != user) {
messages.add(new UserMessage(user));
for (UdtValue udt : r.getList(this.conf.messagesColumn, UdtValue.class)) {
messages.add(getMessage(udt));
}
}
Collections.reverse(messages);
return messages;
}

Expand All @@ -128,58 +123,49 @@ public void saveAll(String conversationId, List<Message> messages) {
Assert.notNull(messages, "messages cannot be null");
Assert.noNullElements(messages, "messages cannot contain null elements");

final AtomicLong instantSeq = new AtomicLong(Instant.now().toEpochMilli());
messages.forEach(msg -> {
if (msg.getMetadata().containsKey(CONVERSATION_TS)) {
msg.getMetadata().put(CONVERSATION_TS, Instant.ofEpochMilli(instantSeq.getAndIncrement()));
}
save(conversationId, msg);
});
}

void save(String conversationId, Message msg) {

Preconditions.checkArgument(
!msg.getMetadata().containsKey(CONVERSATION_TS)
|| msg.getMetadata().get(CONVERSATION_TS) instanceof Instant,
"messages only accept metadata '%s' entries of type Instant", CONVERSATION_TS);

msg.getMetadata().putIfAbsent(CONVERSATION_TS, Instant.now());

PreparedStatement stmt = getStatement(msg);

Instant instant = Instant.now();
List<Object> primaryKeys = this.conf.primaryKeyTranslator.apply(conversationId);
BoundStatementBuilder builder = stmt.boundStatementBuilder();
BoundStatementBuilder builder = addStmt.boundStatementBuilder();

for (int k = 0; k < primaryKeys.size(); ++k) {
CassandraChatMemoryRepositoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
}

Instant instant = (Instant) msg.getMetadata().get(CONVERSATION_TS);
List<UdtValue> msgs = new ArrayList<>();
for (Message msg : messages) {

Preconditions.checkArgument(
!msg.getMetadata().containsKey(CONVERSATION_TS)
|| msg.getMetadata().get(CONVERSATION_TS) instanceof Instant,
"messages only accept metadata '%s' entries of type Instant", CONVERSATION_TS);

msg.getMetadata().putIfAbsent(CONVERSATION_TS, instant);

UdtValue udt = this.conf.session.getMetadata()
.getKeyspace(this.conf.schema.keyspace())
.get()
.getUserDefinedType(this.conf.messageUDT)
.get()
.newValue()
.setInstant(this.conf.messageUdtTimestampColumn, (Instant) msg.getMetadata().get(CONVERSATION_TS))
.setString(this.conf.messageUdtTypeColumn, msg.getMessageType().name())
.setString(this.conf.messageUdtContentColumn, msg.getText());

msgs.add(udt);
}
builder = builder.setInstant(CassandraChatMemoryRepositoryConfig.DEFAULT_EXCHANGE_ID_NAME, instant)
.setString("message", msg.getText());
.setList("msgs", msgs, UdtValue.class);

this.conf.session.execute(builder.build());
}

@Override
public void deleteByConversationId(String conversationId) {
Assert.hasText(conversationId, "conversationId cannot be null or empty");

List<Object> primaryKeys = this.conf.primaryKeyTranslator.apply(conversationId);
BoundStatementBuilder builder = this.deleteStmt.boundStatementBuilder();

for (int k = 0; k < primaryKeys.size(); ++k) {
CassandraChatMemoryRepositoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
}

this.conf.session.execute(builder.build());
saveAll(conversationId, List.of());
}

private PreparedStatement prepareAddStmt(String column) {
private PreparedStatement prepareAddStmt() {
RegularInsert stmt = null;
InsertInto stmtStart = QueryBuilder.insertInto(this.conf.schema.keyspace(), this.conf.schema.table());
for (var c : this.conf.schema.partitionKeys()) {
Expand All @@ -188,7 +174,7 @@ private PreparedStatement prepareAddStmt(String column) {
for (var c : this.conf.schema.clusteringKeys()) {
stmt = stmt.value(c.name(), QueryBuilder.bindMarker(c.name()));
}
stmt = stmt.value(column, QueryBuilder.bindMarker("message"));
stmt = stmt.value(this.conf.messagesColumn, QueryBuilder.bindMarker("msgs"));
return this.conf.session.prepare(stmt.build());
}

Expand All @@ -214,28 +200,27 @@ private PreparedStatement prepareGetStatement() {
String columnName = this.conf.schema.clusteringKeys().get(i).name();
stmt = stmt.whereColumn(columnName).isEqualTo(QueryBuilder.bindMarker(columnName));
}
stmt = stmt.limit(QueryBuilder.bindMarker("legacy_limit"));
return this.conf.session.prepare(stmt.build());
}

private PreparedStatement prepareDeleteStmt() {
Delete stmt = null;
DeleteSelection stmtStart = QueryBuilder.deleteFrom(this.conf.schema.keyspace(), this.conf.schema.table());
for (var c : this.conf.schema.partitionKeys()) {
stmt = (null != stmt ? stmt : stmtStart).whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name()));
}
for (int i = 0; i + 1 < this.conf.schema.clusteringKeys().size(); ++i) {
String columnName = this.conf.schema.clusteringKeys().get(i).name();
stmt = stmt.whereColumn(columnName).isEqualTo(QueryBuilder.bindMarker(columnName));
private Message getMessage(UdtValue udt) {
String content = udt.getString(this.conf.messageUdtContentColumn);
Map<String, Object> props = Map.of(CONVERSATION_TS, udt.getInstant(this.conf.messageUdtTimestampColumn));
Copy link
Contributor Author

@michaelsembwever michaelsembwever May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link

@ZYMCao ZYMCao May 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@michaelsembwever
unix timestamp is great, but I am not sure about fraction part:

[
{
"messageType": "USER",
"metadata": {
"CassandraChatMemoryRepository_message_timestamp": 1747564565.737000000,
"messageType": "USER"
},
"media": [],
"text": "SSE development using springboot?"
},
{
"messageType": "ASSISTANT",
"metadata": {
"CassandraChatMemoryRepository_message_timestamp": 1747564594.962000000,
"messageType": "ASSISTANT"
},
"toolCalls": [],
"media": [],
"text": "..."
}
]

switch (MessageType.valueOf(udt.getString(this.conf.messageUdtTypeColumn))) {
case ASSISTANT:
return new AssistantMessage(content, props);
case USER:
return UserMessage.builder().text(content).metadata(props).build();
case SYSTEM:
return SystemMessage.builder().text(content).metadata(props).build();
case TOOL:
// todo – persist ToolResponse somehow
return new ToolResponseMessage(List.of(), props);
default:
throw new IllegalStateException(
String.format("unknown message type %s", udt.getString(this.conf.messageUdtTypeColumn)));
}
return this.conf.session.prepare(stmt.build());
}

private PreparedStatement getStatement(Message msg) {
return switch (msg.getMessageType()) {
case USER -> this.addUserStmt;
case ASSISTANT -> this.addAssistantStmt;
default -> throw new IllegalArgumentException("Cant add type " + msg);
};
}

}
Loading