Skip to content

Commit 1c06d7c

Browse files
michael-simonsilayaperumalg
authored andcommitted
Polish Neo4jChatMemoryRepository.
This change turns all the labels into parameters, avoiding the possibility of Cypher injection as the config does not do any sanitization. In addition, the interaction with the driver is changed so that it uses transactional functions, which are retried when any communication with the Neo4j DBMS fails. We can do this here as the repository is not subject to application wide transactions. An alternative to the parameters for labels would be using Cypher-DSL as we did in other parts of this project to sanitize labels proper. Signed-off-by: Michael Simons <[email protected]>
1 parent f2df87b commit 1c06d7c

File tree

1 file changed

+77
-68
lines changed

1 file changed

+77
-68
lines changed

memory/spring-ai-model-chat-memory-neo4j/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryRepository.java

Lines changed: 77 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
package org.springframework.ai.chat.memory.neo4j;
22

3-
import org.neo4j.driver.Result;
43
import org.neo4j.driver.Session;
54
import org.neo4j.driver.Transaction;
5+
import org.neo4j.driver.TransactionContext;
66
import org.springframework.ai.chat.memory.ChatMemoryRepository;
77
import org.springframework.ai.chat.messages.*;
88
import org.springframework.ai.content.Media;
@@ -11,15 +11,17 @@
1111

1212
import java.net.URI;
1313
import java.util.*;
14+
import java.util.stream.Collectors;
1415

1516
/**
1617
* An implementation of {@link ChatMemoryRepository} for Neo4J
1718
*
1819
* @author Enrico Rampazzo
20+
* @author Michael J. Simons
1921
* @since 1.0.0
2022
*/
2123

22-
public class Neo4jChatMemoryRepository implements ChatMemoryRepository {
24+
public final class Neo4jChatMemoryRepository implements ChatMemoryRepository {
2325

2426
private final Neo4jChatMemoryConfig config;
2527

@@ -29,63 +31,65 @@ public Neo4jChatMemoryRepository(Neo4jChatMemoryConfig config) {
2931

3032
@Override
3133
public List<String> findConversationIds() {
32-
try (var session = config.getDriver().session()) {
33-
return session.run("MATCH (conversation:%s) RETURN conversation.id".formatted(config.getSessionLabel()))
34-
.stream()
35-
.map(r -> r.get("conversation.id").asString())
36-
.toList();
37-
}
34+
return config.getDriver()
35+
.executableQuery("MATCH (conversation:$($sessionLabel)) RETURN conversation.id")
36+
.withParameters(Map.of("sessionLabel", config.getSessionLabel()))
37+
.execute(Collectors.mapping(r -> r.get("conversation.id").asString(), Collectors.toList()));
3838
}
3939

4040
@Override
4141
public List<Message> findByConversationId(String conversationId) {
42-
String statementBuilder = """
43-
MATCH (s:%s {id:$conversationId})-[r:HAS_MESSAGE]->(m:%s)
42+
String statement = """
43+
MATCH (s:$($sessionLabel) {id:$conversationId})-[r:HAS_MESSAGE]->(m:$($messageLabel))
4444
WITH m
45-
OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:%s)
46-
OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:%s) WITH m, metadata, media ORDER BY media.idx ASC
47-
OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:%s) WITH m, metadata, media, tr ORDER BY tr.idx ASC
48-
OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:%s)
45+
OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:$($metadataLabel))
46+
OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:$($mediaLabel)) WITH m, metadata, media ORDER BY media.idx ASC
47+
OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:$($toolResponseLabel)) WITH m, metadata, media, tr ORDER BY tr.idx ASC
48+
OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:$($toolCallLabel))
4949
WITH m, metadata, media, tr, tc ORDER BY tc.idx ASC
5050
RETURN m, metadata, collect(tr) as toolResponses, collect(tc) as toolCalls, collect(media) as medias
5151
ORDER BY m.idx ASC
52-
""".formatted(this.config.getSessionLabel(), this.config.getMessageLabel(),
53-
this.config.getMetadataLabel(), this.config.getMediaLabel(), this.config.getToolResponseLabel(),
54-
this.config.getToolCallLabel());
55-
Result res = this.config.getDriver().session().run(statementBuilder, Map.of("conversationId", conversationId));
56-
return res.stream().map(record -> {
57-
Map<String, Object> messageMap = record.get("m").asMap();
58-
String msgType = messageMap.get(MessageAttributes.MESSAGE_TYPE.getValue()).toString();
59-
Message message = null;
60-
List<Media> mediaList = List.of();
61-
if (!record.get("medias").isNull()) {
62-
mediaList = getMedia(record);
63-
}
64-
if (msgType.equals(MessageType.USER.getValue())) {
65-
message = buildUserMessage(record, messageMap, mediaList);
66-
}
67-
if (msgType.equals(MessageType.ASSISTANT.getValue())) {
68-
message = buildAssistantMessage(record, messageMap, mediaList);
69-
}
70-
if (msgType.equals(MessageType.SYSTEM.getValue())) {
71-
SystemMessage.Builder systemMessageBuilder = SystemMessage.builder()
72-
.text(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString());
73-
if (!record.get("metadata").isNull()) {
74-
Map<String, Object> retrievedMetadata = record.get("metadata").asMap();
75-
systemMessageBuilder.metadata(retrievedMetadata);
52+
""";
53+
54+
return this.config.getDriver()
55+
.executableQuery(statement)
56+
.withParameters(Map.of("conversationId", conversationId, "sessionLabel", this.config.getSessionLabel(),
57+
"messageLabel", this.config.getMessageLabel(), "metadataLabel", this.config.getMetadataLabel(),
58+
"mediaLabel", this.config.getMediaLabel(), "toolResponseLabel", this.config.getToolResponseLabel(),
59+
"toolCallLabel", this.config.getToolCallLabel()))
60+
.execute(Collectors.mapping(record -> {
61+
Map<String, Object> messageMap = record.get("m").asMap();
62+
String msgType = messageMap.get(MessageAttributes.MESSAGE_TYPE.getValue()).toString();
63+
Message message = null;
64+
List<Media> mediaList = List.of();
65+
if (!record.get("medias").isNull()) {
66+
mediaList = getMedia(record);
7667
}
77-
message = systemMessageBuilder.build();
78-
}
79-
if (msgType.equals(MessageType.TOOL.getValue())) {
80-
message = buildToolMessage(record);
81-
}
82-
if (message == null) {
83-
throw new IllegalArgumentException("%s messages are not supported"
84-
.formatted(record.get(MessageAttributes.MESSAGE_TYPE.getValue()).asString()));
85-
}
86-
message.getMetadata().put("messageType", message.getMessageType());
87-
return message;
88-
}).toList();
68+
if (msgType.equals(MessageType.USER.getValue())) {
69+
message = buildUserMessage(record, messageMap, mediaList);
70+
}
71+
if (msgType.equals(MessageType.ASSISTANT.getValue())) {
72+
message = buildAssistantMessage(record, messageMap, mediaList);
73+
}
74+
if (msgType.equals(MessageType.SYSTEM.getValue())) {
75+
SystemMessage.Builder systemMessageBuilder = SystemMessage.builder()
76+
.text(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString());
77+
if (!record.get("metadata").isNull()) {
78+
Map<String, Object> retrievedMetadata = record.get("metadata").asMap();
79+
systemMessageBuilder.metadata(retrievedMetadata);
80+
}
81+
message = systemMessageBuilder.build();
82+
}
83+
if (msgType.equals(MessageType.TOOL.getValue())) {
84+
message = buildToolMessage(record);
85+
}
86+
if (message == null) {
87+
throw new IllegalArgumentException("%s messages are not supported"
88+
.formatted(record.get(MessageAttributes.MESSAGE_TYPE.getValue()).asString()));
89+
}
90+
message.getMetadata().put("messageType", message.getMessageType());
91+
return message;
92+
}, Collectors.toList()));
8993

9094
}
9195

@@ -96,12 +100,11 @@ public void saveAll(String conversationId, List<Message> messages) {
96100

97101
// Then add the new messages
98102
try (Session s = this.config.getDriver().session()) {
99-
try (Transaction t = s.beginTransaction()) {
103+
s.executeWriteWithoutResult(tx -> {
100104
for (Message m : messages) {
101-
addMessageToTransaction(t, conversationId, m);
105+
addMessageToTransaction(tx, conversationId, m);
102106
}
103-
t.commit();
104-
}
107+
});
105108
}
106109
}
107110

@@ -196,42 +199,46 @@ else if (mediaMap.get(MediaAttributes.DATA.getValue()).getClass().isArray()) {
196199
return mediaList;
197200
}
198201

199-
private void addMessageToTransaction(Transaction t, String conversationId, Message message) {
202+
private void addMessageToTransaction(TransactionContext t, String conversationId, Message message) {
200203
Map<String, Object> queryParameters = new HashMap<>();
201204
queryParameters.put("conversationId", conversationId);
202205
StringBuilder statementBuilder = new StringBuilder();
203206
statementBuilder.append("""
204-
MERGE (s:%s {id:$conversationId}) WITH s
205-
OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:%s) WITH coalesce(count(countMsg), 0) as totalMsg, s
206-
CREATE (s)-[:HAS_MESSAGE]->(msg:%s) SET msg = $messageProperties
207+
MERGE (s:$($sessionLabel) {id:$conversationId}) WITH s
208+
OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:$($messageLabel))
209+
WITH coalesce(count(countMsg), 0) as totalMsg, s
210+
CREATE (s)-[:HAS_MESSAGE]->(msg:$($messageLabel)) SET msg = $messageProperties
207211
SET msg.idx = totalMsg + 1
208-
""".formatted(this.config.getSessionLabel(), this.config.getMessageLabel(),
209-
this.config.getMessageLabel()));
212+
""");
210213
Map<String, Object> attributes = new HashMap<>();
211214

212215
attributes.put(MessageAttributes.MESSAGE_TYPE.getValue(), message.getMessageType().getValue());
213216
attributes.put(MessageAttributes.TEXT_CONTENT.getValue(), message.getText());
214217
attributes.put("id", UUID.randomUUID().toString());
215218
queryParameters.put("messageProperties", attributes);
219+
queryParameters.put("sessionLabel", this.config.getSessionLabel());
220+
queryParameters.put("messageLabel", this.config.getMessageLabel());
216221

217222
if (!Optional.ofNullable(message.getMetadata()).orElse(Map.of()).isEmpty()) {
218223
statementBuilder.append("""
219224
WITH msg
220-
CREATE (metadataNode:%s)
225+
CREATE (metadataNode:$($metadataLabel))
221226
CREATE (msg)-[:HAS_METADATA]->(metadataNode)
222227
SET metadataNode = $metadata
223-
""".formatted(this.config.getMetadataLabel()));
228+
""");
224229
Map<String, Object> metadataCopy = new HashMap<>(message.getMetadata());
225230
metadataCopy.remove("messageType");
226231
queryParameters.put("metadata", metadataCopy);
232+
queryParameters.put("metadataLabel", this.config.getMetadataLabel());
227233
}
228234
if (message instanceof AssistantMessage assistantMessage) {
229235
if (assistantMessage.hasToolCalls()) {
230236
statementBuilder.append("""
231237
WITH msg
232-
FOREACH(tc in $toolCalls | CREATE (toolCall:%s) SET toolCall = tc
238+
FOREACH(tc in $toolCalls | CREATE (toolCall:$($toolLabel)) SET toolCall = tc
233239
CREATE (msg)-[:HAS_TOOL_CALL]->(toolCall))
234-
""".formatted(this.config.getToolCallLabel()));
240+
""");
241+
queryParameters.put("toolLabel", this.config.getToolCallLabel());
235242
List<Map<String, Object>> toolCallMaps = new ArrayList<>();
236243
for (int i = 0; i < assistantMessage.getToolCalls().size(); i++) {
237244
AssistantMessage.ToolCall tc = assistantMessage.getToolCalls().get(i);
@@ -256,21 +263,23 @@ OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:%s) WITH coalesce(count(countMsg),
256263
}
257264
statementBuilder.append("""
258265
WITH msg
259-
FOREACH(tr IN $toolResponses | CREATE (tm:%s)
266+
FOREACH(tr IN $toolResponses | CREATE (tm:$($toolResponseLabel))
260267
SET tm = tr
261268
MERGE (msg)-[:HAS_TOOL_RESPONSE]->(tm))
262-
""".formatted(this.config.getToolResponseLabel()));
269+
""");
263270
queryParameters.put("toolResponses", toolResponseMaps);
271+
queryParameters.put("toolResponseLabel", this.config.getToolResponseLabel());
264272
}
265273
if (message instanceof MediaContent messageWithMedia && !messageWithMedia.getMedia().isEmpty()) {
266274
List<Map<String, Object>> mediaNodes = convertMediaToMap(messageWithMedia.getMedia());
267275
statementBuilder.append("""
268276
WITH msg
269277
UNWIND $media AS m
270-
CREATE (media:%s) SET media = m
278+
CREATE (media:$($mediaLabel)) SET media = m
271279
WITH msg, media CREATE (msg)-[:HAS_MEDIA]->(media)
272-
""".formatted(this.config.getMediaLabel()));
280+
""");
273281
queryParameters.put("media", mediaNodes);
282+
queryParameters.put("mediaLabel", this.config.getMediaLabel());
274283
}
275284
t.run(statementBuilder.toString(), queryParameters);
276285
}

0 commit comments

Comments
 (0)