Skip to content

Polish Neo4jChatMemoryRepository. #3025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package org.springframework.ai.chat.memory.neo4j;

import org.neo4j.driver.Result;
import org.neo4j.driver.Session;
import org.neo4j.driver.Transaction;
import org.neo4j.driver.TransactionContext;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.content.Media;
Expand All @@ -11,15 +11,17 @@

import java.net.URI;
import java.util.*;
import java.util.stream.Collectors;

/**
* An implementation of {@link ChatMemoryRepository} for Neo4J
*
* @author Enrico Rampazzo
* @author Michael J. Simons
* @since 1.0.0
*/

public class Neo4jChatMemoryRepository implements ChatMemoryRepository {
public final class Neo4jChatMemoryRepository implements ChatMemoryRepository {

private final Neo4jChatMemoryConfig config;

Expand All @@ -29,63 +31,65 @@ public Neo4jChatMemoryRepository(Neo4jChatMemoryConfig config) {

@Override
public List<String> findConversationIds() {
try (var session = config.getDriver().session()) {
return session.run("MATCH (conversation:%s) RETURN conversation.id".formatted(config.getSessionLabel()))
.stream()
.map(r -> r.get("conversation.id").asString())
.toList();
}
return config.getDriver()
.executableQuery("MATCH (conversation:$($sessionLabel)) RETURN conversation.id")
.withParameters(Map.of("sessionLabel", config.getSessionLabel()))
.execute(Collectors.mapping(r -> r.get("conversation.id").asString(), Collectors.toList()));
}

@Override
public List<Message> findByConversationId(String conversationId) {
String statementBuilder = """
MATCH (s:%s {id:$conversationId})-[r:HAS_MESSAGE]->(m:%s)
String statement = """
MATCH (s:$($sessionLabel) {id:$conversationId})-[r:HAS_MESSAGE]->(m:$($messageLabel))
WITH m
OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:%s)
OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:%s) WITH m, metadata, media ORDER BY media.idx ASC
OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:%s) WITH m, metadata, media, tr ORDER BY tr.idx ASC
OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:%s)
OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:$($metadataLabel))
OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:$($mediaLabel)) WITH m, metadata, media ORDER BY media.idx ASC
OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:$($toolResponseLabel)) WITH m, metadata, media, tr ORDER BY tr.idx ASC
OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:$($toolCallLabel))
WITH m, metadata, media, tr, tc ORDER BY tc.idx ASC
RETURN m, metadata, collect(tr) as toolResponses, collect(tc) as toolCalls, collect(media) as medias
ORDER BY m.idx ASC
""".formatted(this.config.getSessionLabel(), this.config.getMessageLabel(),
this.config.getMetadataLabel(), this.config.getMediaLabel(), this.config.getToolResponseLabel(),
this.config.getToolCallLabel());
Result res = this.config.getDriver().session().run(statementBuilder, Map.of("conversationId", conversationId));
return res.stream().map(record -> {
Map<String, Object> messageMap = record.get("m").asMap();
String msgType = messageMap.get(MessageAttributes.MESSAGE_TYPE.getValue()).toString();
Message message = null;
List<Media> mediaList = List.of();
if (!record.get("medias").isNull()) {
mediaList = getMedia(record);
}
if (msgType.equals(MessageType.USER.getValue())) {
message = buildUserMessage(record, messageMap, mediaList);
}
if (msgType.equals(MessageType.ASSISTANT.getValue())) {
message = buildAssistantMessage(record, messageMap, mediaList);
}
if (msgType.equals(MessageType.SYSTEM.getValue())) {
SystemMessage.Builder systemMessageBuilder = SystemMessage.builder()
.text(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString());
if (!record.get("metadata").isNull()) {
Map<String, Object> retrievedMetadata = record.get("metadata").asMap();
systemMessageBuilder.metadata(retrievedMetadata);
""";

return this.config.getDriver()
.executableQuery(statement)
.withParameters(Map.of("conversationId", conversationId, "sessionLabel", this.config.getSessionLabel(),
"messageLabel", this.config.getMessageLabel(), "metadataLabel", this.config.getMetadataLabel(),
"mediaLabel", this.config.getMediaLabel(), "toolResponseLabel", this.config.getToolResponseLabel(),
"toolCallLabel", this.config.getToolCallLabel()))
.execute(Collectors.mapping(record -> {
Map<String, Object> messageMap = record.get("m").asMap();
String msgType = messageMap.get(MessageAttributes.MESSAGE_TYPE.getValue()).toString();
Message message = null;
List<Media> mediaList = List.of();
if (!record.get("medias").isNull()) {
mediaList = getMedia(record);
}
message = systemMessageBuilder.build();
}
if (msgType.equals(MessageType.TOOL.getValue())) {
message = buildToolMessage(record);
}
if (message == null) {
throw new IllegalArgumentException("%s messages are not supported"
.formatted(record.get(MessageAttributes.MESSAGE_TYPE.getValue()).asString()));
}
message.getMetadata().put("messageType", message.getMessageType());
return message;
}).toList();
if (msgType.equals(MessageType.USER.getValue())) {
message = buildUserMessage(record, messageMap, mediaList);
}
if (msgType.equals(MessageType.ASSISTANT.getValue())) {
message = buildAssistantMessage(record, messageMap, mediaList);
}
if (msgType.equals(MessageType.SYSTEM.getValue())) {
SystemMessage.Builder systemMessageBuilder = SystemMessage.builder()
.text(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString());
if (!record.get("metadata").isNull()) {
Map<String, Object> retrievedMetadata = record.get("metadata").asMap();
systemMessageBuilder.metadata(retrievedMetadata);
}
message = systemMessageBuilder.build();
}
if (msgType.equals(MessageType.TOOL.getValue())) {
message = buildToolMessage(record);
}
if (message == null) {
throw new IllegalArgumentException("%s messages are not supported"
.formatted(record.get(MessageAttributes.MESSAGE_TYPE.getValue()).asString()));
}
message.getMetadata().put("messageType", message.getMessageType());
return message;
}, Collectors.toList()));

}

Expand All @@ -96,12 +100,11 @@ public void saveAll(String conversationId, List<Message> messages) {

// Then add the new messages
try (Session s = this.config.getDriver().session()) {
try (Transaction t = s.beginTransaction()) {
s.executeWriteWithoutResult(tx -> {
for (Message m : messages) {
addMessageToTransaction(t, conversationId, m);
addMessageToTransaction(tx, conversationId, m);
}
t.commit();
}
});
}
}

Expand Down Expand Up @@ -196,42 +199,46 @@ else if (mediaMap.get(MediaAttributes.DATA.getValue()).getClass().isArray()) {
return mediaList;
}

private void addMessageToTransaction(Transaction t, String conversationId, Message message) {
private void addMessageToTransaction(TransactionContext t, String conversationId, Message message) {
Map<String, Object> queryParameters = new HashMap<>();
queryParameters.put("conversationId", conversationId);
StringBuilder statementBuilder = new StringBuilder();
statementBuilder.append("""
MERGE (s:%s {id:$conversationId}) WITH s
OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:%s) WITH coalesce(count(countMsg), 0) as totalMsg, s
CREATE (s)-[:HAS_MESSAGE]->(msg:%s) SET msg = $messageProperties
MERGE (s:$($sessionLabel) {id:$conversationId}) WITH s
OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:$($messageLabel))
WITH coalesce(count(countMsg), 0) as totalMsg, s
CREATE (s)-[:HAS_MESSAGE]->(msg:$($messageLabel)) SET msg = $messageProperties
SET msg.idx = totalMsg + 1
""".formatted(this.config.getSessionLabel(), this.config.getMessageLabel(),
this.config.getMessageLabel()));
""");
Map<String, Object> attributes = new HashMap<>();

attributes.put(MessageAttributes.MESSAGE_TYPE.getValue(), message.getMessageType().getValue());
attributes.put(MessageAttributes.TEXT_CONTENT.getValue(), message.getText());
attributes.put("id", UUID.randomUUID().toString());
queryParameters.put("messageProperties", attributes);
queryParameters.put("sessionLabel", this.config.getSessionLabel());
queryParameters.put("messageLabel", this.config.getMessageLabel());

if (!Optional.ofNullable(message.getMetadata()).orElse(Map.of()).isEmpty()) {
statementBuilder.append("""
WITH msg
CREATE (metadataNode:%s)
CREATE (metadataNode:$($metadataLabel))
CREATE (msg)-[:HAS_METADATA]->(metadataNode)
SET metadataNode = $metadata
""".formatted(this.config.getMetadataLabel()));
""");
Map<String, Object> metadataCopy = new HashMap<>(message.getMetadata());
metadataCopy.remove("messageType");
queryParameters.put("metadata", metadataCopy);
queryParameters.put("metadataLabel", this.config.getMetadataLabel());
}
if (message instanceof AssistantMessage assistantMessage) {
if (assistantMessage.hasToolCalls()) {
statementBuilder.append("""
WITH msg
FOREACH(tc in $toolCalls | CREATE (toolCall:%s) SET toolCall = tc
FOREACH(tc in $toolCalls | CREATE (toolCall:$($toolLabel)) SET toolCall = tc
CREATE (msg)-[:HAS_TOOL_CALL]->(toolCall))
""".formatted(this.config.getToolCallLabel()));
""");
queryParameters.put("toolLabel", this.config.getToolCallLabel());
List<Map<String, Object>> toolCallMaps = new ArrayList<>();
for (int i = 0; i < assistantMessage.getToolCalls().size(); i++) {
AssistantMessage.ToolCall tc = assistantMessage.getToolCalls().get(i);
Expand All @@ -256,21 +263,23 @@ OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:%s) WITH coalesce(count(countMsg),
}
statementBuilder.append("""
WITH msg
FOREACH(tr IN $toolResponses | CREATE (tm:%s)
FOREACH(tr IN $toolResponses | CREATE (tm:$($toolResponseLabel))
SET tm = tr
MERGE (msg)-[:HAS_TOOL_RESPONSE]->(tm))
""".formatted(this.config.getToolResponseLabel()));
""");
queryParameters.put("toolResponses", toolResponseMaps);
queryParameters.put("toolResponseLabel", this.config.getToolResponseLabel());
}
if (message instanceof MediaContent messageWithMedia && !messageWithMedia.getMedia().isEmpty()) {
List<Map<String, Object>> mediaNodes = convertMediaToMap(messageWithMedia.getMedia());
statementBuilder.append("""
WITH msg
UNWIND $media AS m
CREATE (media:%s) SET media = m
CREATE (media:$($mediaLabel)) SET media = m
WITH msg, media CREATE (msg)-[:HAS_MEDIA]->(media)
""".formatted(this.config.getMediaLabel()));
""");
queryParameters.put("media", mediaNodes);
queryParameters.put("mediaLabel", this.config.getMediaLabel());
}
t.run(statementBuilder.toString(), queryParameters);
}
Expand Down