Skip to content

Commit 591af01

Browse files
author
Manfred Ng
committed
BAEL-9331: Chat Memory in Spring AI
1 parent ec6adc6 commit 591af01

File tree

10 files changed

+307
-0
lines changed

10 files changed

+307
-0
lines changed

pom.xml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,7 @@
764764
<module>spring-ai</module>
765765
<module>spring-ai-2</module>
766766
<module>spring-ai-3</module>
767+
<module>spring-ai-4</module>
767768
<module>spring-ai-modules</module>
768769
<module>spring-aop</module>
769770
<module>spring-aop-2</module>
@@ -1196,6 +1197,7 @@
11961197
<module>spring-ai</module>
11971198
<module>spring-ai-2</module>
11981199
<module>spring-ai-3</module>
1200+
<module>spring-ai-4</module>
11991201
<module>spring-ai-modules</module>
12001202
<module>spring-aop</module>
12011203
<module>spring-aop-2</module>

spring-ai-4/pom.xml

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
4+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
5+
<modelVersion>4.0.0</modelVersion>
6+
<artifactId>spring-ai-4</artifactId>
7+
<version>0.0.1</version>
8+
<packaging>jar</packaging>
9+
<name>spring-ai-4</name>
10+
11+
<parent>
12+
<groupId>com.baeldung</groupId>
13+
<artifactId>parent-boot-3</artifactId>
14+
<version>0.0.1-SNAPSHOT</version>
15+
<relativePath>../parent-boot-3</relativePath>
16+
</parent>
17+
18+
<repositories>
19+
<repository>
20+
<id>spring-milestones</id>
21+
<name>Spring Milestones</name>
22+
<url>https://repo.spring.io/milestone</url>
23+
<snapshots>
24+
<enabled>false</enabled>
25+
</snapshots>
26+
<releases>
27+
<enabled>true</enabled>
28+
</releases>
29+
</repository>
30+
<repository>
31+
<id>spring-snapshots</id>
32+
<name>Spring Snapshots</name>
33+
<url>https://repo.spring.io/snapshot</url>
34+
<snapshots>
35+
<enabled>true</enabled>
36+
</snapshots>
37+
<releases>
38+
<enabled>false</enabled>
39+
</releases>
40+
</repository>
41+
</repositories>
42+
43+
<dependencyManagement>
44+
<dependencies>
45+
<dependency>
46+
<groupId>org.springframework.ai</groupId>
47+
<artifactId>spring-ai-bom</artifactId>
48+
<version>${spring-ai.version}</version>
49+
<type>pom</type>
50+
<scope>import</scope>
51+
</dependency>
52+
</dependencies>
53+
</dependencyManagement>
54+
55+
<dependencies>
56+
<dependency>
57+
<groupId>org.springframework.boot</groupId>
58+
<artifactId>spring-boot-starter-web</artifactId>
59+
</dependency>
60+
<dependency>
61+
<groupId>org.springframework.ai</groupId>
62+
<artifactId>spring-ai-starter-model-openai</artifactId>
63+
</dependency>
64+
<dependency>
65+
<groupId>org.springframework.ai</groupId>
66+
<artifactId>spring-ai-model-chat-memory-repository-jdbc</artifactId>
67+
</dependency>
68+
<dependency>
69+
<groupId>org.hsqldb</groupId>
70+
<artifactId>hsqldb</artifactId>
71+
<scope>runtime</scope>
72+
</dependency>
73+
74+
<!-- Test dependencies -->
75+
<dependency>
76+
<groupId>org.springframework.boot</groupId>
77+
<artifactId>spring-boot-starter-test</artifactId>
78+
<scope>test</scope>
79+
</dependency>
80+
</dependencies>
81+
82+
<profiles>
83+
<profile>
84+
<id>chat-memory</id>
85+
<activation>
86+
<activeByDefault>true</activeByDefault>
87+
</activation>
88+
<properties>
89+
<spring.boot.mainclass>com.baeldung.springai.memory.Application</spring.boot.mainclass>
90+
</properties>
91+
</profile>
92+
</profiles>
93+
94+
<build>
95+
<plugins>
96+
<plugin>
97+
<groupId>org.springframework.boot</groupId>
98+
<artifactId>spring-boot-maven-plugin</artifactId>
99+
<configuration>
100+
<mainClass>${spring.boot.mainclass}</mainClass>
101+
</configuration>
102+
</plugin>
103+
<plugin>
104+
<groupId>org.apache.maven.plugins</groupId>
105+
<artifactId>maven-compiler-plugin</artifactId>
106+
<configuration>
107+
<release>21</release>
108+
</configuration>
109+
</plugin>
110+
</plugins>
111+
</build>
112+
113+
<properties>
114+
<junit-jupiter.version>5.9.0</junit-jupiter.version>
115+
<spring-boot.version>3.5.0</spring-boot.version>
116+
<spring-ai.version>1.0.0</spring-ai.version>
117+
</properties>
118+
119+
</project>
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package com.baeldung.springai.memory;
2+
3+
import org.springframework.boot.SpringApplication;
4+
import org.springframework.boot.autoconfigure.SpringBootApplication;
5+
6+
@SpringBootApplication
7+
public class Application {
8+
9+
public static void main(String[] args) {
10+
SpringApplication app = new SpringApplication(Application.class);
11+
app.setAdditionalProfiles("memory");
12+
app.run(args);
13+
}
14+
15+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package com.baeldung.springai.memory;
2+
3+
import org.springframework.ai.chat.memory.ChatMemoryRepository;
4+
import org.springframework.ai.chat.memory.repository.jdbc.HsqldbChatMemoryRepositoryDialect;
5+
import org.springframework.ai.chat.memory.repository.jdbc.JdbcChatMemoryRepository;
6+
import org.springframework.context.annotation.Bean;
7+
import org.springframework.context.annotation.Configuration;
8+
import org.springframework.jdbc.core.JdbcTemplate;
9+
10+
@Configuration
11+
public class ChatConfig {
12+
13+
@Bean
14+
public ChatMemoryRepository getChatMemoryRepository(JdbcTemplate jdbcTemplate) {
15+
return JdbcChatMemoryRepository.builder()
16+
.jdbcTemplate(jdbcTemplate)
17+
.dialect(new HsqldbChatMemoryRepositoryDialect())
18+
.build();
19+
}
20+
21+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package com.baeldung.springai.memory;
2+
3+
import org.springframework.http.ResponseEntity;
4+
import org.springframework.web.bind.annotation.PostMapping;
5+
import org.springframework.web.bind.annotation.RequestBody;
6+
import org.springframework.web.bind.annotation.RestController;
7+
8+
import javax.validation.Valid;
9+
10+
@RestController
11+
public class ChatController {
12+
13+
private final ChatService chatService;
14+
15+
public ChatController(ChatService chatService) {
16+
this.chatService = chatService;
17+
}
18+
19+
@PostMapping("/chat")
20+
public ResponseEntity<String> chat(@RequestBody @Valid ChatRequest request) {
21+
String response = chatService.chat(request.getPrompt());
22+
return ResponseEntity.ok(response);
23+
}
24+
25+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package com.baeldung.springai.memory;
2+
3+
import javax.validation.constraints.NotNull;
4+
5+
public class ChatRequest {
6+
7+
@NotNull
8+
private String prompt;
9+
10+
public String getPrompt() {
11+
return prompt;
12+
}
13+
14+
public void setPrompt(String prompt) {
15+
this.prompt = prompt;
16+
}
17+
18+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package com.baeldung.springai.memory;
2+
3+
import org.springframework.ai.chat.client.ChatClient;
4+
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
5+
import org.springframework.ai.chat.memory.ChatMemory;
6+
import org.springframework.ai.chat.model.ChatModel;
7+
import org.springframework.stereotype.Component;
8+
import org.springframework.web.context.annotation.SessionScope;
9+
10+
import java.util.UUID;
11+
12+
@Component
13+
@SessionScope
14+
public class ChatService {
15+
16+
private final ChatClient chatClient;
17+
private final String conversationId;
18+
19+
public ChatService(ChatModel chatModel, ChatMemory chatMemory) {
20+
this.chatClient = ChatClient.builder(chatModel)
21+
.defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).build())
22+
.build();
23+
this.conversationId = UUID.randomUUID().toString();
24+
}
25+
26+
public String getConversationId() {
27+
return conversationId;
28+
}
29+
30+
public String chat(String prompt) {
31+
return chatClient.prompt()
32+
.user(userMessage -> userMessage.text(prompt))
33+
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
34+
.call()
35+
.content();
36+
}
37+
38+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
spring:
2+
ai:
3+
openai:
4+
api-key: "<YOUR-API-KEY>"
5+
6+
datasource:
7+
url: jdbc:hsqldb:mem:chatdb
8+
driver-class-name: org.hsqldb.jdbc.JDBCDriver
9+
username: sa
10+
password:
11+
12+
sql:
13+
init:
14+
mode: always
15+
schema-locations: classpath:org/springframework/ai/chat/memory/repository/jdbc/schema-hsqldb.sql
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
<configuration>
2+
<appender name="CONSOLE" class="ch.qos.logback.core.ConsoleAppender">
3+
<encoder>
4+
<pattern>[%d{yyyy-MM-dd HH:mm:ss}] [%p] [%c{1}] - %m%n</pattern>
5+
</encoder>
6+
</appender>
7+
8+
<root level="INFO">
9+
<appender-ref ref="CONSOLE" />
10+
</root>
11+
12+
<logger name="org.springframework" level="INFO" additivity="false">
13+
<appender-ref ref="CONSOLE" />
14+
</logger>
15+
</configuration>
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package com.baeldung.springai.memory;
2+
3+
import org.junit.jupiter.api.Test;
4+
import org.springframework.ai.chat.memory.ChatMemory;
5+
import org.springframework.beans.factory.annotation.Autowired;
6+
import org.springframework.boot.test.context.SpringBootTest;
7+
import org.springframework.test.context.ActiveProfiles;
8+
9+
import static org.assertj.core.api.Assertions.assertThat;
10+
11+
@SpringBootTest
12+
@ActiveProfiles("memory")
13+
class ChatServiceLiveTest {
14+
15+
private static final String PROMPT_1ST = "Tell me a joke";
16+
private static final String PROMPT_2ND = "Tell me another one";
17+
18+
@Autowired
19+
private ChatMemory chatMemory;
20+
21+
@Autowired
22+
private ChatService chatService;
23+
24+
@Test
25+
void whenChatServiceIsCalledTwice_thenChatMemoryHasCorrectNumberOfEntries() {
26+
String conversationId = chatService.getConversationId();
27+
28+
// 1st request
29+
String response1 = chatService.chat(PROMPT_1ST);
30+
assertThat(response1).isNotEmpty();
31+
assertThat(chatMemory.get(conversationId)).hasSize(2);
32+
33+
// 2nd request
34+
String response2 = chatService.chat(PROMPT_2ND);
35+
assertThat(response2).isNotEmpty();
36+
assertThat(chatMemory.get(conversationId)).hasSize(4);
37+
}
38+
39+
}

0 commit comments

Comments
 (0)