Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,9 +790,9 @@ async def run_agent_stream(
# Create session (with service session support)
if config.use_service_session:
supplied_thread_id = input_data.get("thread_id") or input_data.get("threadId")
session = AgentSession(service_session_id=supplied_thread_id)
session = AgentSession(session_id=thread_id, service_session_id=supplied_thread_id)
else:
session = AgentSession()
session = AgentSession(session_id=thread_id)

# Inject metadata for AG-UI orchestration (Feature #2: Azure-safe truncation)
base_metadata: dict[str, Any] = {
Expand Down
2 changes: 2 additions & 0 deletions python/packages/ag-ui/tests/ag_ui/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __init__(
self.client = client or SimpleNamespace(function_invocation_configuration=None)
self.messages_received: list[Any] = []
self.tools_received: list[Any] | None = None
self.last_session: AgentSession | None = None

@overload
def run(
Expand Down Expand Up @@ -216,6 +217,7 @@ def run(

async def _stream() -> AsyncIterator[AgentResponseUpdate]:
self.messages_received = [] if messages is None else list(messages) # type: ignore[arg-type]
self.last_session = session
self.tools_received = kwargs.get("tools")
for update in self.updates:
yield update
Expand Down
87 changes: 87 additions & 0 deletions python/packages/ag-ui/tests/ag_ui/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1640,3 +1640,90 @@ def test_reasoning_distinct_ids_close_previous_block(self):
# close: MsgEnd(block2) + End(block2)
assert isinstance(close[0], ReasoningMessageEndEvent)
assert close[0].message_id == "block2"


async def test_session_id_matches_thread_id():
"""Session created by run_agent_stream uses the client thread_id as session_id."""
from conftest import StubAgent

from agent_framework_ag_ui import AgentFrameworkAgent

stub = StubAgent()
agent = AgentFrameworkAgent(agent=stub)

payload = {
"thread_id": "my-thread-123",
"run_id": "run-1",
"messages": [{"role": "user", "content": "Hello"}],
}

_ = [event async for event in agent.run(payload)]

assert stub.last_session is not None
assert stub.last_session.session_id == "my-thread-123"


async def test_session_id_matches_camel_case_thread_id():
"""Session uses threadId (camelCase) as session_id when snake_case is absent."""
from conftest import StubAgent

from agent_framework_ag_ui import AgentFrameworkAgent

stub = StubAgent()
agent = AgentFrameworkAgent(agent=stub)

payload = {
"threadId": "camel-thread-456",
"run_id": "run-2",
"messages": [{"role": "user", "content": "Hello"}],
}

_ = [event async for event in agent.run(payload)]

assert stub.last_session is not None
assert stub.last_session.session_id == "camel-thread-456"


async def test_session_id_matches_thread_id_with_service_session():
"""Session uses thread_id as session_id even when use_service_session is enabled."""
from conftest import StubAgent

from agent_framework_ag_ui import AgentFrameworkAgent

stub = StubAgent()
agent = AgentFrameworkAgent(agent=stub, use_service_session=True)

payload = {
"thread_id": "service-thread-789",
"run_id": "run-3",
"messages": [{"role": "user", "content": "Hello"}],
}

_ = [event async for event in agent.run(payload)]

assert stub.last_session is not None
assert stub.last_session.session_id == "service-thread-789"
assert stub.last_session.service_session_id == "service-thread-789"


async def test_session_id_generated_when_no_thread_id():
"""Session gets a generated UUID as session_id when no thread_id is provided."""
import uuid

from conftest import StubAgent

from agent_framework_ag_ui import AgentFrameworkAgent

stub = StubAgent()
agent = AgentFrameworkAgent(agent=stub)

payload = {
"run_id": "run-4",
"messages": [{"role": "user", "content": "Hello"}],
}

_ = [event async for event in agent.run(payload)]

assert stub.last_session is not None
# Should be a valid UUID (auto-generated)
uuid.UUID(stub.last_session.session_id)
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pydantic import Field

try:
import orjson
import orjson # pyright: ignore[reportMissingImports]
except ImportError:
orjson = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pydantic import Field

try:
import orjson
import orjson # pyright: ignore[reportMissingImports]
except ImportError:
orjson = None

Expand Down
Loading