Skip to content

Commit 73e31d1

Browse files
Optimise CassandraChatMemoryRepository for MessageWindowChatMemory usage pattern
Time-series each chat window in Cassandra, keeping past (and deleted) windows still in the db. Add ability to store different MessageTypes. Signed-off-by: mck <[email protected]>
1 parent 30eb3ce commit 73e31d1

File tree

8 files changed

+183
-193
lines changed

8 files changed

+183
-193
lines changed

auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryAutoConfiguration.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ public CassandraChatMemoryRepository cassandraChatMemoryRepository(
4949

5050
builder = builder.withKeyspaceName(properties.getKeyspace())
5151
.withTableName(properties.getTable())
52-
.withAssistantColumnName(properties.getAssistantColumn())
53-
.withUserColumnName(properties.getUserColumn());
52+
.withMessagesColumnName(properties.getMessagesColumn());
5453

5554
if (!properties.isInitializeSchema()) {
5655
builder = builder.disallowSchemaChanges();

auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryProperties.java

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818

1919
import java.time.Duration;
2020

21-
import org.slf4j.Logger;
22-
import org.slf4j.LoggerFactory;
23-
2421
import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryRepositoryConfig;
2522
import org.springframework.boot.context.properties.ConfigurationProperties;
2623
import org.springframework.lang.Nullable;
@@ -35,17 +32,13 @@
3532
@ConfigurationProperties(CassandraChatMemoryRepositoryProperties.CONFIG_PREFIX)
3633
public class CassandraChatMemoryRepositoryProperties {
3734

38-
public static final String CONFIG_PREFIX = "spring.ai.chat.memory.repository.cassandra";
39-
40-
private static final Logger logger = LoggerFactory.getLogger(CassandraChatMemoryRepositoryProperties.class);
35+
public static final String CONFIG_PREFIX = "spring.ai.chat.memory.cassandra";
4136

4237
private String keyspace = CassandraChatMemoryRepositoryConfig.DEFAULT_KEYSPACE_NAME;
4338

4439
private String table = CassandraChatMemoryRepositoryConfig.DEFAULT_TABLE_NAME;
4540

46-
private String assistantColumn = CassandraChatMemoryRepositoryConfig.DEFAULT_ASSISTANT_COLUMN_NAME;
47-
48-
private String userColumn = CassandraChatMemoryRepositoryConfig.DEFAULT_USER_COLUMN_NAME;
41+
private String messagesColumn = CassandraChatMemoryRepositoryConfig.DEFAULT_MESSAGES_COLUMN_NAME;
4942

5043
private boolean initializeSchema = true;
5144

@@ -75,20 +68,12 @@ public void setTable(String table) {
7568
this.table = table;
7669
}
7770

78-
public String getAssistantColumn() {
79-
return this.assistantColumn;
80-
}
81-
82-
public void setAssistantColumn(String assistantColumn) {
83-
this.assistantColumn = assistantColumn;
84-
}
85-
86-
public String getUserColumn() {
87-
return this.userColumn;
71+
public String getMessagesColumn() {
72+
return this.messagesColumn;
8873
}
8974

90-
public void setUserColumn(String userColumn) {
91-
this.userColumn = userColumn;
75+
public void setMessagesColumn(String messagesColumn) {
76+
this.messagesColumn = messagesColumn;
9277
}
9378

9479
@Nullable

auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/test/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryAutoConfigurationIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ void addAndGet() {
5959
this.contextRunner.withPropertyValues("spring.cassandra.contactPoints=" + getContactPointHost())
6060
.withPropertyValues("spring.cassandra.port=" + getContactPointPort())
6161
.withPropertyValues("spring.cassandra.localDatacenter=" + cassandraContainer.getLocalDatacenter())
62-
.withPropertyValues("spring.ai.chat.memory.repository.cassandra.time-to-live=" + getTimeToLive())
62+
.withPropertyValues("spring.ai.chat.memory.cassandra.time-to-live=" + getTimeToLive())
6363
.run(context -> {
6464
CassandraChatMemoryRepository memory = context.getBean(CassandraChatMemoryRepository.class);
6565

@@ -96,7 +96,7 @@ void compareTimeToLive_ISO8601Format() {
9696
this.contextRunner.withPropertyValues("spring.cassandra.contactPoints=" + getContactPointHost())
9797
.withPropertyValues("spring.cassandra.port=" + getContactPointPort())
9898
.withPropertyValues("spring.cassandra.localDatacenter=" + cassandraContainer.getLocalDatacenter())
99-
.withPropertyValues("spring.ai.chat.memory.repository.cassandra.time-to-live=" + getTimeToLiveString())
99+
.withPropertyValues("spring.ai.chat.memory.cassandra.time-to-live=" + getTimeToLiveString())
100100
.run(context -> {
101101
CassandraChatMemoryRepositoryProperties properties = context
102102
.getBean(CassandraChatMemoryRepositoryProperties.class);

auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/test/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryPropertiesTest.java

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ void defaultValues() {
3636
var props = new CassandraChatMemoryRepositoryProperties();
3737
assertThat(props.getKeyspace()).isEqualTo(CassandraChatMemoryRepositoryConfig.DEFAULT_KEYSPACE_NAME);
3838
assertThat(props.getTable()).isEqualTo(CassandraChatMemoryRepositoryConfig.DEFAULT_TABLE_NAME);
39-
assertThat(props.getAssistantColumn())
40-
.isEqualTo(CassandraChatMemoryRepositoryConfig.DEFAULT_ASSISTANT_COLUMN_NAME);
41-
assertThat(props.getUserColumn()).isEqualTo(CassandraChatMemoryRepositoryConfig.DEFAULT_USER_COLUMN_NAME);
39+
assertThat(props.getMessagesColumn())
40+
.isEqualTo(CassandraChatMemoryRepositoryConfig.DEFAULT_MESSAGES_COLUMN_NAME);
41+
4242
assertThat(props.getTimeToLive()).isNull();
4343
assertThat(props.isInitializeSchema()).isTrue();
4444
}
@@ -48,15 +48,13 @@ void customValues() {
4848
var props = new CassandraChatMemoryRepositoryProperties();
4949
props.setKeyspace("my_keyspace");
5050
props.setTable("my_table");
51-
props.setAssistantColumn("my_assistant_column");
52-
props.setUserColumn("my_user_column");
51+
props.setMessagesColumn("my_messages_column");
5352
props.setTimeToLive(Duration.ofDays(1));
5453
props.setInitializeSchema(false);
5554

5655
assertThat(props.getKeyspace()).isEqualTo("my_keyspace");
5756
assertThat(props.getTable()).isEqualTo("my_table");
58-
assertThat(props.getAssistantColumn()).isEqualTo("my_assistant_column");
59-
assertThat(props.getUserColumn()).isEqualTo("my_user_column");
57+
assertThat(props.getMessagesColumn()).isEqualTo("my_messages_column");
6058
assertThat(props.getTimeToLive()).isEqualTo(Duration.ofDays(1));
6159
assertThat(props.isInitializeSchema()).isFalse();
6260
}

memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepository.java

Lines changed: 59 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,26 @@
1818

1919
import java.time.Instant;
2020
import java.util.ArrayList;
21-
import java.util.Collections;
2221
import java.util.List;
23-
import java.util.concurrent.atomic.AtomicLong;
22+
import java.util.Map;
2423

2524
import com.datastax.oss.driver.api.core.cql.BoundStatement;
2625
import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder;
2726
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
2827
import com.datastax.oss.driver.api.core.cql.Row;
28+
import com.datastax.oss.driver.api.core.data.UdtValue;
2929
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
30-
import com.datastax.oss.driver.api.querybuilder.delete.Delete;
31-
import com.datastax.oss.driver.api.querybuilder.delete.DeleteSelection;
3230
import com.datastax.oss.driver.api.querybuilder.insert.InsertInto;
3331
import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert;
3432
import com.datastax.oss.driver.api.querybuilder.select.Select;
3533
import com.datastax.oss.driver.shaded.guava.common.base.Preconditions;
34+
3635
import org.springframework.ai.chat.memory.ChatMemoryRepository;
3736
import org.springframework.ai.chat.messages.AssistantMessage;
3837
import org.springframework.ai.chat.messages.Message;
38+
import org.springframework.ai.chat.messages.MessageType;
39+
import org.springframework.ai.chat.messages.SystemMessage;
40+
import org.springframework.ai.chat.messages.ToolResponseMessage;
3941
import org.springframework.ai.chat.messages.UserMessage;
4042
import org.springframework.util.Assert;
4143

@@ -54,23 +56,17 @@ public class CassandraChatMemoryRepository implements ChatMemoryRepository {
5456

5557
private final PreparedStatement allStmt;
5658

57-
private final PreparedStatement addUserStmt;
58-
59-
private final PreparedStatement addAssistantStmt;
59+
private final PreparedStatement addStmt;
6060

6161
private final PreparedStatement getStmt;
6262

63-
private final PreparedStatement deleteStmt;
64-
6563
private CassandraChatMemoryRepository(CassandraChatMemoryRepositoryConfig conf) {
6664
Assert.notNull(conf, "conf cannot be null");
6765
this.conf = conf;
6866
this.conf.ensureSchemaExists();
6967
this.allStmt = prepareAllStatement();
70-
this.addUserStmt = prepareAddStmt(this.conf.userColumn);
71-
this.addAssistantStmt = prepareAddStmt(this.conf.assistantColumn);
68+
this.addStmt = prepareAddStmt();
7269
this.getStmt = prepareGetStatement();
73-
this.deleteStmt = prepareDeleteStmt();
7470
}
7571

7672
public static CassandraChatMemoryRepository create(CassandraChatMemoryRepositoryConfig conf) {
@@ -97,6 +93,10 @@ public List<String> findConversationIds() {
9793

9894
@Override
9995
public List<Message> findByConversationId(String conversationId) {
96+
return findByConversationIdWithLimit(conversationId, 1);
97+
}
98+
99+
List<Message> findByConversationIdWithLimit(String conversationId, int limit) {
100100
Assert.hasText(conversationId, "conversationId cannot be null or empty");
101101

102102
List<Object> primaryKeys = this.conf.primaryKeyTranslator.apply(conversationId);
@@ -106,19 +106,14 @@ public List<Message> findByConversationId(String conversationId) {
106106
CassandraChatMemoryRepositoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
107107
builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
108108
}
109+
builder = builder.setInt("legacy_limit", limit);
109110

110111
List<Message> messages = new ArrayList<>();
111112
for (Row r : this.conf.session.execute(builder.build())) {
112-
String assistant = r.getString(this.conf.assistantColumn);
113-
String user = r.getString(this.conf.userColumn);
114-
if (null != assistant) {
115-
messages.add(new AssistantMessage(assistant));
116-
}
117-
if (null != user) {
118-
messages.add(new UserMessage(user));
113+
for (UdtValue udt : r.getList(this.conf.messagesColumn, UdtValue.class)) {
114+
messages.add(getMessage(udt));
119115
}
120116
}
121-
Collections.reverse(messages);
122117
return messages;
123118
}
124119

@@ -128,58 +123,49 @@ public void saveAll(String conversationId, List<Message> messages) {
128123
Assert.notNull(messages, "messages cannot be null");
129124
Assert.noNullElements(messages, "messages cannot contain null elements");
130125

131-
final AtomicLong instantSeq = new AtomicLong(Instant.now().toEpochMilli());
132-
messages.forEach(msg -> {
133-
if (msg.getMetadata().containsKey(CONVERSATION_TS)) {
134-
msg.getMetadata().put(CONVERSATION_TS, Instant.ofEpochMilli(instantSeq.getAndIncrement()));
135-
}
136-
save(conversationId, msg);
137-
});
138-
}
139-
140-
void save(String conversationId, Message msg) {
141-
142-
Preconditions.checkArgument(
143-
!msg.getMetadata().containsKey(CONVERSATION_TS)
144-
|| msg.getMetadata().get(CONVERSATION_TS) instanceof Instant,
145-
"messages only accept metadata '%s' entries of type Instant", CONVERSATION_TS);
146-
147-
msg.getMetadata().putIfAbsent(CONVERSATION_TS, Instant.now());
148-
149-
PreparedStatement stmt = getStatement(msg);
150-
126+
Instant instant = Instant.now();
151127
List<Object> primaryKeys = this.conf.primaryKeyTranslator.apply(conversationId);
152-
BoundStatementBuilder builder = stmt.boundStatementBuilder();
128+
BoundStatementBuilder builder = addStmt.boundStatementBuilder();
153129

154130
for (int k = 0; k < primaryKeys.size(); ++k) {
155131
CassandraChatMemoryRepositoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
156132
builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
157133
}
158134

159-
Instant instant = (Instant) msg.getMetadata().get(CONVERSATION_TS);
135+
List<UdtValue> msgs = new ArrayList<>();
136+
for (Message msg : messages) {
137+
138+
Preconditions.checkArgument(
139+
!msg.getMetadata().containsKey(CONVERSATION_TS)
140+
|| msg.getMetadata().get(CONVERSATION_TS) instanceof Instant,
141+
"messages only accept metadata '%s' entries of type Instant", CONVERSATION_TS);
160142

143+
msg.getMetadata().putIfAbsent(CONVERSATION_TS, instant);
144+
145+
UdtValue udt = this.conf.session.getMetadata()
146+
.getKeyspace(this.conf.schema.keyspace())
147+
.get()
148+
.getUserDefinedType(this.conf.messageUDT)
149+
.get()
150+
.newValue()
151+
.setInstant(this.conf.messageUdtTimestampColumn, (Instant) msg.getMetadata().get(CONVERSATION_TS))
152+
.setString(this.conf.messageUdtTypeColumn, msg.getMessageType().name())
153+
.setString(this.conf.messageUdtContentColumn, msg.getText());
154+
155+
msgs.add(udt);
156+
}
161157
builder = builder.setInstant(CassandraChatMemoryRepositoryConfig.DEFAULT_EXCHANGE_ID_NAME, instant)
162-
.setString("message", msg.getText());
158+
.setList("msgs", msgs, UdtValue.class);
163159

164160
this.conf.session.execute(builder.build());
165161
}
166162

167163
@Override
168164
public void deleteByConversationId(String conversationId) {
169-
Assert.hasText(conversationId, "conversationId cannot be null or empty");
170-
171-
List<Object> primaryKeys = this.conf.primaryKeyTranslator.apply(conversationId);
172-
BoundStatementBuilder builder = this.deleteStmt.boundStatementBuilder();
173-
174-
for (int k = 0; k < primaryKeys.size(); ++k) {
175-
CassandraChatMemoryRepositoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
176-
builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
177-
}
178-
179-
this.conf.session.execute(builder.build());
165+
saveAll(conversationId, List.of());
180166
}
181167

182-
private PreparedStatement prepareAddStmt(String column) {
168+
private PreparedStatement prepareAddStmt() {
183169
RegularInsert stmt = null;
184170
InsertInto stmtStart = QueryBuilder.insertInto(this.conf.schema.keyspace(), this.conf.schema.table());
185171
for (var c : this.conf.schema.partitionKeys()) {
@@ -188,7 +174,7 @@ private PreparedStatement prepareAddStmt(String column) {
188174
for (var c : this.conf.schema.clusteringKeys()) {
189175
stmt = stmt.value(c.name(), QueryBuilder.bindMarker(c.name()));
190176
}
191-
stmt = stmt.value(column, QueryBuilder.bindMarker("message"));
177+
stmt = stmt.value(this.conf.messagesColumn, QueryBuilder.bindMarker("msgs"));
192178
return this.conf.session.prepare(stmt.build());
193179
}
194180

@@ -214,28 +200,27 @@ private PreparedStatement prepareGetStatement() {
214200
String columnName = this.conf.schema.clusteringKeys().get(i).name();
215201
stmt = stmt.whereColumn(columnName).isEqualTo(QueryBuilder.bindMarker(columnName));
216202
}
203+
stmt = stmt.limit(QueryBuilder.bindMarker("legacy_limit"));
217204
return this.conf.session.prepare(stmt.build());
218205
}
219206

220-
private PreparedStatement prepareDeleteStmt() {
221-
Delete stmt = null;
222-
DeleteSelection stmtStart = QueryBuilder.deleteFrom(this.conf.schema.keyspace(), this.conf.schema.table());
223-
for (var c : this.conf.schema.partitionKeys()) {
224-
stmt = (null != stmt ? stmt : stmtStart).whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name()));
225-
}
226-
for (int i = 0; i + 1 < this.conf.schema.clusteringKeys().size(); ++i) {
227-
String columnName = this.conf.schema.clusteringKeys().get(i).name();
228-
stmt = stmt.whereColumn(columnName).isEqualTo(QueryBuilder.bindMarker(columnName));
207+
private Message getMessage(UdtValue udt) {
208+
String content = udt.getString(this.conf.messageUdtContentColumn);
209+
Map<String, Object> props = Map.of(CONVERSATION_TS, udt.getInstant(this.conf.messageUdtTimestampColumn));
210+
switch (MessageType.valueOf(udt.getString(this.conf.messageUdtTypeColumn))) {
211+
case ASSISTANT:
212+
return new AssistantMessage(content, props);
213+
case USER:
214+
return UserMessage.builder().text(content).metadata(props).build();
215+
case SYSTEM:
216+
return SystemMessage.builder().text(content).metadata(props).build();
217+
case TOOL:
218+
// todo – persist ToolResponse somehow
219+
return new ToolResponseMessage(List.of(), props);
220+
default:
221+
throw new IllegalStateException(
222+
String.format("unknown message type %s", udt.getString(this.conf.messageUdtTypeColumn)));
229223
}
230-
return this.conf.session.prepare(stmt.build());
231-
}
232-
233-
private PreparedStatement getStatement(Message msg) {
234-
return switch (msg.getMessageType()) {
235-
case USER -> this.addUserStmt;
236-
case ASSISTANT -> this.addAssistantStmt;
237-
default -> throw new IllegalArgumentException("Cant add type " + msg);
238-
};
239224
}
240225

241226
}

0 commit comments

Comments
 (0)