Skip to content

Commit 63484f9

Browse files
authored
feat: add thinking_budget (#6089)
1 parent dff0ab9 commit 63484f9

File tree

9 files changed

+196
-5
lines changed

9 files changed

+196
-5
lines changed

docs/backend/sampling_params.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ Please refer to our dedicated guide on [constrained decoding](./structured_outpu
6464
| ignore_eos | `bool = False` | Don't stop generation when EOS token is sampled. |
6565
| skip_special_tokens | `bool = True` | Remove special tokens during decoding. |
6666
| custom_params | `Optional[List[Optional[Dict[str, Any]]]] = None` | Used when employing `CustomLogitProcessor`. For usage, see below. |
67+
| thinking_budget | `Optional[int] = None` | The maximum number of reasoning tokens that can be generated for a request. |
6768

6869
## Examples
6970

@@ -296,3 +297,29 @@ response = requests.post(
296297
)
297298
print(response.json())
298299
```
300+
301+
### Thinking Budget
302+
303+
Launch a server with `--reasoning-parser`.
304+
305+
```bash
306+
python3 -m sglang.launch_server --model Qwen/Qwen3-8B --reasoning-parser qwen3
307+
```
308+
309+
Send a request:
310+
311+
```python
312+
import requests
313+
response = requests.post(
314+
"http://localhost:30000/generate",
315+
json={
316+
"text": "9.11 and 9.8, which is greater?",
317+
"sampling_params": {
318+
"temperature": 0.3,
319+
"max_new_tokens": 256,
320+
"thinking_budget": 20,
321+
},
322+
},
323+
)
324+
print(response.json())
325+
```

python/sglang/srt/model_executor/model_runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1145,7 +1145,9 @@ def sample(
11451145
[self.sample(values, forward_batch) for values in logits_output],
11461146
axis=-1,
11471147
)
1148-
1148+
sampling_info = forward_batch.sampling_info
1149+
if sampling_info.thinking_budgets is not None:
1150+
sampling_info.apply_thinking_budgets(logits_output.next_token_logits)
11491151
self._preprocess_logits(logits_output, forward_batch.sampling_info)
11501152

11511153
# Sample the next tokens
@@ -1156,6 +1158,8 @@ def sample(
11561158
forward_batch.top_logprobs_nums,
11571159
forward_batch.token_ids_logprobs,
11581160
)
1161+
if sampling_info.thinking_budgets is not None:
1162+
sampling_info.update_thinking_budgets(next_token_ids)
11591163
return next_token_ids
11601164

11611165
@property

python/sglang/srt/openai_api/adapter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,7 @@ def v1_generate_request(
529529
"temperature": request.temperature,
530530
"max_new_tokens": request.max_tokens,
531531
"min_new_tokens": request.min_tokens,
532+
"thinking_budget": request.thinking_budget,
532533
"stop": request.stop,
533534
"stop_token_ids": request.stop_token_ids,
534535
"top_p": request.top_p,
@@ -1101,6 +1102,7 @@ def v1_chat_generate_request(
11011102
"temperature": request.temperature,
11021103
"max_new_tokens": request.max_tokens or request.max_completion_tokens,
11031104
"min_new_tokens": request.min_tokens,
1105+
"thinking_budget": request.thinking_budget,
11041106
"stop": stop,
11051107
"stop_token_ids": request.stop_token_ids,
11061108
"top_p": request.top_p,

python/sglang/srt/openai_api/protocol.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ class CompletionRequest(BaseModel):
172172
top_k: int = -1
173173
min_p: float = 0.0
174174
min_tokens: int = 0
175+
thinking_budget: Optional[int] = None
175176
json_schema: Optional[str] = None
176177
regex: Optional[str] = None
177178
ebnf: Optional[str] = None
@@ -350,6 +351,13 @@ class ChatCompletionRequest(BaseModel):
350351
description="The maximum number of completion tokens for a chat completion request, "
351352
"including visible output tokens and reasoning tokens. Input tokens are not included. ",
352353
)
354+
thinking_budget: Optional[int] = Field(
355+
default=None,
356+
description="The maximum number of reasoning tokens that can be generated for a request. "
357+
"This setting of does not affect the thinking process of models. "
358+
"If the number of tokens generated by the model's thinking process exceeds thinking_budget, "
359+
"the reasoning content will be truncated and the final response content will be generated immediately.",
360+
)
353361
n: int = 1
354362
presence_penalty: float = 0.0
355363
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None

python/sglang/srt/reasoning_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def detect_and_parse(self, text: str) -> StreamingParseResult:
3232
One-time parsing: Detects and parses reasoning sections in the provided text.
3333
Returns both reasoning content and normal text separately.
3434
"""
35-
text = text.replace(self.think_start_token, "").strip()
35+
text = text.replace(self.think_start_token, "")
3636
if self.think_end_token not in text:
3737
# Assume reasoning was truncated before `</think>` token
3838
return StreamingParseResult(reasoning_text=text)
@@ -73,7 +73,7 @@ def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
7373
normal_text = current_text[end_idx + len(self.think_end_token) :]
7474

7575
return StreamingParseResult(
76-
normal_text=normal_text, reasoning_text=reasoning_text.rstrip()
76+
normal_text=normal_text, reasoning_text=reasoning_text
7777
)
7878

7979
# Continue with reasoning content

python/sglang/srt/sampling/sampling_batch_info.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,13 @@ class SamplingBatchInfo:
3030
# Whether any request needs min_p sampling
3131
need_min_p_sampling: bool
3232

33+
# Use thinking_budget to truncate thinking
34+
num_thinking_tokens: Optional[torch.Tensor] = None
35+
think_end_ids: Optional[torch.Tensor] = None
36+
thinking_budgets: Optional[torch.Tensor] = None
37+
3338
# Masking tensors for grammar-guided structured outputs
34-
vocab_size: int
39+
vocab_size: int = 0
3540
grammars: Optional[List] = None
3641
vocab_mask: Optional[torch.Tensor] = None
3742
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
@@ -76,7 +81,22 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
7681
min_ps = torch.tensor(
7782
[r.sampling_params.min_p for r in reqs], dtype=torch.float
7883
).to(device, non_blocking=True)
79-
84+
if any(hasattr(r.tokenizer, "think_end_id") for r in reqs):
85+
think_end_ids = torch.tensor(
86+
[getattr(r.tokenizer, "think_end_id", -1) for r in reqs],
87+
dtype=torch.int64,
88+
).to(device, non_blocking=True)
89+
num_thinking_tokens = torch.tensor([0 for _ in reqs], dtype=torch.int64).to(
90+
device, non_blocking=True
91+
)
92+
thinking_budgets = torch.tensor(
93+
[r.sampling_params.thinking_budget or -1 for r in reqs],
94+
dtype=torch.int64,
95+
).to(device, non_blocking=True)
96+
else:
97+
think_end_ids = None
98+
num_thinking_tokens = None
99+
thinking_budgets = None
80100
# Check if any request has custom logit processor
81101
has_custom_logit_processor = (
82102
batch.enable_custom_logit_processor # check the flag first.
@@ -132,6 +152,9 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
132152
top_ps=top_ps,
133153
top_ks=top_ks,
134154
min_ps=min_ps,
155+
think_end_ids=think_end_ids,
156+
num_thinking_tokens=num_thinking_tokens,
157+
thinking_budgets=thinking_budgets,
135158
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
136159
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
137160
vocab_size=vocab_size,
@@ -146,6 +169,35 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
146169
def __len__(self):
147170
return len(self.temperatures)
148171

172+
def apply_thinking_budgets(self, next_token_logits: torch.Tensor):
173+
has_budget = self.thinking_budgets > 0
174+
if not has_budget.any():
175+
return
176+
torch.where(
177+
has_budget,
178+
self.num_thinking_tokens + 1,
179+
self.num_thinking_tokens,
180+
out=self.num_thinking_tokens,
181+
)
182+
should_stop = has_budget & (
183+
self.num_thinking_tokens - 1 > self.thinking_budgets
184+
)
185+
next_token_logits.masked_fill_(should_stop.unsqueeze(0), float("-inf"))
186+
batch_indices = torch.nonzero(should_stop, as_tuple=True)[0]
187+
if len(batch_indices) > 0:
188+
end_token_indices = self.think_end_ids[batch_indices]
189+
next_token_logits[batch_indices, end_token_indices] = 0.0
190+
191+
def update_thinking_budgets(self, next_token_ids: torch.Tensor):
192+
if not torch.any(self.thinking_budgets > 0):
193+
return
194+
torch.where(
195+
next_token_ids == self.think_end_ids,
196+
torch.tensor(-1, device=self.thinking_budgets.device),
197+
self.thinking_budgets,
198+
out=self.thinking_budgets,
199+
)
200+
149201
def update_regex_vocab_mask(self):
150202
if not self.grammars:
151203
self.vocab_mask = None

python/sglang/srt/sampling/sampling_params.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class SamplingParams:
3030
def __init__(
3131
self,
3232
max_new_tokens: int = 128,
33+
thinking_budget: Optional[int] = None,
3334
stop: Optional[Union[str, List[str]]] = None,
3435
stop_token_ids: Optional[List[int]] = None,
3536
temperature: float = 1.0,
@@ -57,6 +58,7 @@ def __init__(
5758
self.stop_token_ids = set(stop_token_ids)
5859
else:
5960
self.stop_token_ids = None
61+
self.thinking_budget = thinking_budget
6062
self.temperature = temperature
6163
self.top_p = top_p
6264
self.top_k = top_k

test/srt/run_suite.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class TestFile:
6161
TestFile("test_radix_attention.py", 167),
6262
TestFile("test_reasoning_content.py", 89),
6363
TestFile("test_enable_thinking.py", 70),
64+
TestFile("test_thinking_budget.py", 60),
6465
TestFile("test_regex_constrained.py", 64),
6566
TestFile("test_release_memory_occupation.py", 44),
6667
TestFile("test_request_length_validation.py", 31),

test/srt/test_thinking_budget.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""
2+
Usage:
3+
python3 -m unittest test_thinking_budget.TestThinkingBudget.test_chat_completion_with_thinking_budget_20
4+
python3 -m unittest test_thinking_budget.TestThinkingBudget.test_chat_completion_with_thinking_budget_200
5+
"""
6+
7+
import unittest
8+
9+
import requests
10+
from transformers import AutoTokenizer
11+
12+
from sglang.srt.utils import kill_process_tree
13+
from sglang.test.test_utils import (
14+
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
15+
DEFAULT_URL_FOR_TEST,
16+
CustomTestCase,
17+
popen_launch_server,
18+
)
19+
20+
21+
class TestThinkingBudget(CustomTestCase):
22+
@classmethod
23+
def setUpClass(cls):
24+
cls.model = "Qwen/Qwen3-8B"
25+
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model)
26+
cls.base_url = DEFAULT_URL_FOR_TEST
27+
cls.api_key = "sk-1234"
28+
cls.process = popen_launch_server(
29+
cls.model,
30+
cls.base_url,
31+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
32+
api_key=cls.api_key,
33+
other_args=[
34+
"--reasoning-parser",
35+
"qwen3",
36+
],
37+
)
38+
39+
@classmethod
40+
def tearDownClass(cls):
41+
kill_process_tree(cls.process.pid)
42+
43+
def test_chat_completion_with_thinking_budget_20(self):
44+
response = requests.post(
45+
f"{self.base_url}/v1/chat/completions",
46+
headers={"Authorization": f"Bearer {self.api_key}"},
47+
json={
48+
"model": self.model,
49+
"messages": [
50+
{"role": "user", "content": "9.11 and 9.8, which is greater?"}
51+
],
52+
"temperature": 0,
53+
"separate_reasoning": True,
54+
"chat_template_kwargs": {"enable_thinking": True},
55+
"thinking_budget": 20,
56+
},
57+
)
58+
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
59+
data = response.json()
60+
reasoning_content = data["choices"][0]["message"]["reasoning_content"]
61+
tokens = self.tokenizer.encode(reasoning_content)
62+
self.assertEqual(
63+
len(tokens),
64+
20,
65+
f"Reasoning content length: {len(tokens)} not equal to 20, tokens: {tokens}, reasoning_content: {reasoning_content}",
66+
)
67+
68+
def test_chat_completion_with_thinking_budget_200(self):
69+
response = requests.post(
70+
f"{self.base_url}/v1/chat/completions",
71+
headers={"Authorization": f"Bearer {self.api_key}"},
72+
json={
73+
"model": self.model,
74+
"messages": [
75+
{"role": "user", "content": "9.11 and 9.8, which is greater?"}
76+
],
77+
"temperature": 0,
78+
"separate_reasoning": True,
79+
"chat_template_kwargs": {"enable_thinking": True},
80+
"thinking_budget": 200,
81+
},
82+
)
83+
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
84+
data = response.json()
85+
reasoning_content = data["choices"][0]["message"]["reasoning_content"]
86+
tokens = self.tokenizer.encode(reasoning_content)
87+
self.assertEqual(
88+
len(tokens),
89+
200,
90+
f"Reasoning content length {len(tokens)} not equal to 200, tokens: {tokens}, reasoning_content: {reasoning_content}",
91+
)
92+
93+
94+
if __name__ == "__main__":
95+
unittest.main()

0 commit comments

Comments
 (0)