16
16
17
17
package org .springframework .ai .chat .memory .cassandra ;
18
18
19
- import java .time .Instant ;
20
- import java .util .ArrayList ;
21
- import java .util .Collections ;
22
19
import java .util .List ;
23
- import java .util .concurrent .atomic .AtomicLong ;
24
-
25
- import com .datastax .oss .driver .api .core .cql .BoundStatementBuilder ;
26
- import com .datastax .oss .driver .api .core .cql .PreparedStatement ;
27
- import com .datastax .oss .driver .api .core .cql .Row ;
28
- import com .datastax .oss .driver .api .querybuilder .QueryBuilder ;
29
- import com .datastax .oss .driver .api .querybuilder .delete .Delete ;
30
- import com .datastax .oss .driver .api .querybuilder .delete .DeleteSelection ;
31
- import com .datastax .oss .driver .api .querybuilder .insert .InsertInto ;
32
- import com .datastax .oss .driver .api .querybuilder .insert .RegularInsert ;
33
- import com .datastax .oss .driver .api .querybuilder .select .Select ;
34
- import com .datastax .oss .driver .shaded .guava .common .base .Preconditions ;
35
20
36
21
import org .springframework .ai .chat .memory .ChatMemory ;
37
- import org .springframework .ai .chat .memory .cassandra .CassandraChatMemoryConfig .SchemaColumn ;
38
- import org .springframework .ai .chat .messages .AssistantMessage ;
39
22
import org .springframework .ai .chat .messages .Message ;
40
- import org .springframework .ai .chat .messages .UserMessage ;
41
23
42
24
/**
25
+ * @deprecated Use CassandraChatMemoryRepository
26
+ *
43
27
* Create a CassandraChatMemory like <code>
44
28
CassandraChatMemory.create(CassandraChatMemoryConfig.builder().withTimeToLive(Duration.ofDays(1)).build());
45
29
</code>
46
30
*
47
31
* For example @see org.springframework.ai.chat.memory.cassandra.CassandraChatMemory
48
- *
49
32
* @author Mick Semb Wever
50
33
* @since 1.0.0
51
34
*/
35
+ @ Deprecated
52
36
public final class CassandraChatMemory implements ChatMemory {
53
37
54
- public static final String CONVERSATION_TS = CassandraChatMemory .class .getSimpleName () + "_message_timestamp" ;
55
-
56
38
final CassandraChatMemoryConfig conf ;
57
39
58
- private final PreparedStatement addUserStmt ;
59
-
60
- private final PreparedStatement addAssistantStmt ;
61
-
62
- private final PreparedStatement getStmt ;
63
-
64
- private final PreparedStatement deleteStmt ;
40
+ final CassandraChatMemoryRepository repo ;
65
41
66
42
public CassandraChatMemory (CassandraChatMemoryConfig config ) {
67
43
this .conf = config ;
68
- this .conf .ensureSchemaExists ();
69
- this .addUserStmt = prepareAddStmt (this .conf .userColumn );
70
- this .addAssistantStmt = prepareAddStmt (this .conf .assistantColumn );
71
- this .getStmt = prepareGetStatement ();
72
- this .deleteStmt = prepareDeleteStmt ();
44
+ repo = CassandraChatMemoryRepository .create (conf );
73
45
}
74
46
75
47
public static CassandraChatMemory create (CassandraChatMemoryConfig conf ) {
@@ -78,128 +50,22 @@ public static CassandraChatMemory create(CassandraChatMemoryConfig conf) {
78
50
79
51
@ Override
80
52
public void add (String conversationId , List <Message > messages ) {
81
- final AtomicLong instantSeq = new AtomicLong (Instant .now ().toEpochMilli ());
82
- messages .forEach (msg -> {
83
- if (msg .getMetadata ().containsKey (CONVERSATION_TS )) {
84
- msg .getMetadata ().put (CONVERSATION_TS , Instant .ofEpochMilli (instantSeq .getAndIncrement ()));
85
- }
86
- add (conversationId , msg );
87
- });
53
+ repo .saveAll (conversationId , messages );
88
54
}
89
55
90
56
@ Override
91
57
public void add (String sessionId , Message msg ) {
92
-
93
- Preconditions .checkArgument (
94
- !msg .getMetadata ().containsKey (CONVERSATION_TS )
95
- || msg .getMetadata ().get (CONVERSATION_TS ) instanceof Instant ,
96
- "messages only accept metadata '%s' entries of type Instant" , CONVERSATION_TS );
97
-
98
- msg .getMetadata ().putIfAbsent (CONVERSATION_TS , Instant .now ());
99
-
100
- PreparedStatement stmt = getStatement (msg );
101
-
102
- List <Object > primaryKeys = this .conf .primaryKeyTranslator .apply (sessionId );
103
- BoundStatementBuilder builder = stmt .boundStatementBuilder ();
104
-
105
- for (int k = 0 ; k < primaryKeys .size (); ++k ) {
106
- SchemaColumn keyColumn = this .conf .getPrimaryKeyColumn (k );
107
- builder = builder .set (keyColumn .name (), primaryKeys .get (k ), keyColumn .javaType ());
108
- }
109
-
110
- Instant instant = (Instant ) msg .getMetadata ().get (CONVERSATION_TS );
111
-
112
- builder = builder .setInstant (CassandraChatMemoryConfig .DEFAULT_EXCHANGE_ID_NAME , instant )
113
- .setString ("message" , msg .getText ());
114
-
115
- this .conf .session .execute (builder .build ());
116
- }
117
-
118
- PreparedStatement getStatement (Message msg ) {
119
- return switch (msg .getMessageType ()) {
120
- case USER -> this .addUserStmt ;
121
- case ASSISTANT -> this .addAssistantStmt ;
122
- default -> throw new IllegalArgumentException ("Cant add type " + msg );
123
- };
58
+ repo .save (sessionId , msg );
124
59
}
125
60
126
61
@ Override
127
62
public void clear (String sessionId ) {
128
-
129
- List <Object > primaryKeys = this .conf .primaryKeyTranslator .apply (sessionId );
130
- BoundStatementBuilder builder = this .deleteStmt .boundStatementBuilder ();
131
-
132
- for (int k = 0 ; k < primaryKeys .size (); ++k ) {
133
- SchemaColumn keyColumn = this .conf .getPrimaryKeyColumn (k );
134
- builder = builder .set (keyColumn .name (), primaryKeys .get (k ), keyColumn .javaType ());
135
- }
136
-
137
- this .conf .session .execute (builder .build ());
63
+ repo .deleteByConversationId (sessionId );
138
64
}
139
65
140
66
@ Override
141
67
public List <Message > get (String sessionId , int lastN ) {
142
-
143
- List <Object > primaryKeys = this .conf .primaryKeyTranslator .apply (sessionId );
144
- BoundStatementBuilder builder = this .getStmt .boundStatementBuilder ().setInt ("lastN" , lastN );
145
-
146
- for (int k = 0 ; k < primaryKeys .size (); ++k ) {
147
- SchemaColumn keyColumn = this .conf .getPrimaryKeyColumn (k );
148
- builder = builder .set (keyColumn .name (), primaryKeys .get (k ), keyColumn .javaType ());
149
- }
150
-
151
- List <Message > messages = new ArrayList <>();
152
- for (Row r : this .conf .session .execute (builder .build ())) {
153
- String assistant = r .getString (this .conf .assistantColumn );
154
- String user = r .getString (this .conf .userColumn );
155
- if (null != assistant ) {
156
- messages .add (new AssistantMessage (assistant ));
157
- }
158
- if (null != user ) {
159
- messages .add (new UserMessage (user ));
160
- }
161
- }
162
- Collections .reverse (messages );
163
- return messages ;
164
- }
165
-
166
- private PreparedStatement prepareAddStmt (String column ) {
167
- RegularInsert stmt = null ;
168
- InsertInto stmtStart = QueryBuilder .insertInto (this .conf .schema .keyspace (), this .conf .schema .table ());
169
- for (var c : this .conf .schema .partitionKeys ()) {
170
- stmt = (null != stmt ? stmt : stmtStart ).value (c .name (), QueryBuilder .bindMarker (c .name ()));
171
- }
172
- for (var c : this .conf .schema .clusteringKeys ()) {
173
- stmt = stmt .value (c .name (), QueryBuilder .bindMarker (c .name ()));
174
- }
175
- stmt = stmt .value (column , QueryBuilder .bindMarker ("message" ));
176
- return this .conf .session .prepare (stmt .build ());
177
- }
178
-
179
- private PreparedStatement prepareGetStatement () {
180
- Select stmt = QueryBuilder .selectFrom (this .conf .schema .keyspace (), this .conf .schema .table ()).all ();
181
- for (var c : this .conf .schema .partitionKeys ()) {
182
- stmt = stmt .whereColumn (c .name ()).isEqualTo (QueryBuilder .bindMarker (c .name ()));
183
- }
184
- for (int i = 0 ; i + 1 < this .conf .schema .clusteringKeys ().size (); ++i ) {
185
- String columnName = this .conf .schema .clusteringKeys ().get (i ).name ();
186
- stmt = stmt .whereColumn (columnName ).isEqualTo (QueryBuilder .bindMarker (columnName ));
187
- }
188
- stmt = stmt .limit (QueryBuilder .bindMarker ("lastN" ));
189
- return this .conf .session .prepare (stmt .build ());
190
- }
191
-
192
- private PreparedStatement prepareDeleteStmt () {
193
- Delete stmt = null ;
194
- DeleteSelection stmtStart = QueryBuilder .deleteFrom (this .conf .schema .keyspace (), this .conf .schema .table ());
195
- for (var c : this .conf .schema .partitionKeys ()) {
196
- stmt = (null != stmt ? stmt : stmtStart ).whereColumn (c .name ()).isEqualTo (QueryBuilder .bindMarker (c .name ()));
197
- }
198
- for (int i = 0 ; i + 1 < this .conf .schema .clusteringKeys ().size (); ++i ) {
199
- String columnName = this .conf .schema .clusteringKeys ().get (i ).name ();
200
- stmt = stmt .whereColumn (columnName ).isEqualTo (QueryBuilder .bindMarker (columnName ));
201
- }
202
- return this .conf .session .prepare (stmt .build ());
68
+ return repo .findByConversationId (sessionId ).subList (0 , lastN );
203
69
}
204
70
205
71
}
0 commit comments