diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/TokenWindowChatMemory.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/TokenWindowChatMemory.java new file mode 100644 index 00000000000..dbd3d0b515c --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/TokenWindowChatMemory.java @@ -0,0 +1,167 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator; +import org.springframework.ai.tokenizer.TokenCountEstimator; +import org.springframework.util.Assert; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * A chat memory implementation that maintains a message window of a specified size, + * ensuring that the total number of tokens does not exceed the specified limit. Messages + * are treated as indivisible units; when eviction is necessary due to exceeding the token + * limit, the oldest complete message is removed. + *

+ * Messages of type {@link SystemMessage} are treated specially: if a new + * {@link SystemMessage} is added, all previous {@link SystemMessage} instances are + * removed from the memory. + * + * @author Sun Yuhan + * @since 1.1.0 + */ +public final class TokenWindowChatMemory implements ChatMemory { + + private static final long DEFAULT_MAX_TOKENS = 128000L; + + private final ChatMemoryRepository chatMemoryRepository; + + private final TokenCountEstimator tokenCountEstimator; + + private final long maxTokens; + + public TokenWindowChatMemory(ChatMemoryRepository chatMemoryRepository, TokenCountEstimator tokenCountEstimator, + Long maxTokens) { + Assert.notNull(chatMemoryRepository, "chatMemoryRepository cannot be null"); + Assert.notNull(tokenCountEstimator, "tokenCountEstimator cannot be null"); + Assert.isTrue(maxTokens > 0, "maxTokens must be greater than 0"); + this.chatMemoryRepository = chatMemoryRepository; + this.tokenCountEstimator = tokenCountEstimator; + this.maxTokens = maxTokens; + } + + @Override + public void add(String conversationId, List messages) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + Assert.notNull(messages, "messages cannot be null"); + Assert.noNullElements(messages, "messages cannot contain null elements"); + + List memoryMessages = this.chatMemoryRepository.findByConversationId(conversationId); + List processedMessages = process(memoryMessages, messages); + this.chatMemoryRepository.saveAll(conversationId, processedMessages); + } + + @Override + public List get(String conversationId) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + return this.chatMemoryRepository.findByConversationId(conversationId); + } + + @Override + public void clear(String conversationId) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + this.chatMemoryRepository.deleteByConversationId(conversationId); + } + + private List process(List memoryMessages, List newMessages) { + List processedMessages = new ArrayList<>(); + + Set memoryMessagesSet = new HashSet<>(memoryMessages); + boolean hasNewSystemMessage = newMessages.stream() + .filter(SystemMessage.class::isInstance) + .anyMatch(message -> !memoryMessagesSet.contains(message)); + + memoryMessages.stream() + .filter(message -> !(hasNewSystemMessage && message instanceof SystemMessage)) + .forEach(processedMessages::add); + + processedMessages.addAll(newMessages); + + int tokens = processedMessages.stream() + .mapToInt(processedMessage -> tokenCountEstimator.estimate(processedMessage.getText())) + .sum(); + + if (tokens <= this.maxTokens) { + return processedMessages; + } + + int removeMessageIndex = 0; + while (tokens > this.maxTokens && !processedMessages.isEmpty() + && removeMessageIndex < processedMessages.size()) { + if (processedMessages.get(removeMessageIndex) instanceof SystemMessage) { + if (processedMessages.size() == 1) { + break; + } + removeMessageIndex += 1; + continue; + } + Message removedMessage = processedMessages.remove(removeMessageIndex); + tokens -= tokenCountEstimator.estimate(removedMessage.getText()); + } + + return processedMessages; + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + + private ChatMemoryRepository chatMemoryRepository; + + private TokenCountEstimator tokenCountEstimator; + + private long maxTokens = DEFAULT_MAX_TOKENS; + + private Builder() { + } + + public Builder chatMemoryRepository(ChatMemoryRepository chatMemoryRepository) { + this.chatMemoryRepository = chatMemoryRepository; + return this; + } + + public Builder tokenCountEstimator(TokenCountEstimator tokenCountEstimator) { + this.tokenCountEstimator = tokenCountEstimator; + return this; + } + + public Builder maxTokens(long maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public TokenWindowChatMemory build() { + if (this.chatMemoryRepository == null) { + this.chatMemoryRepository = new InMemoryChatMemoryRepository(); + } + if (this.tokenCountEstimator == null) { + this.tokenCountEstimator = new JTokkitTokenCountEstimator(); + } + return new TokenWindowChatMemory(this.chatMemoryRepository, this.tokenCountEstimator, this.maxTokens); + } + + } + +} diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/TokenWindowChatMemoryTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/TokenWindowChatMemoryTests.java new file mode 100644 index 00000000000..940926619c7 --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/TokenWindowChatMemoryTests.java @@ -0,0 +1,308 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator; + +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link TokenWindowChatMemory}. + * + * @author Sun Yuhan + * @since 1.1.0 + */ +public class TokenWindowChatMemoryTests { + + private final TokenWindowChatMemory chatMemory = TokenWindowChatMemory.builder().build(); + + @Test + void zeroMaxMessagesNotAllowed() { + assertThatThrownBy(() -> TokenWindowChatMemory.builder().maxTokens(0).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("maxTokens must be greater than 0"); + } + + @Test + void negativeMaxTokensNotAllowed() { + assertThatThrownBy(() -> TokenWindowChatMemory.builder().maxTokens(-1).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("maxTokens must be greater than 0"); + } + + @Test + void handleMultipleMessagesInConversation() { + String conversationId = UUID.randomUUID().toString(); + List messages = List.of(new AssistantMessage("I, Robot"), new UserMessage("Hello")); + + this.chatMemory.add(conversationId, messages); + + assertThat(this.chatMemory.get(conversationId)).containsAll(messages); + + this.chatMemory.clear(conversationId); + + assertThat(this.chatMemory.get(conversationId)).isEmpty(); + } + + @Test + void handleSingleMessageInConversation() { + String conversationId = UUID.randomUUID().toString(); + Message message = new UserMessage("Hello"); + + this.chatMemory.add(conversationId, message); + + assertThat(this.chatMemory.get(conversationId)).contains(message); + + this.chatMemory.clear(conversationId); + + assertThat(this.chatMemory.get(conversationId)).isEmpty(); + } + + @Test + void nullConversationIdNotAllowed() { + assertThatThrownBy(() -> this.chatMemory.add(null, List.of(new UserMessage("Hello")))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> this.chatMemory.add(null, new UserMessage("Hello"))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> this.chatMemory.get(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> this.chatMemory.clear(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + } + + @Test + void emptyConversationIdNotAllowed() { + assertThatThrownBy(() -> this.chatMemory.add("", List.of(new UserMessage("Hello")))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> this.chatMemory.add(null, new UserMessage("Hello"))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> this.chatMemory.get("")).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> this.chatMemory.clear("")).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + } + + @Test + void nullMessagesNotAllowed() { + String conversationId = UUID.randomUUID().toString(); + assertThatThrownBy(() -> this.chatMemory.add(conversationId, (List) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("messages cannot be null"); + } + + @Test + void nullMessageNotAllowed() { + String conversationId = UUID.randomUUID().toString(); + assertThatThrownBy(() -> this.chatMemory.add(conversationId, (Message) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("message cannot be null"); + } + + @Test + void messagesWithNullElementsNotAllowed() { + String conversationId = UUID.randomUUID().toString(); + List messagesWithNull = new ArrayList<>(); + messagesWithNull.add(null); + + assertThatThrownBy(() -> this.chatMemory.add(conversationId, messagesWithNull)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("messages cannot contain null elements"); + } + + @Test + void customMaxTokens() { + String conversationId = UUID.randomUUID().toString(); + int customMaxTokens = 15; + + TokenWindowChatMemory customChatMemory = TokenWindowChatMemory.builder().maxTokens(customMaxTokens).build(); + + List messages = List.of(new UserMessage("Message 1"), new AssistantMessage("Response 1"), + new UserMessage("Message 2"), new AssistantMessage("Response 2"), new UserMessage("Message 3")); + + customChatMemory.add(conversationId, messages); + List result = customChatMemory.get(conversationId); + + assertThat(result).hasSize(5); + } + + @Test + void customTokenCountEstimator() { + String conversationId = UUID.randomUUID().toString(); + + TokenWindowChatMemory customChatMemory = TokenWindowChatMemory.builder() + .tokenCountEstimator(new JTokkitTokenCountEstimator()) + .build(); + + List messages = List.of(new UserMessage("Message 1"), new AssistantMessage("Response 1"), + new UserMessage("Message 2"), new AssistantMessage("Response 2"), new UserMessage("Message 3")); + + customChatMemory.add(conversationId, messages); + List result = customChatMemory.get(conversationId); + + assertThat(result).hasSize(5); + } + + @Test + void noEvictionWhenMessagesWithinLimit() { + int maxTokens = 10; + TokenWindowChatMemory customChatMemory = TokenWindowChatMemory.builder().maxTokens(maxTokens).build(); + + String conversationId = UUID.randomUUID().toString(); + List memoryMessages = new ArrayList<>( + List.of(new UserMessage("Message 1"), new AssistantMessage("Response 1"))); + customChatMemory.add(conversationId, memoryMessages); + + List newMessages = new ArrayList<>(List.of(new UserMessage("Message 2"))); + customChatMemory.add(conversationId, newMessages); + + List result = customChatMemory.get(conversationId); + + assertThat(result).hasSize(3); + assertThat(result).containsExactly(new UserMessage("Message 1"), new AssistantMessage("Response 1"), + new UserMessage("Message 2")); + } + + @Test + void evictionWhenMessagesExceedLimit() { + int maxTokens = 3; + TokenWindowChatMemory customChatMemory = TokenWindowChatMemory.builder().maxTokens(maxTokens).build(); + + String conversationId = UUID.randomUUID().toString(); + List memoryMessages = new ArrayList<>( + List.of(new UserMessage("Message 1"), new AssistantMessage("Response 1"))); + customChatMemory.add(conversationId, memoryMessages); + + List newMessages = new ArrayList<>(List.of(new UserMessage("Message 2"))); + customChatMemory.add(conversationId, newMessages); + + List result = customChatMemory.get(conversationId); + + assertThat(result).hasSize(1); + assertThat(result).containsExactly(new UserMessage("Message 2")); + } + + @Test + void systemMessageIsPreservedDuringEviction() { + int maxTokens = 9; + TokenWindowChatMemory customChatMemory = TokenWindowChatMemory.builder().maxTokens(maxTokens).build(); + + String conversationId = UUID.randomUUID().toString(); + List memoryMessages = new ArrayList<>(List.of(new SystemMessage("System 1"), + new UserMessage("Message 1"), new AssistantMessage("Response 1"))); + customChatMemory.add(conversationId, memoryMessages); + + List newMessages = new ArrayList<>( + List.of(new UserMessage("Message 2"), new AssistantMessage("Response 2"))); + customChatMemory.add(conversationId, newMessages); + + List result = customChatMemory.get(conversationId); + + assertThat(result).hasSize(3); + assertThat(result).containsExactly(new SystemMessage("System 1"), new UserMessage("Message 2"), + new AssistantMessage("Response 2")); + } + + @Test + void multipleSystemMessagesArePreservedDuringEviction() { + int maxTokens = 9; + TokenWindowChatMemory customChatMemory = TokenWindowChatMemory.builder().maxTokens(maxTokens).build(); + + String conversationId = UUID.randomUUID().toString(); + List memoryMessages = new ArrayList<>(List.of(new SystemMessage("System 1"), + new SystemMessage("System 2"), new UserMessage("Message 1"), new AssistantMessage("Response 1"))); + customChatMemory.add(conversationId, memoryMessages); + + List newMessages = new ArrayList<>( + List.of(new UserMessage("Message 2"), new AssistantMessage("Response 2"))); + customChatMemory.add(conversationId, newMessages); + + List result = customChatMemory.get(conversationId); + + assertThat(result).hasSize(3); + assertThat(result).containsExactly(new SystemMessage("System 1"), new SystemMessage("System 2"), + new AssistantMessage("Response 2")); + } + + @Test + void emptyMessageList() { + String conversationId = UUID.randomUUID().toString(); + + List result = this.chatMemory.get(conversationId); + + assertThat(result).isEmpty(); + } + + @Test + void oldSystemMessagesAreRemovedWhenNewOneAdded() { + int maxTokens = 3; + TokenWindowChatMemory customChatMemory = TokenWindowChatMemory.builder().maxTokens(maxTokens).build(); + + String conversationId = UUID.randomUUID().toString(); + List memoryMessages = new ArrayList<>( + List.of(new SystemMessage("System 1"), new SystemMessage("System 2"))); + customChatMemory.add(conversationId, memoryMessages); + + List newMessages = new ArrayList<>(List.of(new SystemMessage("System 3"))); + customChatMemory.add(conversationId, newMessages); + + List result = customChatMemory.get(conversationId); + + assertThat(result).hasSize(1); + assertThat(result).containsExactly(new SystemMessage("System 3")); + } + + @Test + void mixedMessagesWithLimitEqualToSystemMessageCount() { + int maxTokens = 6; + TokenWindowChatMemory customChatMemory = TokenWindowChatMemory.builder().maxTokens(maxTokens).build(); + + String conversationId = UUID.randomUUID().toString(); + List memoryMessages = new ArrayList<>( + List.of(new SystemMessage("System 1"), new SystemMessage("System 2"))); + customChatMemory.add(conversationId, memoryMessages); + + List newMessages = new ArrayList<>( + List.of(new UserMessage("Message 1"), new AssistantMessage("Response 1"))); + customChatMemory.add(conversationId, newMessages); + + List result = customChatMemory.get(conversationId); + + assertThat(result).hasSize(2); + assertThat(result).containsExactly(new SystemMessage("System 1"), new SystemMessage("System 2")); + } + +}