Skip to content

feat: add thinking_budget (version 2) #6208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/backend/sampling_params.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Please refer to our dedicated guide on [constrained decoding](./structured_outpu
| ignore_eos | `bool = False` | Don't stop generation when EOS token is sampled. |
| skip_special_tokens | `bool = True` | Remove special tokens during decoding. |
| custom_params | `Optional[List[Optional[Dict[str, Any]]]] = None` | Used when employing `CustomLogitProcessor`. For usage, see below. |
| thinking_budget | `Optional[int] = None` | The maximum number of reasoning tokens that can be generated for a request. |

## Examples

Expand Down Expand Up @@ -296,3 +297,29 @@ response = requests.post(
)
print(response.json())
```

### Thinking Budget

Launch a server with `--reasoning-parser`.

```bash
python3 -m sglang.launch_server --model Qwen/Qwen3-8B --reasoning-parser qwen3
```

Send a request:

```python
import requests
response = requests.post(
"http://localhost:30000/generate",
json={
"text": "9.11 and 9.8, which is greater?",
"sampling_params": {
"temperature": 0.3,
"max_new_tokens": 256,
"thinking_budget": 20,
},
},
)
print(response.json())
```
6 changes: 5 additions & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,7 +1271,9 @@ def sample(
[self.sample(values, forward_batch) for values in logits_output],
axis=-1,
)

sampling_info = forward_batch.sampling_info
if sampling_info.thinking_budgets is not None:
sampling_info.apply_thinking_budgets(logits_output.next_token_logits)
self._preprocess_logits(logits_output, forward_batch.sampling_info)

# Sample the next tokens
Expand All @@ -1282,6 +1284,8 @@ def sample(
forward_batch.top_logprobs_nums,
forward_batch.token_ids_logprobs,
)
if sampling_info.thinking_budgets is not None:
sampling_info.update_thinking_budgets(next_token_ids)
return next_token_ids

@property
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/openai_api/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,10 +1128,16 @@ def v1_chat_generate_request(
lora_paths.append(request.lora_path)
prompts.append(prompt)

thinking_budget = request.thinking_budget
enable_thinking = _get_enable_thinking_from_request(request)
if not enable_thinking:
thinking_budget = None

sampling_params = {
"temperature": request.temperature,
"max_new_tokens": request.max_tokens or request.max_completion_tokens,
"min_new_tokens": request.min_tokens,
"thinking_budget": thinking_budget,
"stop": stop,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/openai_api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def set_tool_choice_default(cls, values):
top_k: int = -1
min_p: float = 0.0
min_tokens: int = 0
thinking_budget: Optional[int] = None
regex: Optional[str] = None
ebnf: Optional[str] = None
repetition_penalty: float = 1.0
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/reasoning_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def detect_and_parse(self, text: str) -> StreamingParseResult:
One-time parsing: Detects and parses reasoning sections in the provided text.
Returns both reasoning content and normal text separately.
"""
text = text.replace(self.think_start_token, "").strip()
text = text.replace(self.think_start_token, "")
if self.think_end_token not in text:
# Assume reasoning was truncated before `</think>` token
return StreamingParseResult(reasoning_text=text)
Expand Down Expand Up @@ -73,7 +73,7 @@ def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
normal_text = current_text[end_idx + len(self.think_end_token) :]

return StreamingParseResult(
normal_text=normal_text, reasoning_text=reasoning_text.rstrip()
normal_text=normal_text, reasoning_text=reasoning_text
)

# Continue with reasoning content
Expand Down
65 changes: 63 additions & 2 deletions python/sglang/srt/sampling/sampling_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,13 @@ class SamplingBatchInfo:
# Whether any request needs min_p sampling
need_min_p_sampling: bool

# Use thinking_budget to truncate thinking
num_thinking_tokens: Optional[torch.Tensor] = None
think_end_ids: Optional[torch.Tensor] = None
thinking_budgets: Optional[torch.Tensor] = None

# Masking tensors for grammar-guided structured outputs
vocab_size: int
vocab_size: int = 0
grammars: Optional[List] = None
vocab_mask: Optional[torch.Tensor] = None
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
Expand Down Expand Up @@ -76,7 +81,29 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float
).to(device, non_blocking=True)

if any(hasattr(r.tokenizer, "think_end_id") for r in reqs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For non-reasoning model, we can skip this check? Do we need to identify if the model is a reasoning model from the architect?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking about it. Now I don't know how to decide whether a model is reasoning from a request.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can think of, there seems to be no better way. The information about can be obtained here is very limited. It is just doing sampling and does not know the specific model architecture.

think_end_ids = torch.tensor(
[getattr(r.tokenizer, "think_end_id", -1) for r in reqs],
dtype=torch.int64,
).to(device, non_blocking=True)
num_thinking_tokens = torch.tensor([0 for _ in reqs], dtype=torch.int64).to(
device, non_blocking=True
)
thinking_budgets = torch.tensor(
[
(
r.sampling_params.thinking_budget
if r.sampling_params.thinking_budget is not None
else -1
)
for r in reqs
],
dtype=torch.int64,
).to(device, non_blocking=True)
else:
think_end_ids = None
num_thinking_tokens = None
thinking_budgets = None
# Check if any request has custom logit processor
has_custom_logit_processor = (
batch.enable_custom_logit_processor # check the flag first.
Expand Down Expand Up @@ -132,6 +159,9 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
top_ps=top_ps,
top_ks=top_ks,
min_ps=min_ps,
think_end_ids=think_end_ids,
num_thinking_tokens=num_thinking_tokens,
thinking_budgets=thinking_budgets,
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
vocab_size=vocab_size,
Expand All @@ -146,6 +176,37 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
def __len__(self):
return len(self.temperatures)

def apply_thinking_budgets(self, next_token_logits: torch.Tensor):
if self.thinking_budgets is None:
return
has_budget = self.thinking_budgets >= 0
if not has_budget.any():
return
torch.where(
has_budget,
self.num_thinking_tokens + 1,
self.num_thinking_tokens,
out=self.num_thinking_tokens,
)
should_stop = has_budget & (
self.num_thinking_tokens - 1 > self.thinking_budgets
)
next_token_logits.masked_fill_(should_stop.unsqueeze(0), float("-inf"))
batch_indices = torch.nonzero(should_stop, as_tuple=True)[0]
if len(batch_indices) > 0:
end_token_indices = self.think_end_ids[batch_indices]
next_token_logits[batch_indices, end_token_indices] = 0.0

def update_thinking_budgets(self, next_token_ids: torch.Tensor):
if self.thinking_budgets is None or not torch.any(self.thinking_budgets >= 0):
return
torch.where(
next_token_ids == self.think_end_ids,
torch.tensor(-1, device=self.thinking_budgets.device),
self.thinking_budgets,
out=self.thinking_budgets,
)

def update_regex_vocab_mask(self):
if not self.grammars:
self.vocab_mask = None
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/sampling/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class SamplingParams:
def __init__(
self,
max_new_tokens: int = 128,
thinking_budget: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(
self.stop_token_ids = set(stop_token_ids)
else:
self.stop_token_ids = None
self.thinking_budget = thinking_budget
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class TestFile:
TestFile("test_pytorch_sampling_backend.py", 66),
TestFile("test_radix_attention.py", 105),
TestFile("test_reasoning_content.py", 89),
TestFile("test_thinking_budget.py", 60),
TestFile("test_regex_constrained.py", 64),
TestFile("test_release_memory_occupation.py", 44),
TestFile("test_request_length_validation.py", 31),
Expand Down
95 changes: 95 additions & 0 deletions test/srt/test_thinking_budget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Usage:
python3 -m unittest test_thinking_budget.TestThinkingBudget.test_chat_completion_with_thinking_budget_20
python3 -m unittest test_thinking_budget.TestThinkingBudget.test_chat_completion_with_thinking_budget_200
"""

import unittest

import requests
from transformers import AutoTokenizer

from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)


class TestThinkingBudget(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen3-8B"
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model)
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-1234"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--reasoning-parser",
"qwen3",
],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_chat_completion_with_thinking_budget_20(self):
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"messages": [
{"role": "user", "content": "9.11 and 9.8, which is greater?"}
],
"temperature": 0,
"separate_reasoning": True,
"chat_template_kwargs": {"enable_thinking": True},
"thinking_budget": 20,
},
)
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
data = response.json()
reasoning_content = data["choices"][0]["message"]["reasoning_content"]
tokens = self.tokenizer.encode(reasoning_content)
self.assertEqual(
len(tokens),
20,
f"Reasoning content length: {len(tokens)} not equal to 20, tokens: {tokens}, reasoning_content: {reasoning_content}",
)

def test_chat_completion_with_thinking_budget_200(self):
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"messages": [
{"role": "user", "content": "9.11 and 9.8, which is greater?"}
],
"temperature": 0,
"separate_reasoning": True,
"chat_template_kwargs": {"enable_thinking": True},
"thinking_budget": 200,
},
)
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
data = response.json()
reasoning_content = data["choices"][0]["message"]["reasoning_content"]
tokens = self.tokenizer.encode(reasoning_content)
self.assertEqual(
len(tokens),
200,
f"Reasoning content length {len(tokens)} not equal to 200, tokens: {tokens}, reasoning_content: {reasoning_content}",
)


if __name__ == "__main__":
unittest.main()