Skip to content

Commit 11e23ab

Browse files
committed
refactor(conversations): Update conversation store message type handling
- Add support for GeneratedAssistantMessage in conversation store implementations - Update type hints to include generic type parameter for GeneratedAssistantMessage - Modify add_message_async method signatures across conversation store classes - Ensure consistent message type handling across different conversation store implementations - Improve type flexibility for message storage and retrieval
1 parent 56a3112 commit 11e23ab

File tree

7 files changed

+99
-24
lines changed

7 files changed

+99
-24
lines changed

agentle/agents/conversations/callback_conversation_store.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,14 @@ async def clear_messages(chat_id: str):
4646
from __future__ import annotations
4747

4848
from collections.abc import Awaitable, Callable, Sequence
49-
from typing import override
49+
from typing import Any, override
5050

5151
from agentle.agents.conversations.conversation_store import ConversationStore
5252
from agentle.generations.models.messages.assistant_message import AssistantMessage
5353
from agentle.generations.models.messages.developer_message import DeveloperMessage
54+
from agentle.generations.models.messages.generated_assistant_message import (
55+
GeneratedAssistantMessage,
56+
)
5457
from agentle.generations.models.messages.user_message import UserMessage
5558

5659

@@ -78,10 +81,23 @@ def __init__(
7881
self,
7982
get_callback: Callable[
8083
[str],
81-
Awaitable[Sequence[DeveloperMessage | UserMessage | AssistantMessage]],
84+
Awaitable[
85+
Sequence[
86+
DeveloperMessage
87+
| UserMessage
88+
| AssistantMessage
89+
]
90+
],
8291
],
8392
add_callback: Callable[
84-
[str, DeveloperMessage | UserMessage | AssistantMessage], Awaitable[None]
93+
[
94+
str,
95+
DeveloperMessage
96+
| UserMessage
97+
| AssistantMessage
98+
| GeneratedAssistantMessage[Any],
99+
],
100+
Awaitable[None],
85101
],
86102
clear_callback: Callable[[str], Awaitable[None]],
87103
message_limit: int | None = None,
@@ -146,8 +162,13 @@ async def get_conversation_history_async(
146162
return await self._get_callback(chat_id)
147163

148164
@override
149-
async def add_message_async(
150-
self, chat_id: str, message: DeveloperMessage | UserMessage | AssistantMessage
165+
async def add_message_async[T = Any](
166+
self,
167+
chat_id: str,
168+
message: DeveloperMessage
169+
| UserMessage
170+
| AssistantMessage
171+
| GeneratedAssistantMessage[T],
151172
) -> None:
152173
"""
153174
Add a message to the conversation using the user-provided add callback.

agentle/agents/conversations/conversation_store.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import abc
22
from collections.abc import Sequence
3+
from typing import Any
34

45
from agentle.generations.models.messages.assistant_message import AssistantMessage
56
from agentle.generations.models.messages.developer_message import DeveloperMessage
7+
from agentle.generations.models.messages.generated_assistant_message import GeneratedAssistantMessage
68
from agentle.generations.models.messages.user_message import UserMessage
79

810

@@ -42,8 +44,8 @@ async def get_conversation_history_async(
4244
) -> Sequence[DeveloperMessage | UserMessage | AssistantMessage]: ...
4345

4446
@abc.abstractmethod
45-
async def add_message_async(
46-
self, chat_id: str, message: DeveloperMessage | UserMessage | AssistantMessage
47+
async def add_message_async[T = Any](
48+
self, chat_id: str, message: DeveloperMessage | UserMessage | AssistantMessage | GeneratedAssistantMessage[T]
4749
) -> None: ...
4850

4951
@abc.abstractmethod

agentle/agents/conversations/firebase_conversation_store.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import annotations
22

33
from collections.abc import Sequence
4-
from typing import TYPE_CHECKING, override
4+
from typing import TYPE_CHECKING, Any, override
55

66
from agentle.agents.conversations.conversation_store import ConversationStore
77
from agentle.generations.models.messages.assistant_message import AssistantMessage
88
from agentle.generations.models.messages.developer_message import DeveloperMessage
9+
from agentle.generations.models.messages.generated_assistant_message import (
10+
GeneratedAssistantMessage,
11+
)
912
from agentle.generations.models.messages.user_message import UserMessage
1013

1114
if TYPE_CHECKING:
@@ -28,8 +31,8 @@ def __init__(
2831
self._collection_name = collection_name
2932

3033
@override
31-
async def add_message_async(
32-
self, chat_id: str, message: DeveloperMessage | UserMessage | AssistantMessage
34+
async def add_message_async[T = Any](
35+
self, chat_id: str, message: DeveloperMessage | UserMessage | AssistantMessage | GeneratedAssistantMessage[T]
3336
) -> None:
3437
from google.cloud import firestore
3538

agentle/agents/conversations/json_file_conversation_store.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from agentle.agents.conversations.conversation_store import ConversationStore
88
from agentle.generations.models.messages.assistant_message import AssistantMessage
99
from agentle.generations.models.messages.developer_message import DeveloperMessage
10+
from agentle.generations.models.messages.generated_assistant_message import (
11+
GeneratedAssistantMessage,
12+
)
1013
from agentle.generations.models.messages.user_message import UserMessage
1114

1215

@@ -87,8 +90,12 @@ def _json_serializer(self, obj: Any) -> Any:
8790
# For other non-serializable objects, convert to string
8891
return str(obj)
8992

90-
def _message_to_dict(
91-
self, message: DeveloperMessage | UserMessage | AssistantMessage
93+
def _message_to_dict[T](
94+
self,
95+
message: DeveloperMessage
96+
| UserMessage
97+
| AssistantMessage
98+
| GeneratedAssistantMessage[T],
9299
) -> dict[str, Any]:
93100
"""Convert a message object to a dictionary for JSON serialization."""
94101
# Get the basic message dictionary
@@ -177,8 +184,13 @@ def _dict_to_message(
177184
return UserMessage.model_validate(message_data)
178185

179186
@override
180-
async def add_message_async(
181-
self, chat_id: str, message: DeveloperMessage | UserMessage | AssistantMessage
187+
async def add_message_async[T = Any](
188+
self,
189+
chat_id: str,
190+
message: DeveloperMessage
191+
| UserMessage
192+
| AssistantMessage
193+
| GeneratedAssistantMessage[T],
182194
) -> None:
183195
"""Add a message to the conversation."""
184196
messages_data = self._load_messages(chat_id)

agentle/agents/conversations/local_conversation_store.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
from collections.abc import MutableMapping, MutableSequence, Sequence
2-
from typing import override
2+
from typing import Any, override
3+
34
from agentle.agents.conversations.conversation_store import ConversationStore
45
from agentle.generations.models.messages.assistant_message import AssistantMessage
56
from agentle.generations.models.messages.developer_message import DeveloperMessage
7+
from agentle.generations.models.messages.generated_assistant_message import (
8+
GeneratedAssistantMessage,
9+
)
610
from agentle.generations.models.messages.user_message import UserMessage
711

812

913
class LocalConversationStore(ConversationStore):
1014
__messages: MutableMapping[
11-
str, MutableSequence[DeveloperMessage | UserMessage | AssistantMessage]
15+
str,
16+
MutableSequence[DeveloperMessage | UserMessage | AssistantMessage],
1217
]
1318

1419
def __init__(
@@ -20,8 +25,13 @@ def __init__(
2025
self.__messages = {}
2126

2227
@override
23-
async def add_message_async(
24-
self, chat_id: str, message: DeveloperMessage | UserMessage | AssistantMessage
28+
async def add_message_async[T = Any](
29+
self,
30+
chat_id: str,
31+
message: DeveloperMessage
32+
| UserMessage
33+
| AssistantMessage
34+
| GeneratedAssistantMessage[T],
2535
) -> None:
2636
if chat_id not in self.__messages:
2737
self.__messages[chat_id] = []
@@ -41,6 +51,9 @@ async def add_message_async(
4151
# Don't add message if limit reached and not overriding
4252
return
4353

54+
if isinstance(message, GeneratedAssistantMessage):
55+
message = message.to_assistant_message()
56+
4457
self.__messages[chat_id].append(message)
4558

4659
@override

agentle/agents/conversations/mysql_conversation_store.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from agentle.agents.conversations.conversation_store import ConversationStore
99
from agentle.generations.models.messages.assistant_message import AssistantMessage
1010
from agentle.generations.models.messages.developer_message import DeveloperMessage
11+
from agentle.generations.models.messages.generated_assistant_message import (
12+
GeneratedAssistantMessage,
13+
)
1114
from agentle.generations.models.messages.user_message import UserMessage
1215

1316
if TYPE_CHECKING:
@@ -74,7 +77,11 @@ async def _ensure_table_exists(self) -> None:
7477
await conn.commit()
7578

7679
def _message_to_dict(
77-
self, message: DeveloperMessage | UserMessage | AssistantMessage
80+
self,
81+
message: DeveloperMessage
82+
| UserMessage
83+
| AssistantMessage
84+
| GeneratedAssistantMessage[Any],
7885
) -> dict[str, Any]:
7986
"""
8087
Convert a Message object to a dictionary for JSON serialization.
@@ -125,8 +132,13 @@ def _dict_to_message(
125132
return UserMessage.model_validate(message_dict)
126133

127134
@override
128-
async def add_message_async(
129-
self, chat_id: str, message: DeveloperMessage | UserMessage | AssistantMessage
135+
async def add_message_async[T = Any](
136+
self,
137+
chat_id: str,
138+
message: DeveloperMessage
139+
| UserMessage
140+
| AssistantMessage
141+
| GeneratedAssistantMessage[T],
130142
) -> None:
131143
"""
132144
Add a message to the conversation store.

agentle/agents/conversations/postgres_conversation_store.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from agentle.agents.conversations.conversation_store import ConversationStore
88
from agentle.generations.models.messages.assistant_message import AssistantMessage
99
from agentle.generations.models.messages.developer_message import DeveloperMessage
10+
from agentle.generations.models.messages.generated_assistant_message import (
11+
GeneratedAssistantMessage,
12+
)
1013
from agentle.generations.models.messages.user_message import UserMessage
1114

1215
if TYPE_CHECKING:
@@ -84,7 +87,11 @@ async def ensure_table_exists(self) -> None:
8487
""")
8588

8689
def _message_to_dict(
87-
self, message: DeveloperMessage | UserMessage | AssistantMessage
90+
self,
91+
message: DeveloperMessage
92+
| UserMessage
93+
| AssistantMessage
94+
| GeneratedAssistantMessage[Any],
8895
) -> dict[str, Any]:
8996
"""Convert a message object to a dictionary for JSON serialization."""
9097
message_dict = message.model_dump()
@@ -119,8 +126,13 @@ def _dict_to_message(
119126
return UserMessage.model_validate(message_dict)
120127

121128
@override
122-
async def add_message_async(
123-
self, chat_id: str, message: DeveloperMessage | UserMessage | AssistantMessage
129+
async def add_message_async[T = Any](
130+
self,
131+
chat_id: str,
132+
message: DeveloperMessage
133+
| UserMessage
134+
| AssistantMessage
135+
| GeneratedAssistantMessage[T],
124136
) -> None:
125137
"""Add a message to the conversation."""
126138
async with self._pool.acquire() as conn:

0 commit comments

Comments
 (0)