1
1
package org .springframework .ai .chat .memory .neo4j ;
2
2
3
- import org .neo4j .driver .Result ;
4
3
import org .neo4j .driver .Session ;
5
4
import org .neo4j .driver .Transaction ;
5
+ import org .neo4j .driver .TransactionContext ;
6
6
import org .springframework .ai .chat .memory .ChatMemoryRepository ;
7
7
import org .springframework .ai .chat .messages .*;
8
8
import org .springframework .ai .content .Media ;
11
11
12
12
import java .net .URI ;
13
13
import java .util .*;
14
+ import java .util .stream .Collectors ;
14
15
15
16
/**
16
17
* An implementation of {@link ChatMemoryRepository} for Neo4J
17
18
*
18
19
* @author Enrico Rampazzo
20
+ * @author Michael J. Simons
19
21
* @since 1.0.0
20
22
*/
21
23
22
- public class Neo4jChatMemoryRepository implements ChatMemoryRepository {
24
+ public final class Neo4jChatMemoryRepository implements ChatMemoryRepository {
23
25
24
26
private final Neo4jChatMemoryConfig config ;
25
27
@@ -29,63 +31,65 @@ public Neo4jChatMemoryRepository(Neo4jChatMemoryConfig config) {
29
31
30
32
@ Override
31
33
public List <String > findConversationIds () {
32
- try (var session = config .getDriver ().session ()) {
33
- return session .run ("MATCH (conversation:%s) RETURN conversation.id" .formatted (config .getSessionLabel ()))
34
- .stream ()
35
- .map (r -> r .get ("conversation.id" ).asString ())
36
- .toList ();
37
- }
34
+ return config .getDriver ()
35
+ .executableQuery ("MATCH (conversation:$($sessionLabel)) RETURN conversation.id" )
36
+ .withParameters (Map .of ("sessionLabel" , config .getSessionLabel ()))
37
+ .execute (Collectors .mapping (r -> r .get ("conversation.id" ).asString (), Collectors .toList ()));
38
38
}
39
39
40
40
@ Override
41
41
public List <Message > findByConversationId (String conversationId ) {
42
- String statementBuilder = """
43
- MATCH (s:%s {id:$conversationId})-[r:HAS_MESSAGE]->(m:%s )
42
+ String statement = """
43
+ MATCH (s:$($sessionLabel) {id:$conversationId})-[r:HAS_MESSAGE]->(m:$($messageLabel) )
44
44
WITH m
45
- OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:%s )
46
- OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:%s ) WITH m, metadata, media ORDER BY media.idx ASC
47
- OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:%s ) WITH m, metadata, media, tr ORDER BY tr.idx ASC
48
- OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:%s )
45
+ OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:$($metadataLabel) )
46
+ OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:$($mediaLabel) ) WITH m, metadata, media ORDER BY media.idx ASC
47
+ OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:$($toolResponseLabel) ) WITH m, metadata, media, tr ORDER BY tr.idx ASC
48
+ OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:$($toolCallLabel) )
49
49
WITH m, metadata, media, tr, tc ORDER BY tc.idx ASC
50
50
RETURN m, metadata, collect(tr) as toolResponses, collect(tc) as toolCalls, collect(media) as medias
51
51
ORDER BY m.idx ASC
52
- """ .formatted (this .config .getSessionLabel (), this .config .getMessageLabel (),
53
- this .config .getMetadataLabel (), this .config .getMediaLabel (), this .config .getToolResponseLabel (),
54
- this .config .getToolCallLabel ());
55
- Result res = this .config .getDriver ().session ().run (statementBuilder , Map .of ("conversationId" , conversationId ));
56
- return res .stream ().map (record -> {
57
- Map <String , Object > messageMap = record .get ("m" ).asMap ();
58
- String msgType = messageMap .get (MessageAttributes .MESSAGE_TYPE .getValue ()).toString ();
59
- Message message = null ;
60
- List <Media > mediaList = List .of ();
61
- if (!record .get ("medias" ).isNull ()) {
62
- mediaList = getMedia (record );
63
- }
64
- if (msgType .equals (MessageType .USER .getValue ())) {
65
- message = buildUserMessage (record , messageMap , mediaList );
66
- }
67
- if (msgType .equals (MessageType .ASSISTANT .getValue ())) {
68
- message = buildAssistantMessage (record , messageMap , mediaList );
69
- }
70
- if (msgType .equals (MessageType .SYSTEM .getValue ())) {
71
- SystemMessage .Builder systemMessageBuilder = SystemMessage .builder ()
72
- .text (messageMap .get (MessageAttributes .TEXT_CONTENT .getValue ()).toString ());
73
- if (!record .get ("metadata" ).isNull ()) {
74
- Map <String , Object > retrievedMetadata = record .get ("metadata" ).asMap ();
75
- systemMessageBuilder .metadata (retrievedMetadata );
52
+ """ ;
53
+
54
+ return this .config .getDriver ()
55
+ .executableQuery (statement )
56
+ .withParameters (Map .of ("conversationId" , conversationId , "sessionLabel" , this .config .getSessionLabel (),
57
+ "messageLabel" , this .config .getMessageLabel (), "metadataLabel" , this .config .getMetadataLabel (),
58
+ "mediaLabel" , this .config .getMediaLabel (), "toolResponseLabel" , this .config .getToolResponseLabel (),
59
+ "toolCallLabel" , this .config .getToolCallLabel ()))
60
+ .execute (Collectors .mapping (record -> {
61
+ Map <String , Object > messageMap = record .get ("m" ).asMap ();
62
+ String msgType = messageMap .get (MessageAttributes .MESSAGE_TYPE .getValue ()).toString ();
63
+ Message message = null ;
64
+ List <Media > mediaList = List .of ();
65
+ if (!record .get ("medias" ).isNull ()) {
66
+ mediaList = getMedia (record );
76
67
}
77
- message = systemMessageBuilder .build ();
78
- }
79
- if (msgType .equals (MessageType .TOOL .getValue ())) {
80
- message = buildToolMessage (record );
81
- }
82
- if (message == null ) {
83
- throw new IllegalArgumentException ("%s messages are not supported"
84
- .formatted (record .get (MessageAttributes .MESSAGE_TYPE .getValue ()).asString ()));
85
- }
86
- message .getMetadata ().put ("messageType" , message .getMessageType ());
87
- return message ;
88
- }).toList ();
68
+ if (msgType .equals (MessageType .USER .getValue ())) {
69
+ message = buildUserMessage (record , messageMap , mediaList );
70
+ }
71
+ if (msgType .equals (MessageType .ASSISTANT .getValue ())) {
72
+ message = buildAssistantMessage (record , messageMap , mediaList );
73
+ }
74
+ if (msgType .equals (MessageType .SYSTEM .getValue ())) {
75
+ SystemMessage .Builder systemMessageBuilder = SystemMessage .builder ()
76
+ .text (messageMap .get (MessageAttributes .TEXT_CONTENT .getValue ()).toString ());
77
+ if (!record .get ("metadata" ).isNull ()) {
78
+ Map <String , Object > retrievedMetadata = record .get ("metadata" ).asMap ();
79
+ systemMessageBuilder .metadata (retrievedMetadata );
80
+ }
81
+ message = systemMessageBuilder .build ();
82
+ }
83
+ if (msgType .equals (MessageType .TOOL .getValue ())) {
84
+ message = buildToolMessage (record );
85
+ }
86
+ if (message == null ) {
87
+ throw new IllegalArgumentException ("%s messages are not supported"
88
+ .formatted (record .get (MessageAttributes .MESSAGE_TYPE .getValue ()).asString ()));
89
+ }
90
+ message .getMetadata ().put ("messageType" , message .getMessageType ());
91
+ return message ;
92
+ }, Collectors .toList ()));
89
93
90
94
}
91
95
@@ -96,12 +100,11 @@ public void saveAll(String conversationId, List<Message> messages) {
96
100
97
101
// Then add the new messages
98
102
try (Session s = this .config .getDriver ().session ()) {
99
- try ( Transaction t = s . beginTransaction ()) {
103
+ s . executeWriteWithoutResult ( tx -> {
100
104
for (Message m : messages ) {
101
- addMessageToTransaction (t , conversationId , m );
105
+ addMessageToTransaction (tx , conversationId , m );
102
106
}
103
- t .commit ();
104
- }
107
+ });
105
108
}
106
109
}
107
110
@@ -196,42 +199,46 @@ else if (mediaMap.get(MediaAttributes.DATA.getValue()).getClass().isArray()) {
196
199
return mediaList ;
197
200
}
198
201
199
- private void addMessageToTransaction (Transaction t , String conversationId , Message message ) {
202
+ private void addMessageToTransaction (TransactionContext t , String conversationId , Message message ) {
200
203
Map <String , Object > queryParameters = new HashMap <>();
201
204
queryParameters .put ("conversationId" , conversationId );
202
205
StringBuilder statementBuilder = new StringBuilder ();
203
206
statementBuilder .append ("""
204
- MERGE (s:%s {id:$conversationId}) WITH s
205
- OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:%s) WITH coalesce(count(countMsg), 0) as totalMsg, s
206
- CREATE (s)-[:HAS_MESSAGE]->(msg:%s) SET msg = $messageProperties
207
+ MERGE (s:$($sessionLabel) {id:$conversationId}) WITH s
208
+ OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:$($messageLabel))
209
+ WITH coalesce(count(countMsg), 0) as totalMsg, s
210
+ CREATE (s)-[:HAS_MESSAGE]->(msg:$($messageLabel)) SET msg = $messageProperties
207
211
SET msg.idx = totalMsg + 1
208
- """ .formatted (this .config .getSessionLabel (), this .config .getMessageLabel (),
209
- this .config .getMessageLabel ()));
212
+ """ );
210
213
Map <String , Object > attributes = new HashMap <>();
211
214
212
215
attributes .put (MessageAttributes .MESSAGE_TYPE .getValue (), message .getMessageType ().getValue ());
213
216
attributes .put (MessageAttributes .TEXT_CONTENT .getValue (), message .getText ());
214
217
attributes .put ("id" , UUID .randomUUID ().toString ());
215
218
queryParameters .put ("messageProperties" , attributes );
219
+ queryParameters .put ("sessionLabel" , this .config .getSessionLabel ());
220
+ queryParameters .put ("messageLabel" , this .config .getMessageLabel ());
216
221
217
222
if (!Optional .ofNullable (message .getMetadata ()).orElse (Map .of ()).isEmpty ()) {
218
223
statementBuilder .append ("""
219
224
WITH msg
220
- CREATE (metadataNode:%s )
225
+ CREATE (metadataNode:$($metadataLabel) )
221
226
CREATE (msg)-[:HAS_METADATA]->(metadataNode)
222
227
SET metadataNode = $metadata
223
- """ . formatted ( this . config . getMetadataLabel ()) );
228
+ """ );
224
229
Map <String , Object > metadataCopy = new HashMap <>(message .getMetadata ());
225
230
metadataCopy .remove ("messageType" );
226
231
queryParameters .put ("metadata" , metadataCopy );
232
+ queryParameters .put ("metadataLabel" , this .config .getMetadataLabel ());
227
233
}
228
234
if (message instanceof AssistantMessage assistantMessage ) {
229
235
if (assistantMessage .hasToolCalls ()) {
230
236
statementBuilder .append ("""
231
237
WITH msg
232
- FOREACH(tc in $toolCalls | CREATE (toolCall:%s ) SET toolCall = tc
238
+ FOREACH(tc in $toolCalls | CREATE (toolCall:$($toolLabel) ) SET toolCall = tc
233
239
CREATE (msg)-[:HAS_TOOL_CALL]->(toolCall))
234
- """ .formatted (this .config .getToolCallLabel ()));
240
+ """ );
241
+ queryParameters .put ("toolLabel" , this .config .getToolCallLabel ());
235
242
List <Map <String , Object >> toolCallMaps = new ArrayList <>();
236
243
for (int i = 0 ; i < assistantMessage .getToolCalls ().size (); i ++) {
237
244
AssistantMessage .ToolCall tc = assistantMessage .getToolCalls ().get (i );
@@ -256,21 +263,23 @@ OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:%s) WITH coalesce(count(countMsg),
256
263
}
257
264
statementBuilder .append ("""
258
265
WITH msg
259
- FOREACH(tr IN $toolResponses | CREATE (tm:%s )
266
+ FOREACH(tr IN $toolResponses | CREATE (tm:$($toolResponseLabel) )
260
267
SET tm = tr
261
268
MERGE (msg)-[:HAS_TOOL_RESPONSE]->(tm))
262
- """ . formatted ( this . config . getToolResponseLabel ()) );
269
+ """ );
263
270
queryParameters .put ("toolResponses" , toolResponseMaps );
271
+ queryParameters .put ("toolResponseLabel" , this .config .getToolResponseLabel ());
264
272
}
265
273
if (message instanceof MediaContent messageWithMedia && !messageWithMedia .getMedia ().isEmpty ()) {
266
274
List <Map <String , Object >> mediaNodes = convertMediaToMap (messageWithMedia .getMedia ());
267
275
statementBuilder .append ("""
268
276
WITH msg
269
277
UNWIND $media AS m
270
- CREATE (media:%s ) SET media = m
278
+ CREATE (media:$($mediaLabel) ) SET media = m
271
279
WITH msg, media CREATE (msg)-[:HAS_MEDIA]->(media)
272
- """ . formatted ( this . config . getMediaLabel ()) );
280
+ """ );
273
281
queryParameters .put ("media" , mediaNodes );
282
+ queryParameters .put ("mediaLabel" , this .config .getMediaLabel ());
274
283
}
275
284
t .run (statementBuilder .toString (), queryParameters );
276
285
}
0 commit comments