Skip to content

Commit 35fe7b5

Browse files
authored
feat(openai): add token usage stream options to request (#11606)
This PR adds special casing such that any user's openai streamed chat/completion requests, unless explicitly specified otherwise, will by default include the token usage as part of the streamed response. ### Motivation OpenAI streamed responses have historically not provided token usage details as part of the streamed response. However OpenAI earlier this year added a `stream_options: {"include_usage": True}` kwarg option to the chat/completions API to provide token usage details as part of an additional stream chunk at the end of the streamed response. If this kwarg option was not specified by the user, then token usage is not provided by OpenAI and our current behavior is to give our best effort to 1) use the `tiktoken` library to calculate token counts, or 2) use a very crude heuristic to estimate token counts. Both are not ideal as neither alternative takes into account function/tool calling. **It is simpler and more accurate to just request the token counts from OpenAI directly.** ### Proposed design There are 2 major components for this feature: 1. If a user does not specify `stream_options: {"include_usage": True}` as a kwarg on the chat/completions call, we need to manually insert that as part of the kwargs before the request is made. 2. If a user does not specify `stream_options: {"include_usage": True}` as a kwarg on the chat/completions call but we add that option on the integration-side, the returned streamed response will include an additional chunk (with empty content) at the end containing token usage information. To avoid disrupting user applications with one more chunk (with different content/fields) than expected, the integration should automatically extract the last chunk under the hood. Note: if a user does explicitly specify `stream_options: {"include_usage": False}`, then we must respect their intent and avoid adding token usage into the kwargs. We'll add in our release note that we cannot guarantee 100% accurate token counts in this case.` ### Streamed reading logic change Additionally, we make a change to `__iter__/__aiter__` methods of our traced streamed responses. Previously we returned the traced streamed response (and relied on the underlying `__next__/__anext__` methods), but to ensure spans will be finished even if the streamed response is not fully consumed, we change the `__iter__/__aiter__` methods to implement the stream consumption using a try/catch/finally. Note: this only applies to 1. When users use `__iter__/__aiter__()`, since directly calling `__next__()/__anext__()` individually will not let us know when the overall response is fully consumed. 2. When users use `__aiter__()` and break early, they are still responsible for calling `resp.close()`, since asynchronous generators do not automatically close when the context manager is exited (this is held until close() is called either manually or by the garbage collector). ### Testing This PR modifies the existing OpenAI streamed completion/chat completion tests to be simplified (use snapshots when possible instead of making large numbers of tedious assertions) and to add coverage for the token extraction behavior (existing tests remove `include_usage: True` options to assert that the automatic extraction works, and we add a couple tests asserting our original behavior if `include_usage: False` is explicitly set). ## Checklist - [x] PR author has checked that all the criteria below are met - The PR description includes an overview of the change - The PR description articulates the motivation for the change - The change includes tests OR the PR description describes a testing strategy - The PR description notes risks associated with the change, if any - Newly-added code is easy to change - The change follows the [library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) - The change includes or references documentation updates if necessary - Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) ## Reviewer Checklist - [x] Reviewer has checked that all the criteria below are met - Title is accurate - All changes are related to the pull request's stated goal - Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - Testing strategy adequately addresses listed risks - Newly-added code is easy to change - Release note makes sense to a user of the library - If necessary, author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)
1 parent 5581f73 commit 35fe7b5

8 files changed

+305
-237
lines changed

ddtrace/contrib/internal/openai/_endpoint_hooks.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,14 @@ def _record_request(self, pin, integration, span, args, kwargs):
255255
span.set_tag_str("openai.request.messages.%d.content" % idx, integration.trunc(str(content)))
256256
span.set_tag_str("openai.request.messages.%d.role" % idx, str(role))
257257
span.set_tag_str("openai.request.messages.%d.name" % idx, str(name))
258+
if parse_version(OPENAI_VERSION) >= (1, 26) and kwargs.get("stream"):
259+
if kwargs.get("stream_options", {}).get("include_usage", None) is not None:
260+
# Only perform token chunk auto-extraction if this option is not explicitly set
261+
return
262+
span._set_ctx_item("_dd.auto_extract_token_chunk", True)
263+
stream_options = kwargs.get("stream_options", {})
264+
stream_options["include_usage"] = True
265+
kwargs["stream_options"] = stream_options
258266

259267
def _record_response(self, pin, integration, span, args, kwargs, resp, error):
260268
resp = super()._record_response(pin, integration, span, args, kwargs, resp, error)

ddtrace/contrib/internal/openai/utils.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,28 @@ def __exit__(self, exc_type, exc_val, exc_tb):
4848
self.__wrapped__.__exit__(exc_type, exc_val, exc_tb)
4949

5050
def __iter__(self):
51-
return self
51+
exception_raised = False
52+
try:
53+
for chunk in self.__wrapped__:
54+
self._extract_token_chunk(chunk)
55+
yield chunk
56+
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
57+
except Exception:
58+
self._dd_span.set_exc_info(*sys.exc_info())
59+
exception_raised = True
60+
raise
61+
finally:
62+
if not exception_raised:
63+
_process_finished_stream(
64+
self._dd_integration, self._dd_span, self._kwargs, self._streamed_chunks, self._is_completion
65+
)
66+
self._dd_span.finish()
67+
self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns)
5268

5369
def __next__(self):
5470
try:
5571
chunk = self.__wrapped__.__next__()
72+
self._extract_token_chunk(chunk)
5673
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
5774
return chunk
5875
except StopIteration:
@@ -68,6 +85,22 @@ def __next__(self):
6885
self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns)
6986
raise
7087

88+
def _extract_token_chunk(self, chunk):
89+
"""Attempt to extract the token chunk (last chunk in the stream) from the streamed response."""
90+
if not self._dd_span._get_ctx_item("_dd.auto_extract_token_chunk"):
91+
return
92+
choice = getattr(chunk, "choices", [None])[0]
93+
if not getattr(choice, "finish_reason", None):
94+
# Only the second-last chunk in the stream with token usage enabled will have finish_reason set
95+
return
96+
try:
97+
# User isn't expecting last token chunk to be present since it's not part of the default streamed response,
98+
# so we consume it and extract the token usage metadata before it reaches the user.
99+
usage_chunk = self.__wrapped__.__next__()
100+
self._streamed_chunks[0].insert(0, usage_chunk)
101+
except (StopIteration, GeneratorExit):
102+
return
103+
71104

72105
class TracedOpenAIAsyncStream(BaseTracedOpenAIStream):
73106
async def __aenter__(self):
@@ -77,12 +110,29 @@ async def __aenter__(self):
77110
async def __aexit__(self, exc_type, exc_val, exc_tb):
78111
await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb)
79112

80-
def __aiter__(self):
81-
return self
113+
async def __aiter__(self):
114+
exception_raised = False
115+
try:
116+
async for chunk in self.__wrapped__:
117+
await self._extract_token_chunk(chunk)
118+
yield chunk
119+
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
120+
except Exception:
121+
self._dd_span.set_exc_info(*sys.exc_info())
122+
exception_raised = True
123+
raise
124+
finally:
125+
if not exception_raised:
126+
_process_finished_stream(
127+
self._dd_integration, self._dd_span, self._kwargs, self._streamed_chunks, self._is_completion
128+
)
129+
self._dd_span.finish()
130+
self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns)
82131

83132
async def __anext__(self):
84133
try:
85134
chunk = await self.__wrapped__.__anext__()
135+
await self._extract_token_chunk(chunk)
86136
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
87137
return chunk
88138
except StopAsyncIteration:
@@ -98,6 +148,19 @@ async def __anext__(self):
98148
self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns)
99149
raise
100150

151+
async def _extract_token_chunk(self, chunk):
152+
"""Attempt to extract the token chunk (last chunk in the stream) from the streamed response."""
153+
if not self._dd_span._get_ctx_item("_dd.auto_extract_token_chunk"):
154+
return
155+
choice = getattr(chunk, "choices", [None])[0]
156+
if not getattr(choice, "finish_reason", None):
157+
return
158+
try:
159+
usage_chunk = await self.__wrapped__.__anext__()
160+
self._streamed_chunks[0].insert(0, usage_chunk)
161+
except (StopAsyncIteration, GeneratorExit):
162+
return
163+
101164

102165
def _compute_token_count(content, model):
103166
# type: (Union[str, List[int]], Optional[str]) -> Tuple[bool, int]
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
features:
3+
- |
4+
openai: Introduces automatic extraction of token usage from streamed chat completions.
5+
Unless ``stream_options: {"include_usage": False}`` is explicitly set on your streamed chat completion request,
6+
the OpenAI integration will add ``stream_options: {"include_usage": True}`` to your request and automatically extract the token usage chunk from the streamed response.

tests/contrib/openai/test_openai_llmobs.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -518,11 +518,17 @@ async def test_chat_completion_azure_async(
518518
)
519519
)
520520

521-
def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer):
521+
@pytest.mark.skipif(
522+
parse_version(openai_module.version.VERSION) < (1, 26), reason="Stream options only available openai >= 1.26"
523+
)
524+
def test_chat_completion_stream_explicit_no_tokens(
525+
self, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer
526+
):
522527
"""Ensure llmobs records are emitted for chat completion endpoints when configured.
523528
524529
Also ensure the llmobs records have the correct tagging including trace/span ID for trace correlation.
525530
"""
531+
526532
with get_openai_vcr(subdirectory_name="v1").use_cassette("chat_completion_streamed.yaml"):
527533
with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding:
528534
with mock.patch("ddtrace.contrib.internal.openai.utils._est_tokens") as mock_est:
@@ -534,7 +540,11 @@ def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs
534540
expected_completion = "The Los Angeles Dodgers won the World Series in 2020."
535541
client = openai.OpenAI()
536542
resp = client.chat.completions.create(
537-
model=model, messages=input_messages, stream=True, user="ddtrace-test"
543+
model=model,
544+
messages=input_messages,
545+
stream=True,
546+
user="ddtrace-test",
547+
stream_options={"include_usage": False},
538548
)
539549
for chunk in resp:
540550
resp_model = chunk.model
@@ -547,7 +557,7 @@ def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs
547557
model_provider="openai",
548558
input_messages=input_messages,
549559
output_messages=[{"content": expected_completion, "role": "assistant"}],
550-
metadata={"stream": True, "user": "ddtrace-test"},
560+
metadata={"stream": True, "stream_options": {"include_usage": False}, "user": "ddtrace-test"},
551561
token_metrics={"input_tokens": 8, "output_tokens": 8, "total_tokens": 16},
552562
tags={"ml_app": "<ml-app-name>", "service": "tests.contrib.openai"},
553563
)
@@ -557,20 +567,14 @@ def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs
557567
parse_version(openai_module.version.VERSION) < (1, 26, 0), reason="Streamed tokens available in 1.26.0+"
558568
)
559569
def test_chat_completion_stream_tokens(self, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer):
560-
"""
561-
Ensure llmobs records are emitted for chat completion endpoints when configured
562-
with the `stream_options={"include_usage": True}`.
563-
Also ensure the llmobs records have the correct tagging including trace/span ID for trace correlation.
564-
"""
570+
"""Assert that streamed token chunk extraction logic works when options are not explicitly passed from user."""
565571
with get_openai_vcr(subdirectory_name="v1").use_cassette("chat_completion_streamed_tokens.yaml"):
566572
model = "gpt-3.5-turbo"
567573
resp_model = model
568574
input_messages = [{"role": "user", "content": "Who won the world series in 2020?"}]
569575
expected_completion = "The Los Angeles Dodgers won the World Series in 2020."
570576
client = openai.OpenAI()
571-
resp = client.chat.completions.create(
572-
model=model, messages=input_messages, stream=True, stream_options={"include_usage": True}
573-
)
577+
resp = client.chat.completions.create(model=model, messages=input_messages, stream=True)
574578
for chunk in resp:
575579
resp_model = chunk.model
576580
span = mock_tracer.pop_traces()[0][0]
@@ -671,7 +675,6 @@ def test_chat_completion_tool_call_stream(self, openai, ddtrace_global_config, m
671675
messages=[{"role": "user", "content": chat_completion_input_description}],
672676
user="ddtrace-test",
673677
stream=True,
674-
stream_options={"include_usage": True},
675678
)
676679
for chunk in resp:
677680
resp_model = chunk.model

0 commit comments

Comments
 (0)