diff --git a/docs/backend/sampling_params.md b/docs/backend/sampling_params.md index e5d8219914e..3c51ae164dc 100644 --- a/docs/backend/sampling_params.md +++ b/docs/backend/sampling_params.md @@ -64,6 +64,7 @@ Please refer to our dedicated guide on [constrained decoding](./structured_outpu | ignore_eos | `bool = False` | Don't stop generation when EOS token is sampled. | | skip_special_tokens | `bool = True` | Remove special tokens during decoding. | | custom_params | `Optional[List[Optional[Dict[str, Any]]]] = None` | Used when employing `CustomLogitProcessor`. For usage, see below. | +| thinking_budget | `Optional[int] = None` | The maximum number of reasoning tokens that can be generated for a request. | ## Examples @@ -296,3 +297,29 @@ response = requests.post( ) print(response.json()) ``` + +### Thinking Budget + +Launch a server with `--reasoning-parser`. + +```bash +python3 -m sglang.launch_server --model Qwen/Qwen3-8B --reasoning-parser qwen3 +``` + +Send a request: + +```python +import requests +response = requests.post( + "http://localhost:30000/generate", + json={ + "text": "9.11 and 9.8, which is greater?", + "sampling_params": { + "temperature": 0.3, + "max_new_tokens": 256, + "thinking_budget": 20, + }, + }, +) +print(response.json()) +``` diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a146df7a751..fe218206ae9 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1271,7 +1271,9 @@ def sample( [self.sample(values, forward_batch) for values in logits_output], axis=-1, ) - + sampling_info = forward_batch.sampling_info + if sampling_info.thinking_budgets is not None: + sampling_info.apply_thinking_budgets(logits_output.next_token_logits) self._preprocess_logits(logits_output, forward_batch.sampling_info) # Sample the next tokens @@ -1282,6 +1284,8 @@ def sample( forward_batch.top_logprobs_nums, forward_batch.token_ids_logprobs, ) + if sampling_info.thinking_budgets is not None: + sampling_info.update_thinking_budgets(next_token_ids) return next_token_ids @property diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 83ff70b3965..d9d6b0e94eb 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -1128,10 +1128,16 @@ def v1_chat_generate_request( lora_paths.append(request.lora_path) prompts.append(prompt) + thinking_budget = request.thinking_budget + enable_thinking = _get_enable_thinking_from_request(request) + if not enable_thinking: + thinking_budget = None + sampling_params = { "temperature": request.temperature, "max_new_tokens": request.max_tokens or request.max_completion_tokens, "min_new_tokens": request.min_tokens, + "thinking_budget": thinking_budget, "stop": stop, "stop_token_ids": request.stop_token_ids, "top_p": request.top_p, diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 35c04b0542a..6116f6dd0bd 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -383,6 +383,7 @@ def set_tool_choice_default(cls, values): top_k: int = -1 min_p: float = 0.0 min_tokens: int = 0 + thinking_budget: Optional[int] = None regex: Optional[str] = None ebnf: Optional[str] = None repetition_penalty: float = 1.0 diff --git a/python/sglang/srt/reasoning_parser.py b/python/sglang/srt/reasoning_parser.py index d8bf8f09cb6..06325691ef1 100644 --- a/python/sglang/srt/reasoning_parser.py +++ b/python/sglang/srt/reasoning_parser.py @@ -32,7 +32,7 @@ def detect_and_parse(self, text: str) -> StreamingParseResult: One-time parsing: Detects and parses reasoning sections in the provided text. Returns both reasoning content and normal text separately. """ - text = text.replace(self.think_start_token, "").strip() + text = text.replace(self.think_start_token, "") if self.think_end_token not in text: # Assume reasoning was truncated before `` token return StreamingParseResult(reasoning_text=text) @@ -73,7 +73,7 @@ def parse_streaming_increment(self, new_text: str) -> StreamingParseResult: normal_text = current_text[end_idx + len(self.think_end_token) :] return StreamingParseResult( - normal_text=normal_text, reasoning_text=reasoning_text.rstrip() + normal_text=normal_text, reasoning_text=reasoning_text ) # Continue with reasoning content diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 7f169ef0417..1359c6e66f9 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -30,8 +30,13 @@ class SamplingBatchInfo: # Whether any request needs min_p sampling need_min_p_sampling: bool + # Use thinking_budget to truncate thinking + num_thinking_tokens: Optional[torch.Tensor] = None + think_end_ids: Optional[torch.Tensor] = None + thinking_budgets: Optional[torch.Tensor] = None + # Masking tensors for grammar-guided structured outputs - vocab_size: int + vocab_size: int = 0 grammars: Optional[List] = None vocab_mask: Optional[torch.Tensor] = None apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None @@ -76,7 +81,29 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): min_ps = torch.tensor( [r.sampling_params.min_p for r in reqs], dtype=torch.float ).to(device, non_blocking=True) - + if any(hasattr(r.tokenizer, "think_end_id") for r in reqs): + think_end_ids = torch.tensor( + [getattr(r.tokenizer, "think_end_id", -1) for r in reqs], + dtype=torch.int64, + ).to(device, non_blocking=True) + num_thinking_tokens = torch.tensor([0 for _ in reqs], dtype=torch.int64).to( + device, non_blocking=True + ) + thinking_budgets = torch.tensor( + [ + ( + r.sampling_params.thinking_budget + if r.sampling_params.thinking_budget is not None + else -1 + ) + for r in reqs + ], + dtype=torch.int64, + ).to(device, non_blocking=True) + else: + think_end_ids = None + num_thinking_tokens = None + thinking_budgets = None # Check if any request has custom logit processor has_custom_logit_processor = ( batch.enable_custom_logit_processor # check the flag first. @@ -132,6 +159,9 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): top_ps=top_ps, top_ks=top_ks, min_ps=min_ps, + think_end_ids=think_end_ids, + num_thinking_tokens=num_thinking_tokens, + thinking_budgets=thinking_budgets, is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs), need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), vocab_size=vocab_size, @@ -146,6 +176,37 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): def __len__(self): return len(self.temperatures) + def apply_thinking_budgets(self, next_token_logits: torch.Tensor): + if self.thinking_budgets is None: + return + has_budget = self.thinking_budgets >= 0 + if not has_budget.any(): + return + torch.where( + has_budget, + self.num_thinking_tokens + 1, + self.num_thinking_tokens, + out=self.num_thinking_tokens, + ) + should_stop = has_budget & ( + self.num_thinking_tokens - 1 > self.thinking_budgets + ) + next_token_logits.masked_fill_(should_stop.unsqueeze(0), float("-inf")) + batch_indices = torch.nonzero(should_stop, as_tuple=True)[0] + if len(batch_indices) > 0: + end_token_indices = self.think_end_ids[batch_indices] + next_token_logits[batch_indices, end_token_indices] = 0.0 + + def update_thinking_budgets(self, next_token_ids: torch.Tensor): + if self.thinking_budgets is None or not torch.any(self.thinking_budgets >= 0): + return + torch.where( + next_token_ids == self.think_end_ids, + torch.tensor(-1, device=self.thinking_budgets.device), + self.thinking_budgets, + out=self.thinking_budgets, + ) + def update_regex_vocab_mask(self): if not self.grammars: self.vocab_mask = None diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index 4c505fe7a27..f64f063d00d 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -30,6 +30,7 @@ class SamplingParams: def __init__( self, max_new_tokens: int = 128, + thinking_budget: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, temperature: float = 1.0, @@ -58,6 +59,7 @@ def __init__( self.stop_token_ids = set(stop_token_ids) else: self.stop_token_ids = None + self.thinking_budget = thinking_budget self.temperature = temperature self.top_p = top_p self.top_k = top_k diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 323aeb1eb0a..d594d862c7c 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -62,6 +62,7 @@ class TestFile: TestFile("test_pytorch_sampling_backend.py", 66), TestFile("test_radix_attention.py", 105), TestFile("test_reasoning_content.py", 89), + TestFile("test_thinking_budget.py", 60), TestFile("test_regex_constrained.py", 64), TestFile("test_release_memory_occupation.py", 44), TestFile("test_request_length_validation.py", 31), diff --git a/test/srt/test_thinking_budget.py b/test/srt/test_thinking_budget.py new file mode 100644 index 00000000000..9d264c9c619 --- /dev/null +++ b/test/srt/test_thinking_budget.py @@ -0,0 +1,95 @@ +""" +Usage: +python3 -m unittest test_thinking_budget.TestThinkingBudget.test_chat_completion_with_thinking_budget_20 +python3 -m unittest test_thinking_budget.TestThinkingBudget.test_chat_completion_with_thinking_budget_200 +""" + +import unittest + +import requests +from transformers import AutoTokenizer + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestThinkingBudget(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "Qwen/Qwen3-8B" + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model) + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-1234" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=[ + "--reasoning-parser", + "qwen3", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_chat_completion_with_thinking_budget_20(self): + response = requests.post( + f"{self.base_url}/v1/chat/completions", + headers={"Authorization": f"Bearer {self.api_key}"}, + json={ + "model": self.model, + "messages": [ + {"role": "user", "content": "9.11 and 9.8, which is greater?"} + ], + "temperature": 0, + "separate_reasoning": True, + "chat_template_kwargs": {"enable_thinking": True}, + "thinking_budget": 20, + }, + ) + self.assertEqual(response.status_code, 200, f"Failed with: {response.text}") + data = response.json() + reasoning_content = data["choices"][0]["message"]["reasoning_content"] + tokens = self.tokenizer.encode(reasoning_content) + self.assertEqual( + len(tokens), + 20, + f"Reasoning content length: {len(tokens)} not equal to 20, tokens: {tokens}, reasoning_content: {reasoning_content}", + ) + + def test_chat_completion_with_thinking_budget_200(self): + response = requests.post( + f"{self.base_url}/v1/chat/completions", + headers={"Authorization": f"Bearer {self.api_key}"}, + json={ + "model": self.model, + "messages": [ + {"role": "user", "content": "9.11 and 9.8, which is greater?"} + ], + "temperature": 0, + "separate_reasoning": True, + "chat_template_kwargs": {"enable_thinking": True}, + "thinking_budget": 200, + }, + ) + self.assertEqual(response.status_code, 200, f"Failed with: {response.text}") + data = response.json() + reasoning_content = data["choices"][0]["message"]["reasoning_content"] + tokens = self.tokenizer.encode(reasoning_content) + self.assertEqual( + len(tokens), + 200, + f"Reasoning content length {len(tokens)} not equal to 200, tokens: {tokens}, reasoning_content: {reasoning_content}", + ) + + +if __name__ == "__main__": + unittest.main()