Skip to content

Commit 8e2effd

Browse files
authored
feat: trace annotation dataloader (#12699)
* refactor: use dataloader for trace_annotations resolver with filter support * Fix trace annotation sort stability * Fix formatting in Trace.py to pass CI lint check * Fix mypy type error in test_TraceAnnotation.py * Fix default sort to preserve original created_at DESC behavior The dataloader refactor inadvertently changed the default sort from created_at DESC to name ASC. Restore the original default so clients that omit the sort argument still get most-recent-first ordering.
1 parent d6b0135 commit 8e2effd

3 files changed

Lines changed: 86 additions & 14 deletions

File tree

app/schema.graphql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3564,7 +3564,7 @@ type Trace implements Node {
35643564
spans(first: Int!, last: Int, after: String, before: String, rootSpansOnly: Boolean, orphanSpanAsRootSpan: Boolean = true): SpanConnection!
35653565

35663566
"""Annotations associated with the trace."""
3567-
traceAnnotations(sort: TraceAnnotationSort = null): [TraceAnnotation!]!
3567+
traceAnnotations(sort: TraceAnnotationSort, filter: AnnotationFilter = null): [TraceAnnotation!]!
35683568

35693569
"""Summarizes each annotation (by name) associated with the trace"""
35703570
traceAnnotationSummaries(filter: AnnotationFilter = null): [AnnotationSummary!]!

src/phoenix/server/api/types/Trace.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
from phoenix.server.api.context import Context
2020
from phoenix.server.api.extensions import RequireForwardPaginationExtension
2121
from phoenix.server.api.input_types.AnnotationFilter import AnnotationFilter, satisfies_filter
22-
from phoenix.server.api.input_types.TraceAnnotationSort import TraceAnnotationSort
22+
from phoenix.server.api.input_types.TraceAnnotationSort import (
23+
TraceAnnotationColumn,
24+
TraceAnnotationSort,
25+
)
2326
from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
2427
from phoenix.server.api.types.CostBreakdown import CostBreakdown
2528
from phoenix.server.api.types.pagination import (
@@ -283,19 +286,24 @@ async def spans(
283286
async def trace_annotations(
284287
self,
285288
info: Info[Context, None],
286-
sort: Optional[TraceAnnotationSort] = None,
289+
sort: Optional[TraceAnnotationSort] = UNSET,
290+
filter: Optional[AnnotationFilter] = None,
287291
) -> list[TraceAnnotation]:
288-
async with info.context.db.read() as session:
289-
stmt = select(models.TraceAnnotation).filter_by(trace_rowid=self.id)
290-
if sort:
291-
sort_col = getattr(models.TraceAnnotation, sort.col.value)
292-
if sort.dir is SortDir.desc:
293-
stmt = stmt.order_by(sort_col.desc(), models.TraceAnnotation.id.desc())
294-
else:
295-
stmt = stmt.order_by(sort_col.asc(), models.TraceAnnotation.id.asc())
296-
else:
297-
stmt = stmt.order_by(models.TraceAnnotation.created_at.desc())
298-
annotations = await session.scalars(stmt)
292+
annotations = list(await info.context.data_loaders.trace_annotations_by_trace.load(self.id))
293+
sort_key = TraceAnnotationColumn.createdAt.value
294+
sort_descending = True
295+
if filter:
296+
annotations = [
297+
annotation for annotation in annotations if satisfies_filter(annotation, filter)
298+
]
299+
if sort:
300+
sort_key = sort.col.value
301+
sort_descending = sort.dir is SortDir.desc
302+
annotations = sorted(
303+
annotations,
304+
key=lambda annotation: (getattr(annotation, sort_key), annotation.id),
305+
reverse=sort_descending,
306+
)
299307
return [
300308
TraceAnnotation(id=annotation.id, db_record=annotation) for annotation in annotations
301309
]

tests/unit/server/api/types/test_TraceAnnotation.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,67 @@ async def test_annotating_a_trace(
194194
select(models.TraceAnnotation).where(models.TraceAnnotation.id == annotation_id)
195195
)
196196
assert not orm_annotation
197+
198+
199+
async def test_trace_annotations_sort_uses_id_as_tiebreaker(
200+
gql_client: AsyncGraphQLClient,
201+
db: DbSessionFactory,
202+
project_with_a_single_trace_and_span: Any,
203+
) -> None:
204+
async with db() as session:
205+
trace_id = await session.scalar(select(models.Trace.id))
206+
assert trace_id is not None
207+
created_at = datetime.fromisoformat("2021-01-01T00:00:00.000+00:00")
208+
session.add_all(
209+
[
210+
models.TraceAnnotation(
211+
trace_rowid=trace_id,
212+
name="same-name",
213+
label="older-id",
214+
score=0.1,
215+
explanation=None,
216+
metadata_={},
217+
annotator_kind="HUMAN",
218+
created_at=created_at,
219+
updated_at=created_at,
220+
identifier="a",
221+
source="API",
222+
user_id=None,
223+
),
224+
models.TraceAnnotation(
225+
trace_rowid=trace_id,
226+
name="same-name",
227+
label="newer-id",
228+
score=0.2,
229+
explanation=None,
230+
metadata_={},
231+
annotator_kind="HUMAN",
232+
created_at=created_at,
233+
updated_at=created_at,
234+
identifier="b",
235+
source="API",
236+
user_id=None,
237+
),
238+
]
239+
)
240+
241+
response = await gql_client.execute(
242+
query="""
243+
query GetTraceAnnotations($id: ID!) {
244+
node(id: $id) {
245+
... on Trace {
246+
traceAnnotations(sort: {col: createdAt, dir: desc}) {
247+
label
248+
}
249+
}
250+
}
251+
}
252+
""",
253+
variables={"id": str(GlobalID("Trace", str(trace_id)))},
254+
)
255+
assert not response.errors
256+
assert response.data is not None
257+
assert response.data["node"]["traceAnnotations"] == [
258+
{"label": "newer-id"},
259+
{"label": "older-id"},
260+
]

0 commit comments

Comments
 (0)