@@ -38,19 +38,6 @@ def __init__(self, wrapped, integration, span, kwargs, is_completion=False):
38
38
self ._is_completion = is_completion
39
39
self ._kwargs = kwargs
40
40
41
- def _extract_token_chunk (self , chunk ):
42
- """Attempt to extract the token chunk (last chunk in the stream) from the streamed response."""
43
- if not self ._dd_span ._get_ctx_item ("openai_stream_magic" ):
44
- return
45
- choice = getattr (chunk , "choices" , [None ])[0 ]
46
- if not getattr (choice , "finish_reason" , None ):
47
- return
48
- try :
49
- usage_chunk = next (self .__wrapped__ )
50
- self ._streamed_chunks [0 ].insert (0 , usage_chunk )
51
- except (StopIteration , GeneratorExit ):
52
- pass
53
-
54
41
55
42
class TracedOpenAIStream (BaseTracedOpenAIStream ):
56
43
def __enter__ (self ):
@@ -98,6 +85,18 @@ def __next__(self):
98
85
self ._dd_integration .metric (self ._dd_span , "dist" , "request.duration" , self ._dd_span .duration_ns )
99
86
raise
100
87
88
+ def _extract_token_chunk (self , chunk ):
89
+ """Attempt to extract the token chunk (last chunk in the stream) from the streamed response."""
90
+ if not self ._dd_span ._get_ctx_item ("openai_stream_magic" ):
91
+ return
92
+ choice = getattr (chunk , "choices" , [None ])[0 ]
93
+ if not getattr (choice , "finish_reason" , None ):
94
+ return
95
+ try :
96
+ usage_chunk = next (self )
97
+ self ._streamed_chunks [0 ].insert (0 , usage_chunk )
98
+ except (StopIteration , GeneratorExit ):
99
+ return
101
100
102
101
class TracedOpenAIAsyncStream (BaseTracedOpenAIStream ):
103
102
async def __aenter__ (self ):
@@ -107,11 +106,11 @@ async def __aenter__(self):
107
106
async def __aexit__ (self , exc_type , exc_val , exc_tb ):
108
107
await self .__wrapped__ .__aexit__ (exc_type , exc_val , exc_tb )
109
108
110
- def __aiter__ (self ):
109
+ async def __aiter__ (self ):
111
110
exception_raised = False
112
111
try :
113
- for chunk in self .__wrapped__ :
114
- self ._extract_token_chunk (chunk )
112
+ async for chunk in self .__wrapped__ :
113
+ await self ._extract_token_chunk (chunk )
115
114
yield chunk
116
115
_loop_handler (self ._dd_span , chunk , self ._streamed_chunks )
117
116
except Exception :
@@ -128,8 +127,8 @@ def __aiter__(self):
128
127
129
128
async def __anext__ (self ):
130
129
try :
131
- chunk = await self .__wrapped__ . __anext__ ( )
132
- self ._extract_token_chunk (chunk )
130
+ chunk = await anext ( self .__wrapped__ )
131
+ await self ._extract_token_chunk (chunk )
133
132
_loop_handler (self ._dd_span , chunk , self ._streamed_chunks )
134
133
return chunk
135
134
except StopAsyncIteration :
@@ -145,6 +144,19 @@ async def __anext__(self):
145
144
self ._dd_integration .metric (self ._dd_span , "dist" , "request.duration" , self ._dd_span .duration_ns )
146
145
raise
147
146
147
+ async def _extract_token_chunk (self , chunk ):
148
+ """Attempt to extract the token chunk (last chunk in the stream) from the streamed response."""
149
+ if not self ._dd_span ._get_ctx_item ("openai_stream_magic" ):
150
+ return
151
+ choice = getattr (chunk , "choices" , [None ])[0 ]
152
+ if not getattr (choice , "finish_reason" , None ):
153
+ return
154
+ try :
155
+ usage_chunk = await anext (self )
156
+ self ._streamed_chunks [0 ].insert (0 , usage_chunk )
157
+ except (StopAsyncIteration , GeneratorExit ):
158
+ return
159
+
148
160
149
161
def _compute_token_count (content , model ):
150
162
# type: (Union[str, List[int]], Optional[str]) -> Tuple[bool, int]
0 commit comments