Skip to content

Commit b2fff76

Browse files
committed
Fix helper chunk extraction method to be async compatible
1 parent 23e706b commit b2fff76

File tree

1 file changed

+30
-18
lines changed

1 file changed

+30
-18
lines changed

ddtrace/contrib/internal/openai/utils.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,6 @@ def __init__(self, wrapped, integration, span, kwargs, is_completion=False):
3838
self._is_completion = is_completion
3939
self._kwargs = kwargs
4040

41-
def _extract_token_chunk(self, chunk):
42-
"""Attempt to extract the token chunk (last chunk in the stream) from the streamed response."""
43-
if not self._dd_span._get_ctx_item("openai_stream_magic"):
44-
return
45-
choice = getattr(chunk, "choices", [None])[0]
46-
if not getattr(choice, "finish_reason", None):
47-
return
48-
try:
49-
usage_chunk = next(self.__wrapped__)
50-
self._streamed_chunks[0].insert(0, usage_chunk)
51-
except (StopIteration, GeneratorExit):
52-
pass
53-
5441

5542
class TracedOpenAIStream(BaseTracedOpenAIStream):
5643
def __enter__(self):
@@ -98,6 +85,18 @@ def __next__(self):
9885
self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns)
9986
raise
10087

88+
def _extract_token_chunk(self, chunk):
89+
"""Attempt to extract the token chunk (last chunk in the stream) from the streamed response."""
90+
if not self._dd_span._get_ctx_item("openai_stream_magic"):
91+
return
92+
choice = getattr(chunk, "choices", [None])[0]
93+
if not getattr(choice, "finish_reason", None):
94+
return
95+
try:
96+
usage_chunk = next(self)
97+
self._streamed_chunks[0].insert(0, usage_chunk)
98+
except (StopIteration, GeneratorExit):
99+
return
101100

102101
class TracedOpenAIAsyncStream(BaseTracedOpenAIStream):
103102
async def __aenter__(self):
@@ -107,11 +106,11 @@ async def __aenter__(self):
107106
async def __aexit__(self, exc_type, exc_val, exc_tb):
108107
await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb)
109108

110-
def __aiter__(self):
109+
async def __aiter__(self):
111110
exception_raised = False
112111
try:
113-
for chunk in self.__wrapped__:
114-
self._extract_token_chunk(chunk)
112+
async for chunk in self.__wrapped__:
113+
await self._extract_token_chunk(chunk)
115114
yield chunk
116115
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
117116
except Exception:
@@ -128,8 +127,8 @@ def __aiter__(self):
128127

129128
async def __anext__(self):
130129
try:
131-
chunk = await self.__wrapped__.__anext__()
132-
self._extract_token_chunk(chunk)
130+
chunk = await anext(self.__wrapped__)
131+
await self._extract_token_chunk(chunk)
133132
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
134133
return chunk
135134
except StopAsyncIteration:
@@ -145,6 +144,19 @@ async def __anext__(self):
145144
self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns)
146145
raise
147146

147+
async def _extract_token_chunk(self, chunk):
148+
"""Attempt to extract the token chunk (last chunk in the stream) from the streamed response."""
149+
if not self._dd_span._get_ctx_item("openai_stream_magic"):
150+
return
151+
choice = getattr(chunk, "choices", [None])[0]
152+
if not getattr(choice, "finish_reason", None):
153+
return
154+
try:
155+
usage_chunk = await anext(self)
156+
self._streamed_chunks[0].insert(0, usage_chunk)
157+
except (StopAsyncIteration, GeneratorExit):
158+
return
159+
148160

149161
def _compute_token_count(content, model):
150162
# type: (Union[str, List[int]], Optional[str]) -> Tuple[bool, int]

0 commit comments

Comments
 (0)