18
18
19
19
import java .time .Instant ;
20
20
import java .util .ArrayList ;
21
- import java .util .Collections ;
22
21
import java .util .List ;
23
- import java .util .concurrent . atomic . AtomicLong ;
22
+ import java .util .Map ;
24
23
25
24
import com .datastax .oss .driver .api .core .cql .BoundStatement ;
26
25
import com .datastax .oss .driver .api .core .cql .BoundStatementBuilder ;
27
26
import com .datastax .oss .driver .api .core .cql .PreparedStatement ;
28
27
import com .datastax .oss .driver .api .core .cql .Row ;
28
+ import com .datastax .oss .driver .api .core .data .UdtValue ;
29
29
import com .datastax .oss .driver .api .querybuilder .QueryBuilder ;
30
- import com .datastax .oss .driver .api .querybuilder .delete .Delete ;
31
- import com .datastax .oss .driver .api .querybuilder .delete .DeleteSelection ;
32
30
import com .datastax .oss .driver .api .querybuilder .insert .InsertInto ;
33
31
import com .datastax .oss .driver .api .querybuilder .insert .RegularInsert ;
34
32
import com .datastax .oss .driver .api .querybuilder .select .Select ;
35
33
import com .datastax .oss .driver .shaded .guava .common .base .Preconditions ;
34
+
36
35
import org .springframework .ai .chat .memory .ChatMemoryRepository ;
37
36
import org .springframework .ai .chat .messages .AssistantMessage ;
38
37
import org .springframework .ai .chat .messages .Message ;
38
+ import org .springframework .ai .chat .messages .MessageType ;
39
+ import org .springframework .ai .chat .messages .SystemMessage ;
40
+ import org .springframework .ai .chat .messages .ToolResponseMessage ;
39
41
import org .springframework .ai .chat .messages .UserMessage ;
40
42
import org .springframework .util .Assert ;
41
43
@@ -54,23 +56,17 @@ public class CassandraChatMemoryRepository implements ChatMemoryRepository {
54
56
55
57
private final PreparedStatement allStmt ;
56
58
57
- private final PreparedStatement addUserStmt ;
58
-
59
- private final PreparedStatement addAssistantStmt ;
59
+ private final PreparedStatement addStmt ;
60
60
61
61
private final PreparedStatement getStmt ;
62
62
63
- private final PreparedStatement deleteStmt ;
64
-
65
63
private CassandraChatMemoryRepository (CassandraChatMemoryRepositoryConfig conf ) {
66
64
Assert .notNull (conf , "conf cannot be null" );
67
65
this .conf = conf ;
68
66
this .conf .ensureSchemaExists ();
69
67
this .allStmt = prepareAllStatement ();
70
- this .addUserStmt = prepareAddStmt (this .conf .userColumn );
71
- this .addAssistantStmt = prepareAddStmt (this .conf .assistantColumn );
68
+ this .addStmt = prepareAddStmt ();
72
69
this .getStmt = prepareGetStatement ();
73
- this .deleteStmt = prepareDeleteStmt ();
74
70
}
75
71
76
72
public static CassandraChatMemoryRepository create (CassandraChatMemoryRepositoryConfig conf ) {
@@ -97,6 +93,10 @@ public List<String> findConversationIds() {
97
93
98
94
@ Override
99
95
public List <Message > findByConversationId (String conversationId ) {
96
+ return findByConversationIdWithLimit (conversationId , 1 );
97
+ }
98
+
99
+ List <Message > findByConversationIdWithLimit (String conversationId , int limit ) {
100
100
Assert .hasText (conversationId , "conversationId cannot be null or empty" );
101
101
102
102
List <Object > primaryKeys = this .conf .primaryKeyTranslator .apply (conversationId );
@@ -106,19 +106,14 @@ public List<Message> findByConversationId(String conversationId) {
106
106
CassandraChatMemoryRepositoryConfig .SchemaColumn keyColumn = this .conf .getPrimaryKeyColumn (k );
107
107
builder = builder .set (keyColumn .name (), primaryKeys .get (k ), keyColumn .javaType ());
108
108
}
109
+ builder = builder .setInt ("legacy_limit" , limit );
109
110
110
111
List <Message > messages = new ArrayList <>();
111
112
for (Row r : this .conf .session .execute (builder .build ())) {
112
- String assistant = r .getString (this .conf .assistantColumn );
113
- String user = r .getString (this .conf .userColumn );
114
- if (null != assistant ) {
115
- messages .add (new AssistantMessage (assistant ));
116
- }
117
- if (null != user ) {
118
- messages .add (new UserMessage (user ));
113
+ for (UdtValue udt : r .getList (this .conf .messagesColumn , UdtValue .class )) {
114
+ messages .add (getMessage (udt ));
119
115
}
120
116
}
121
- Collections .reverse (messages );
122
117
return messages ;
123
118
}
124
119
@@ -128,58 +123,49 @@ public void saveAll(String conversationId, List<Message> messages) {
128
123
Assert .notNull (messages , "messages cannot be null" );
129
124
Assert .noNullElements (messages , "messages cannot contain null elements" );
130
125
131
- final AtomicLong instantSeq = new AtomicLong (Instant .now ().toEpochMilli ());
132
- messages .forEach (msg -> {
133
- if (msg .getMetadata ().containsKey (CONVERSATION_TS )) {
134
- msg .getMetadata ().put (CONVERSATION_TS , Instant .ofEpochMilli (instantSeq .getAndIncrement ()));
135
- }
136
- save (conversationId , msg );
137
- });
138
- }
139
-
140
- void save (String conversationId , Message msg ) {
141
-
142
- Preconditions .checkArgument (
143
- !msg .getMetadata ().containsKey (CONVERSATION_TS )
144
- || msg .getMetadata ().get (CONVERSATION_TS ) instanceof Instant ,
145
- "messages only accept metadata '%s' entries of type Instant" , CONVERSATION_TS );
146
-
147
- msg .getMetadata ().putIfAbsent (CONVERSATION_TS , Instant .now ());
148
-
149
- PreparedStatement stmt = getStatement (msg );
150
-
126
+ Instant instant = Instant .now ();
151
127
List <Object > primaryKeys = this .conf .primaryKeyTranslator .apply (conversationId );
152
- BoundStatementBuilder builder = stmt .boundStatementBuilder ();
128
+ BoundStatementBuilder builder = addStmt .boundStatementBuilder ();
153
129
154
130
for (int k = 0 ; k < primaryKeys .size (); ++k ) {
155
131
CassandraChatMemoryRepositoryConfig .SchemaColumn keyColumn = this .conf .getPrimaryKeyColumn (k );
156
132
builder = builder .set (keyColumn .name (), primaryKeys .get (k ), keyColumn .javaType ());
157
133
}
158
134
159
- Instant instant = (Instant ) msg .getMetadata ().get (CONVERSATION_TS );
135
+ List <UdtValue > msgs = new ArrayList <>();
136
+ for (Message msg : messages ) {
137
+
138
+ Preconditions .checkArgument (
139
+ !msg .getMetadata ().containsKey (CONVERSATION_TS )
140
+ || msg .getMetadata ().get (CONVERSATION_TS ) instanceof Instant ,
141
+ "messages only accept metadata '%s' entries of type Instant" , CONVERSATION_TS );
160
142
143
+ msg .getMetadata ().putIfAbsent (CONVERSATION_TS , instant );
144
+
145
+ UdtValue udt = this .conf .session .getMetadata ()
146
+ .getKeyspace (this .conf .schema .keyspace ())
147
+ .get ()
148
+ .getUserDefinedType (this .conf .messageUDT )
149
+ .get ()
150
+ .newValue ()
151
+ .setInstant (this .conf .messageUdtTimestampColumn , (Instant ) msg .getMetadata ().get (CONVERSATION_TS ))
152
+ .setString (this .conf .messageUdtTypeColumn , msg .getMessageType ().name ())
153
+ .setString (this .conf .messageUdtContentColumn , msg .getText ());
154
+
155
+ msgs .add (udt );
156
+ }
161
157
builder = builder .setInstant (CassandraChatMemoryRepositoryConfig .DEFAULT_EXCHANGE_ID_NAME , instant )
162
- .setString ( "message " , msg . getText () );
158
+ .setList ( "msgs " , msgs , UdtValue . class );
163
159
164
160
this .conf .session .execute (builder .build ());
165
161
}
166
162
167
163
@ Override
168
164
public void deleteByConversationId (String conversationId ) {
169
- Assert .hasText (conversationId , "conversationId cannot be null or empty" );
170
-
171
- List <Object > primaryKeys = this .conf .primaryKeyTranslator .apply (conversationId );
172
- BoundStatementBuilder builder = this .deleteStmt .boundStatementBuilder ();
173
-
174
- for (int k = 0 ; k < primaryKeys .size (); ++k ) {
175
- CassandraChatMemoryRepositoryConfig .SchemaColumn keyColumn = this .conf .getPrimaryKeyColumn (k );
176
- builder = builder .set (keyColumn .name (), primaryKeys .get (k ), keyColumn .javaType ());
177
- }
178
-
179
- this .conf .session .execute (builder .build ());
165
+ saveAll (conversationId , List .of ());
180
166
}
181
167
182
- private PreparedStatement prepareAddStmt (String column ) {
168
+ private PreparedStatement prepareAddStmt () {
183
169
RegularInsert stmt = null ;
184
170
InsertInto stmtStart = QueryBuilder .insertInto (this .conf .schema .keyspace (), this .conf .schema .table ());
185
171
for (var c : this .conf .schema .partitionKeys ()) {
@@ -188,7 +174,7 @@ private PreparedStatement prepareAddStmt(String column) {
188
174
for (var c : this .conf .schema .clusteringKeys ()) {
189
175
stmt = stmt .value (c .name (), QueryBuilder .bindMarker (c .name ()));
190
176
}
191
- stmt = stmt .value (column , QueryBuilder .bindMarker ("message " ));
177
+ stmt = stmt .value (this . conf . messagesColumn , QueryBuilder .bindMarker ("msgs " ));
192
178
return this .conf .session .prepare (stmt .build ());
193
179
}
194
180
@@ -214,28 +200,27 @@ private PreparedStatement prepareGetStatement() {
214
200
String columnName = this .conf .schema .clusteringKeys ().get (i ).name ();
215
201
stmt = stmt .whereColumn (columnName ).isEqualTo (QueryBuilder .bindMarker (columnName ));
216
202
}
203
+ stmt = stmt .limit (QueryBuilder .bindMarker ("legacy_limit" ));
217
204
return this .conf .session .prepare (stmt .build ());
218
205
}
219
206
220
- private PreparedStatement prepareDeleteStmt () {
221
- Delete stmt = null ;
222
- DeleteSelection stmtStart = QueryBuilder .deleteFrom (this .conf .schema .keyspace (), this .conf .schema .table ());
223
- for (var c : this .conf .schema .partitionKeys ()) {
224
- stmt = (null != stmt ? stmt : stmtStart ).whereColumn (c .name ()).isEqualTo (QueryBuilder .bindMarker (c .name ()));
225
- }
226
- for (int i = 0 ; i + 1 < this .conf .schema .clusteringKeys ().size (); ++i ) {
227
- String columnName = this .conf .schema .clusteringKeys ().get (i ).name ();
228
- stmt = stmt .whereColumn (columnName ).isEqualTo (QueryBuilder .bindMarker (columnName ));
207
+ private Message getMessage (UdtValue udt ) {
208
+ String content = udt .getString (this .conf .messageUdtContentColumn );
209
+ Map <String , Object > props = Map .of (CONVERSATION_TS , udt .getInstant (this .conf .messageUdtTimestampColumn ));
210
+ switch (MessageType .valueOf (udt .getString (this .conf .messageUdtTypeColumn ))) {
211
+ case ASSISTANT :
212
+ return new AssistantMessage (content , props );
213
+ case USER :
214
+ return UserMessage .builder ().text (content ).metadata (props ).build ();
215
+ case SYSTEM :
216
+ return SystemMessage .builder ().text (content ).metadata (props ).build ();
217
+ case TOOL :
218
+ // todo – persist ToolResponse somehow
219
+ return new ToolResponseMessage (List .of (), props );
220
+ default :
221
+ throw new IllegalStateException (
222
+ String .format ("unknown message type %s" , udt .getString (this .conf .messageUdtTypeColumn )));
229
223
}
230
- return this .conf .session .prepare (stmt .build ());
231
- }
232
-
233
- private PreparedStatement getStatement (Message msg ) {
234
- return switch (msg .getMessageType ()) {
235
- case USER -> this .addUserStmt ;
236
- case ASSISTANT -> this .addAssistantStmt ;
237
- default -> throw new IllegalArgumentException ("Cant add type " + msg );
238
- };
239
224
}
240
225
241
226
}
0 commit comments