Skip to content

Commit 94d63ce

Browse files
committed
refactor(tracing): Migrate tracing models from dataclasses to Pydantic for improved type safety
- Replace dataclass implementations with Pydantic BaseModel in tracing_models.py - Remove manual to_dict() methods in favor of Pydantic's built-in model_dump() functionality - Simplify attribute access and type handling in responder.py - Update method calls to use model_dump() and to_api_dict() for consistent serialization - Improve type safety and reduce manual type conversion code - Remove redundant attribute checks and simplify conditional logic
1 parent f73ffe8 commit 94d63ce

File tree

2 files changed

+28
-77
lines changed

2 files changed

+28
-77
lines changed

agentle/responses/responder.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -974,27 +974,18 @@ def _prepare_trace_input_data(
974974
tools_count = len(tools)
975975

976976
# Extract structured output information
977-
has_structured_output = (
978-
create_response.text is not None
979-
and hasattr(create_response.text, "format")
980-
and create_response.text.format is not None
981-
and hasattr(create_response.text.format, "type")
982-
and create_response.text.format.type == "json_schema"
983-
)
977+
has_structured_output = False
978+
if create_response.text is not None and create_response.text.format is not None:
979+
format_type = getattr(create_response.text.format, "type", None)
980+
has_structured_output = format_type == "json_schema"
984981

985982
# Extract reasoning information
986983
reasoning_enabled = create_response.reasoning is not None
987984
reasoning_effort: str | None = None
988-
if create_response.reasoning is not None and hasattr(
989-
create_response.reasoning, "effort"
990-
):
985+
if create_response.reasoning is not None:
991986
effort_value = create_response.reasoning.effort
992987
if effort_value is not None:
993-
reasoning_effort = (
994-
effort_value.value
995-
if hasattr(effort_value, "value")
996-
else str(effort_value)
997-
)
988+
reasoning_effort = effort_value.value
998989

999990
return TraceInputData(
1000991
input=input_data,
@@ -1191,8 +1182,8 @@ async def _create_tracing_contexts(
11911182
# Create trace context
11921183
trace_gen = client.trace_context(
11931184
name="responder_api_call",
1194-
input_data=input_data.to_dict(),
1195-
metadata=metadata.to_dict(),
1185+
input_data=input_data.model_dump(),
1186+
metadata=metadata.to_api_dict(),
11961187
)
11971188
trace_ctx = await trace_gen.__anext__()
11981189

@@ -1205,8 +1196,8 @@ async def _create_tracing_contexts(
12051196
name="response_generation",
12061197
model=model,
12071198
provider=metadata.provider,
1208-
input_data=input_data.to_dict(),
1209-
metadata=metadata.to_dict(),
1199+
input_data=input_data.model_dump(),
1200+
metadata=metadata.to_api_dict(),
12101201
)
12111202
generation_ctx = await generation_gen.__anext__()
12121203

@@ -1350,10 +1341,10 @@ async def _update_tracing_success(
13501341
}
13511342

13521343
if usage_details:
1353-
metadata["usage"] = usage_details.to_dict()
1344+
metadata["usage"] = usage_details.model_dump()
13541345

13551346
if cost_details:
1356-
metadata["cost"] = cost_details.to_dict()
1347+
metadata["cost"] = cost_details.model_dump()
13571348

13581349
# Update contexts for each client
13591350
for ctx in active_contexts:
@@ -1369,10 +1360,10 @@ async def _update_tracing_success(
13691360
await ctx.client.update_generation(
13701361
ctx.generation_ctx,
13711362
output_data=output_data,
1372-
usage_details=usage_details.to_dict()
1363+
usage_details=usage_details.model_dump()
13731364
if usage_details
13741365
else None,
1375-
cost_details=cost_details.to_dict()
1366+
cost_details=cost_details.model_dump()
13761367
if cost_details
13771368
else None,
13781369
metadata=metadata,
@@ -1457,7 +1448,7 @@ async def _update_tracing_error(
14571448
logger.debug(f"Latency until error: {latency:.3f}s")
14581449

14591450
# Prepare error metadata
1460-
error_metadata = metadata.to_dict()
1451+
error_metadata = metadata.to_api_dict()
14611452
error_metadata.update(
14621453
{
14631454
"error_type": type(error).__name__,

agentle/responses/tracing_models.py

Lines changed: 13 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
from __future__ import annotations
44

55
from collections.abc import AsyncGenerator
6-
from dataclasses import dataclass
76
from typing import Any
87

8+
from pydantic import BaseModel, Field, computed_field
9+
910
from agentle.generations.tracing.otel_client import GenerationContext, TraceContext
1011
from agentle.generations.tracing.otel_client_type import OtelClientType
1112

1213

13-
@dataclass
14-
class TraceInputData:
14+
class TraceInputData(BaseModel):
1515
"""Structured input data for trace context."""
1616

1717
input: str | list[dict[str, Any]] | None
@@ -26,34 +26,17 @@ class TraceInputData:
2626
max_output_tokens: int | None
2727
stream: bool
2828

29-
def to_dict(self) -> dict[str, Any]:
30-
"""Convert to dictionary for API calls."""
31-
return {
32-
"input": self.input,
33-
"model": self.model,
34-
"has_tools": self.has_tools,
35-
"tools_count": self.tools_count,
36-
"has_structured_output": self.has_structured_output,
37-
"reasoning_enabled": self.reasoning_enabled,
38-
"reasoning_effort": self.reasoning_effort,
39-
"temperature": self.temperature,
40-
"top_p": self.top_p,
41-
"max_output_tokens": self.max_output_tokens,
42-
"stream": self.stream,
43-
}
44-
4529

46-
@dataclass
47-
class TraceMetadata:
30+
class TraceMetadata(BaseModel):
4831
"""Metadata for trace context."""
4932

5033
model: str
5134
provider: str
5235
base_url: str
53-
custom_metadata: dict[str, Any]
36+
custom_metadata: dict[str, Any] = Field(default_factory=dict)
5437

55-
def to_dict(self) -> dict[str, Any]:
56-
"""Convert to dictionary for API calls."""
38+
def to_api_dict(self) -> dict[str, Any]:
39+
"""Convert to dictionary for API calls, merging custom_metadata."""
5740
result = {
5841
"model": self.model,
5942
"provider": self.provider,
@@ -63,8 +46,7 @@ def to_dict(self) -> dict[str, Any]:
6346
return result
6447

6548

66-
@dataclass
67-
class UsageDetails:
49+
class UsageDetails(BaseModel):
6850
"""Token usage details from API response."""
6951

7052
input: int
@@ -73,21 +55,8 @@ class UsageDetails:
7355
unit: str
7456
reasoning_tokens: int | None = None
7557

76-
def to_dict(self) -> dict[str, Any]:
77-
"""Convert to dictionary for API calls."""
78-
result = {
79-
"input": self.input,
80-
"output": self.output,
81-
"total": self.total,
82-
"unit": self.unit,
83-
}
84-
if self.reasoning_tokens is not None and self.reasoning_tokens > 0:
85-
result["reasoning_tokens"] = self.reasoning_tokens
86-
return result
87-
8858

89-
@dataclass
90-
class CostDetails:
59+
class CostDetails(BaseModel):
9160
"""Cost calculation details."""
9261

9362
input: float
@@ -97,28 +66,19 @@ class CostDetails:
9766
input_tokens: int
9867
output_tokens: int
9968

100-
def to_dict(self) -> dict[str, Any]:
101-
"""Convert to dictionary for API calls."""
102-
return {
103-
"input": self.input,
104-
"output": self.output,
105-
"total": self.total,
106-
"currency": self.currency,
107-
"input_tokens": self.input_tokens,
108-
"output_tokens": self.output_tokens,
109-
}
11069

111-
112-
@dataclass
113-
class TracingContext:
70+
class TracingContext(BaseModel):
11471
"""Container for a single client's tracing contexts."""
11572

73+
model_config = {"arbitrary_types_allowed": True}
74+
11675
client: OtelClientType
11776
trace_gen: AsyncGenerator[TraceContext | None, None]
11877
trace_ctx: TraceContext | None
11978
generation_gen: AsyncGenerator[GenerationContext | None, None]
12079
generation_ctx: GenerationContext | None
12180

81+
@computed_field
12282
@property
12383
def client_name(self) -> str:
12484
"""Get the client class name for logging."""

0 commit comments

Comments
 (0)