Skip to content

Commit 2450ae5

Browse files
michaelsembwevermarkpollack
authored andcommitted
Implement CassandraChatMemoryRepository
ref: #2998 Signed-off-by: mck <[email protected]>
1 parent 20b560f commit 2450ae5

File tree

13 files changed

+421
-197
lines changed

13 files changed

+421
-197
lines changed

auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
</parent>
1212
<artifactId>spring-ai-autoconfigure-model-chat-memory-cassandra</artifactId>
1313
<packaging>jar</packaging>
14-
<name>Spring AI Cassandra Chat Memory Auto Configuration</name>
15-
<description>Spring AI Cassandra Chat Memory Auto Configuration</description>
14+
<name>Spring AI Apache Cassandra Chat Memory Auto Configuration</name>
15+
<description>Spring AI Apache Cassandra Chat Memory Auto Configuration</description>
1616
<url>https://github.com/spring-projects/spring-ai</url>
1717

1818
<scm>
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
* limitations under the License.
1515
*/
1616

17-
package org.springframework.ai.model.chat.memory.cassandra.autoconfigure;
17+
package org.springframework.ai.model.chat.memory.repository.cassandra.autoconfigure;
1818

1919
import com.datastax.oss.driver.api.core.CqlSession;
2020

21-
import org.springframework.ai.chat.memory.cassandra.CassandraChatMemory;
2221
import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryConfig;
22+
import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryRepository;
2323
import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration;
2424
import org.springframework.boot.autoconfigure.AutoConfiguration;
2525
import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration;
@@ -36,13 +36,13 @@
3636
* @since 1.0.0
3737
*/
3838
@AutoConfiguration(after = CassandraAutoConfiguration.class, before = ChatMemoryAutoConfiguration.class)
39-
@ConditionalOnClass({ CassandraChatMemory.class, CqlSession.class })
39+
@ConditionalOnClass({ CassandraChatMemoryRepository.class, CqlSession.class })
4040
@EnableConfigurationProperties(CassandraChatMemoryProperties.class)
4141
public class CassandraChatMemoryAutoConfiguration {
4242

4343
@Bean
4444
@ConditionalOnMissingBean
45-
public CassandraChatMemory chatMemory(CassandraChatMemoryProperties properties, CqlSession cqlSession) {
45+
public CassandraChatMemoryRepository chatMemory(CassandraChatMemoryProperties properties, CqlSession cqlSession) {
4646

4747
var builder = CassandraChatMemoryConfig.builder().withCqlSession(cqlSession);
4848

@@ -58,7 +58,7 @@ public CassandraChatMemory chatMemory(CassandraChatMemoryProperties properties,
5858
builder = builder.withTimeToLive(properties.getTimeToLive());
5959
}
6060

61-
return CassandraChatMemory.create(builder.build());
61+
return CassandraChatMemoryRepository.create(builder.build());
6262
}
6363

6464
}
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17-
package org.springframework.ai.model.chat.memory.cassandra.autoconfigure;
17+
package org.springframework.ai.model.chat.memory.repository.cassandra.autoconfigure;
1818

1919
import java.time.Duration;
2020

@@ -35,7 +35,7 @@
3535
@ConfigurationProperties(CassandraChatMemoryProperties.CONFIG_PREFIX)
3636
public class CassandraChatMemoryProperties {
3737

38-
public static final String CONFIG_PREFIX = "spring.ai.chat.memory.cassandra";
38+
public static final String CONFIG_PREFIX = "spring.ai.chat.memory.repository.cassandra";
3939

4040
private static final Logger logger = LoggerFactory.getLogger(CassandraChatMemoryProperties.class);
4141

auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16-
org.springframework.ai.model.chat.memory.cassandra.autoconfigure.CassandraChatMemoryAutoConfiguration
16+
org.springframework.ai.model.chat.memory.repository.cassandra.autoconfigure.CassandraChatMemoryAutoConfiguration
Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17-
package org.springframework.ai.model.chat.memory.cassandra.autoconfigure;
17+
package org.springframework.ai.model.chat.memory.repository.cassandra.autoconfigure;
1818

1919
import java.time.Duration;
2020
import java.util.List;
@@ -26,7 +26,7 @@
2626
import org.testcontainers.junit.jupiter.Testcontainers;
2727
import org.testcontainers.utility.DockerImageName;
2828

29-
import org.springframework.ai.chat.memory.cassandra.CassandraChatMemory;
29+
import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryRepository;
3030
import org.springframework.ai.chat.messages.AssistantMessage;
3131
import org.springframework.ai.chat.messages.MessageType;
3232
import org.springframework.ai.chat.messages.UserMessage;
@@ -61,30 +61,29 @@ void addAndGet() {
6161
.withPropertyValues("spring.cassandra.localDatacenter=" + cassandraContainer.getLocalDatacenter())
6262
.withPropertyValues("spring.ai.chat.memory.cassandra.time-to-live=" + getTimeToLive())
6363
.run(context -> {
64-
CassandraChatMemory memory = context.getBean(CassandraChatMemory.class);
64+
CassandraChatMemoryRepository memory = context.getBean(CassandraChatMemoryRepository.class);
6565

6666
String sessionId = UUIDs.timeBased().toString();
67-
assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty();
67+
assertThat(memory.findByConversationId(sessionId)).isEmpty();
6868

69-
memory.add(sessionId, new UserMessage("test question"));
69+
memory.saveAll(sessionId, List.of(new UserMessage("test question")));
7070

71-
assertThat(memory.get(sessionId, Integer.MAX_VALUE)).hasSize(1);
72-
assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(0).getMessageType())
73-
.isEqualTo(MessageType.USER);
74-
assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(0).getText()).isEqualTo("test question");
71+
assertThat(memory.findByConversationId(sessionId)).hasSize(1);
72+
assertThat(memory.findByConversationId(sessionId).get(0).getMessageType()).isEqualTo(MessageType.USER);
73+
assertThat(memory.findByConversationId(sessionId).get(0).getText()).isEqualTo("test question");
7574

76-
memory.clear(sessionId);
77-
assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty();
75+
memory.deleteByConversationId(sessionId);
76+
assertThat(memory.findByConversationId(sessionId)).isEmpty();
7877

79-
memory.add(sessionId, List.of(new UserMessage("test question"), new AssistantMessage("test answer")));
78+
memory.saveAll(sessionId,
79+
List.of(new UserMessage("test question"), new AssistantMessage("test answer")));
8080

81-
assertThat(memory.get(sessionId, Integer.MAX_VALUE)).hasSize(2);
82-
assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(1).getMessageType())
81+
assertThat(memory.findByConversationId(sessionId)).hasSize(2);
82+
assertThat(memory.findByConversationId(sessionId).get(1).getMessageType())
8383
.isEqualTo(MessageType.ASSISTANT);
84-
assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(1).getText()).isEqualTo("test answer");
85-
assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(0).getMessageType())
86-
.isEqualTo(MessageType.USER);
87-
assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(0).getText()).isEqualTo("test question");
84+
assertThat(memory.findByConversationId(sessionId).get(1).getText()).isEqualTo("test answer");
85+
assertThat(memory.findByConversationId(sessionId).get(0).getMessageType()).isEqualTo(MessageType.USER);
86+
assertThat(memory.findByConversationId(sessionId).get(0).getText()).isEqualTo("test question");
8887

8988
CassandraChatMemoryProperties properties = context.getBean(CassandraChatMemoryProperties.class);
9089
assertThat(properties.getTimeToLive()).isEqualTo(getTimeToLive());
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17-
package org.springframework.ai.model.chat.memory.cassandra.autoconfigure;
17+
package org.springframework.ai.model.chat.memory.repository.cassandra.autoconfigure;
1818

1919
import java.time.Duration;
2020

auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-cassandra/pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
</parent>
2828
<artifactId>spring-ai-autoconfigure-vector-store-cassandra</artifactId>
2929
<packaging>jar</packaging>
30-
<name>Spring AI Auto Configuration for Cassandra vector store</name>
31-
<description>Spring AI Auto Configuration for Cassandra vector store</description>
30+
<name>Spring AI Auto Configuration for Apache Cassandra vector store</name>
31+
<description>Spring AI Auto Configuration for Apache Cassandra vector store</description>
3232
<url>https://github.com/spring-projects/spring-ai</url>
3333

3434
<scm>

memory/spring-ai-model-chat-memory-cassandra/pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
</parent>
2828

2929
<artifactId>spring-ai-model-chat-memory-cassandra</artifactId>
30-
<name>Spring AI Cassandra Chat Memory</name>
31-
<description>Spring AI Cassandra Chat Memory implementation</description>
30+
<name>Spring AI Apache Cassandra Chat Memory</name>
31+
<description>Spring AI Apache Cassandra Chat Memory implementation</description>
3232

3333
<url>https://github.com/spring-projects/spring-ai</url>
3434

memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemory.java

Lines changed: 9 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -16,60 +16,32 @@
1616

1717
package org.springframework.ai.chat.memory.cassandra;
1818

19-
import java.time.Instant;
20-
import java.util.ArrayList;
21-
import java.util.Collections;
2219
import java.util.List;
23-
import java.util.concurrent.atomic.AtomicLong;
24-
25-
import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder;
26-
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
27-
import com.datastax.oss.driver.api.core.cql.Row;
28-
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
29-
import com.datastax.oss.driver.api.querybuilder.delete.Delete;
30-
import com.datastax.oss.driver.api.querybuilder.delete.DeleteSelection;
31-
import com.datastax.oss.driver.api.querybuilder.insert.InsertInto;
32-
import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert;
33-
import com.datastax.oss.driver.api.querybuilder.select.Select;
34-
import com.datastax.oss.driver.shaded.guava.common.base.Preconditions;
3520

3621
import org.springframework.ai.chat.memory.ChatMemory;
37-
import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryConfig.SchemaColumn;
38-
import org.springframework.ai.chat.messages.AssistantMessage;
3922
import org.springframework.ai.chat.messages.Message;
40-
import org.springframework.ai.chat.messages.UserMessage;
4123

4224
/**
25+
* @deprecated Use CassandraChatMemoryRepository
26+
*
4327
* Create a CassandraChatMemory like <code>
4428
CassandraChatMemory.create(CassandraChatMemoryConfig.builder().withTimeToLive(Duration.ofDays(1)).build());
4529
</code>
4630
*
4731
* For example @see org.springframework.ai.chat.memory.cassandra.CassandraChatMemory
48-
*
4932
* @author Mick Semb Wever
5033
* @since 1.0.0
5134
*/
35+
@Deprecated
5236
public final class CassandraChatMemory implements ChatMemory {
5337

54-
public static final String CONVERSATION_TS = CassandraChatMemory.class.getSimpleName() + "_message_timestamp";
55-
5638
final CassandraChatMemoryConfig conf;
5739

58-
private final PreparedStatement addUserStmt;
59-
60-
private final PreparedStatement addAssistantStmt;
61-
62-
private final PreparedStatement getStmt;
63-
64-
private final PreparedStatement deleteStmt;
40+
final CassandraChatMemoryRepository repo;
6541

6642
public CassandraChatMemory(CassandraChatMemoryConfig config) {
6743
this.conf = config;
68-
this.conf.ensureSchemaExists();
69-
this.addUserStmt = prepareAddStmt(this.conf.userColumn);
70-
this.addAssistantStmt = prepareAddStmt(this.conf.assistantColumn);
71-
this.getStmt = prepareGetStatement();
72-
this.deleteStmt = prepareDeleteStmt();
44+
repo = CassandraChatMemoryRepository.create(conf);
7345
}
7446

7547
public static CassandraChatMemory create(CassandraChatMemoryConfig conf) {
@@ -78,128 +50,22 @@ public static CassandraChatMemory create(CassandraChatMemoryConfig conf) {
7850

7951
@Override
8052
public void add(String conversationId, List<Message> messages) {
81-
final AtomicLong instantSeq = new AtomicLong(Instant.now().toEpochMilli());
82-
messages.forEach(msg -> {
83-
if (msg.getMetadata().containsKey(CONVERSATION_TS)) {
84-
msg.getMetadata().put(CONVERSATION_TS, Instant.ofEpochMilli(instantSeq.getAndIncrement()));
85-
}
86-
add(conversationId, msg);
87-
});
53+
repo.saveAll(conversationId, messages);
8854
}
8955

9056
@Override
9157
public void add(String sessionId, Message msg) {
92-
93-
Preconditions.checkArgument(
94-
!msg.getMetadata().containsKey(CONVERSATION_TS)
95-
|| msg.getMetadata().get(CONVERSATION_TS) instanceof Instant,
96-
"messages only accept metadata '%s' entries of type Instant", CONVERSATION_TS);
97-
98-
msg.getMetadata().putIfAbsent(CONVERSATION_TS, Instant.now());
99-
100-
PreparedStatement stmt = getStatement(msg);
101-
102-
List<Object> primaryKeys = this.conf.primaryKeyTranslator.apply(sessionId);
103-
BoundStatementBuilder builder = stmt.boundStatementBuilder();
104-
105-
for (int k = 0; k < primaryKeys.size(); ++k) {
106-
SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
107-
builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
108-
}
109-
110-
Instant instant = (Instant) msg.getMetadata().get(CONVERSATION_TS);
111-
112-
builder = builder.setInstant(CassandraChatMemoryConfig.DEFAULT_EXCHANGE_ID_NAME, instant)
113-
.setString("message", msg.getText());
114-
115-
this.conf.session.execute(builder.build());
116-
}
117-
118-
PreparedStatement getStatement(Message msg) {
119-
return switch (msg.getMessageType()) {
120-
case USER -> this.addUserStmt;
121-
case ASSISTANT -> this.addAssistantStmt;
122-
default -> throw new IllegalArgumentException("Cant add type " + msg);
123-
};
58+
repo.save(sessionId, msg);
12459
}
12560

12661
@Override
12762
public void clear(String sessionId) {
128-
129-
List<Object> primaryKeys = this.conf.primaryKeyTranslator.apply(sessionId);
130-
BoundStatementBuilder builder = this.deleteStmt.boundStatementBuilder();
131-
132-
for (int k = 0; k < primaryKeys.size(); ++k) {
133-
SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
134-
builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
135-
}
136-
137-
this.conf.session.execute(builder.build());
63+
repo.deleteByConversationId(sessionId);
13864
}
13965

14066
@Override
14167
public List<Message> get(String sessionId, int lastN) {
142-
143-
List<Object> primaryKeys = this.conf.primaryKeyTranslator.apply(sessionId);
144-
BoundStatementBuilder builder = this.getStmt.boundStatementBuilder().setInt("lastN", lastN);
145-
146-
for (int k = 0; k < primaryKeys.size(); ++k) {
147-
SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
148-
builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
149-
}
150-
151-
List<Message> messages = new ArrayList<>();
152-
for (Row r : this.conf.session.execute(builder.build())) {
153-
String assistant = r.getString(this.conf.assistantColumn);
154-
String user = r.getString(this.conf.userColumn);
155-
if (null != assistant) {
156-
messages.add(new AssistantMessage(assistant));
157-
}
158-
if (null != user) {
159-
messages.add(new UserMessage(user));
160-
}
161-
}
162-
Collections.reverse(messages);
163-
return messages;
164-
}
165-
166-
private PreparedStatement prepareAddStmt(String column) {
167-
RegularInsert stmt = null;
168-
InsertInto stmtStart = QueryBuilder.insertInto(this.conf.schema.keyspace(), this.conf.schema.table());
169-
for (var c : this.conf.schema.partitionKeys()) {
170-
stmt = (null != stmt ? stmt : stmtStart).value(c.name(), QueryBuilder.bindMarker(c.name()));
171-
}
172-
for (var c : this.conf.schema.clusteringKeys()) {
173-
stmt = stmt.value(c.name(), QueryBuilder.bindMarker(c.name()));
174-
}
175-
stmt = stmt.value(column, QueryBuilder.bindMarker("message"));
176-
return this.conf.session.prepare(stmt.build());
177-
}
178-
179-
private PreparedStatement prepareGetStatement() {
180-
Select stmt = QueryBuilder.selectFrom(this.conf.schema.keyspace(), this.conf.schema.table()).all();
181-
for (var c : this.conf.schema.partitionKeys()) {
182-
stmt = stmt.whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name()));
183-
}
184-
for (int i = 0; i + 1 < this.conf.schema.clusteringKeys().size(); ++i) {
185-
String columnName = this.conf.schema.clusteringKeys().get(i).name();
186-
stmt = stmt.whereColumn(columnName).isEqualTo(QueryBuilder.bindMarker(columnName));
187-
}
188-
stmt = stmt.limit(QueryBuilder.bindMarker("lastN"));
189-
return this.conf.session.prepare(stmt.build());
190-
}
191-
192-
private PreparedStatement prepareDeleteStmt() {
193-
Delete stmt = null;
194-
DeleteSelection stmtStart = QueryBuilder.deleteFrom(this.conf.schema.keyspace(), this.conf.schema.table());
195-
for (var c : this.conf.schema.partitionKeys()) {
196-
stmt = (null != stmt ? stmt : stmtStart).whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name()));
197-
}
198-
for (int i = 0; i + 1 < this.conf.schema.clusteringKeys().size(); ++i) {
199-
String columnName = this.conf.schema.clusteringKeys().get(i).name();
200-
stmt = stmt.whereColumn(columnName).isEqualTo(QueryBuilder.bindMarker(columnName));
201-
}
202-
return this.conf.session.prepare(stmt.build());
68+
return repo.findByConversationId(sessionId).subList(0, lastN);
20369
}
20470

20571
}

0 commit comments

Comments
 (0)