Skip to content

Commit 608668e

Browse files
authored
Slightly improve the sampler to skip unnecessary steps (#6956)
1 parent 6c0a482 commit 608668e

File tree

7 files changed

+109
-93
lines changed

7 files changed

+109
-93
lines changed

python/sglang/srt/layers/sampler.py

Lines changed: 81 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.distributed as dist
66
from torch import nn
77

8-
from sglang.srt.distributed import get_tensor_model_parallel_group
8+
from sglang.srt.distributed import get_tp_group
99
from sglang.srt.layers.dp_attention import get_attention_tp_group
1010
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
1111
from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -30,7 +30,7 @@ class Sampler(nn.Module):
3030
def __init__(self):
3131
super().__init__()
3232
self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
33-
self.tp_sync_group = get_tensor_model_parallel_group().device_group
33+
self.tp_sync_group = get_tp_group().device_group
3434

3535
if global_server_args_dict["enable_dp_attention"]:
3636
self.tp_sync_group = get_attention_tp_group().device_group
@@ -59,7 +59,7 @@ def forward(
5959

6060
# Apply the custom logit processors if registered in the sampling info.
6161
if sampling_info.has_custom_logit_processor:
62-
self._apply_custom_logit_processor(logits, sampling_info)
62+
apply_custom_logit_processor(logits, sampling_info)
6363

6464
if self.use_nan_detection and torch.any(torch.isnan(logits)):
6565
logger.warning("Detected errors during sampling! NaN in the logits.")
@@ -81,49 +81,39 @@ def forward(
8181
probs = logits
8282
del logits
8383

84-
if global_server_args_dict["sampling_backend"] == "flashinfer":
85-
if return_logprob:
86-
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
87-
# https://github.com/flashinfer-ai/flashinfer/issues/708
88-
# so we use the torch implementation.
89-
# NOTE: OpenAI's logprobs is independent of top-p, we use the
90-
# same rule.
91-
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
92-
93-
max_top_k_round, batch_size = 32, probs.shape[0]
94-
if sampling_info.need_min_p_sampling:
95-
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
96-
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
97-
batch_next_token_ids = min_p_sampling_from_probs(
98-
probs, sampling_info.min_ps
99-
)
100-
else:
101-
# Check Nan will throw exception, only check when crash_on_warnings is True
102-
check_nan = self.use_nan_detection and crash_on_warnings()
103-
batch_next_token_ids = top_k_top_p_sampling_from_probs(
104-
probs.contiguous(),
84+
if True: # Keep this redundant check to simplify some internal code sync
85+
if global_server_args_dict["sampling_backend"] == "flashinfer":
86+
if sampling_info.need_min_p_sampling:
87+
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
88+
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
89+
batch_next_token_ids = min_p_sampling_from_probs(
90+
probs, sampling_info.min_ps
91+
)
92+
else:
93+
batch_next_token_ids = top_k_top_p_sampling_from_probs(
94+
probs,
95+
sampling_info.top_ks,
96+
sampling_info.top_ps,
97+
filter_apply_order="joint",
98+
check_nan=self.use_nan_detection,
99+
)
100+
elif global_server_args_dict["sampling_backend"] == "pytorch":
101+
# A slower fallback implementation with torch native operations.
102+
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
103+
probs,
105104
sampling_info.top_ks,
106105
sampling_info.top_ps,
107-
filter_apply_order="joint",
108-
check_nan=check_nan,
106+
sampling_info.min_ps,
107+
sampling_info.need_min_p_sampling,
108+
)
109+
else:
110+
raise ValueError(
111+
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
109112
)
110113

111-
elif global_server_args_dict["sampling_backend"] == "pytorch":
112-
# A slower fallback implementation with torch native operations.
113-
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
114-
probs,
115-
sampling_info.top_ks,
116-
sampling_info.top_ps,
117-
sampling_info.min_ps,
118-
sampling_info.need_min_p_sampling,
119-
)
120-
121-
if return_logprob:
122-
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
123-
else:
124-
raise ValueError(
125-
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
126-
)
114+
if return_logprob:
115+
# clamp to avoid -inf
116+
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
127117

128118
# Attach logprobs to logits_output (in-place modification)
129119
if return_logprob:
@@ -160,39 +150,6 @@ def forward(
160150

161151
return batch_next_token_ids
162152

163-
def _apply_custom_logit_processor(
164-
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
165-
):
166-
"""Apply custom logit processors to the logits.
167-
This function will modify the logits in-place."""
168-
169-
assert logits.shape[0] == len(sampling_batch_info), (
170-
f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
171-
f"sampling_batch_info ({len(sampling_batch_info)})"
172-
)
173-
174-
for _, (
175-
processor,
176-
batch_mask,
177-
) in sampling_batch_info.custom_logit_processor.items():
178-
# Get the batch indices that need to be processed
179-
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
180-
181-
assert batch_mask.shape[0] == len(sampling_batch_info), (
182-
f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
183-
f"sampling_batch_info ({len(sampling_batch_info)})"
184-
)
185-
186-
# Apply the processor to the logits
187-
logits[batch_mask] = processor(
188-
logits[batch_mask],
189-
[sampling_batch_info.custom_params[i] for i in batch_indices],
190-
)
191-
192-
logger.debug(
193-
f"Custom logit processor {processor.__class__.__name__} is applied."
194-
)
195-
196153

197154
def top_k_top_p_min_p_sampling_from_probs_torch(
198155
probs: torch.Tensor,
@@ -221,6 +178,14 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
221178
return batch_next_token_ids
222179

223180

181+
def sampling_from_probs_torch(probs: torch.Tensor):
182+
"""A sampling implementation with native pytorch operations, without
183+
top-k, top-p, or min-p filtering."""
184+
sampled_index = torch.multinomial(probs, num_samples=1)
185+
batch_next_token_ids = sampled_index.view(-1).to(torch.int32)
186+
return batch_next_token_ids
187+
188+
224189
def top_p_normalize_probs_torch(
225190
probs: torch.Tensor,
226191
top_ps: torch.Tensor,
@@ -259,3 +224,44 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List
259224
output_token_ids_logprobs_idx.append([])
260225

261226
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
227+
228+
229+
def apply_custom_logit_processor(
230+
logits: torch.Tensor,
231+
sampling_batch_info: SamplingBatchInfo,
232+
num_tokens_in_batch: int = 1,
233+
):
234+
"""Apply custom logit processors to the logits.
235+
This function will modify the logits in-place.
236+
num_tokens_in_batch is needed to support spec decoding, where each batch can contain multiple
237+
tokens. By default, we assume each batch contains only 1 token.
238+
"""
239+
240+
assert logits.shape[0] == len(sampling_batch_info) * num_tokens_in_batch, (
241+
f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
242+
f"sampling_batch_info ({len(sampling_batch_info)}) x num_tokens_in_batch "
243+
f"({num_tokens_in_batch})"
244+
)
245+
246+
for _, (
247+
processor,
248+
batch_mask,
249+
) in sampling_batch_info.custom_logit_processor.items():
250+
# Get the batch indices that need to be processed
251+
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
252+
253+
assert batch_mask.shape[0] == len(sampling_batch_info), (
254+
f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
255+
f"sampling_batch_info ({len(sampling_batch_info)})"
256+
)
257+
batch_mask = torch.repeat_interleave(batch_mask, num_tokens_in_batch)
258+
259+
# Apply the processor to the logits
260+
logits[batch_mask] = processor(
261+
logits[batch_mask],
262+
[sampling_batch_info.custom_params[i] for i in batch_indices],
263+
)
264+
265+
logger.debug(
266+
f"Custom logit processor {processor.__class__.__name__} is applied."
267+
)

python/sglang/srt/managers/tokenizer_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,7 @@ async def update_weights_from_disk(
852852
obj.load_format = self.server_args.load_format
853853
logger.info("Start update_weights. Load format=%s", obj.load_format)
854854

855-
if True:
855+
if True: # Keep this redundant check to simplify some internal code sync
856856
# Hold the lock if it is not async. This means that weight sync
857857
# cannot run while requests are in progress.
858858
async with self.model_update_lock.writer_lock:

python/sglang/srt/models/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""Inference-only LLaMA model compatible with HuggingFace weights."""
1818

1919
import logging
20-
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
20+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
2121

2222
import torch
2323
from torch import nn

python/sglang/srt/sampling/sampling_batch_info.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99

1010
import sglang.srt.sampling.penaltylib as penaltylib
1111
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
12+
from sglang.srt.sampling.sampling_params import TOP_K_ALL
1213

1314
if TYPE_CHECKING:
1415
from sglang.srt.managers.schedule_batch import ScheduleBatch
1516

17+
1618
logger = logging.getLogger(__name__)
1719

1820

@@ -27,6 +29,12 @@ class SamplingBatchInfo:
2729
# Whether all requests use greedy sampling
2830
is_all_greedy: bool
2931

32+
# Whether any requests use top_p sampling
33+
need_top_p_sampling: bool
34+
35+
# Whether any requests use top_k sampling
36+
need_top_k_sampling: bool
37+
3038
# Whether any request needs min_p sampling
3139
need_min_p_sampling: bool
3240

@@ -133,6 +141,8 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
133141
top_ks=top_ks,
134142
min_ps=min_ps,
135143
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
144+
need_top_p_sampling=any(r.sampling_params.top_p != 1.0 for r in reqs),
145+
need_top_k_sampling=any(r.sampling_params.top_k != TOP_K_ALL for r in reqs),
136146
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
137147
vocab_size=vocab_size,
138148
penalizer_orchestrator=penalizer_orchestrator,
@@ -167,7 +177,7 @@ def update_regex_vocab_mask(self):
167177

168178
# Apply the mask
169179
for i, grammar in enumerate(self.grammars):
170-
if grammar and not grammar.finished:
180+
if grammar and not grammar.finished and not grammar.is_terminated():
171181
grammar.fill_vocab_mask(self.vocab_mask, i)
172182

173183
# Move the mask to the device if needed
@@ -308,4 +318,6 @@ def merge_batch(self, other: "SamplingBatchInfo"):
308318
setattr(self, item, torch.cat([self_val, other_val]))
309319

310320
self.is_all_greedy &= other.is_all_greedy
321+
self.need_top_p_sampling |= other.need_top_p_sampling
322+
self.need_top_k_sampling |= other.need_top_k_sampling
311323
self.need_min_p_sampling |= other.need_min_p_sampling

python/sglang/srt/sampling/sampling_params.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Any, Dict, List, Optional, Union
1717

1818
_SAMPLING_EPS = 1e-6
19+
TOP_K_ALL = 1 << 30
1920

2021

2122
class SamplingParams:
@@ -84,7 +85,7 @@ def __init__(
8485
self.temperature = 1.0
8586
self.top_k = 1
8687
if self.top_k == -1:
87-
self.top_k = 1 << 30 # whole vocabulary
88+
self.top_k = TOP_K_ALL # whole vocabulary
8889

8990
def verify(self):
9091
if self.temperature < 0.0:

python/sglang/srt/two_batch_overlap.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
import dataclasses
22
import logging
3-
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
3+
from typing import Dict, List, Optional, Sequence
44

55
import torch
66

77
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
88
from sglang.srt.layers.communicator import (
99
CommunicateContext,
10-
CommunicateSimpleFn,
1110
CommunicateSummableTensorPairFn,
1211
ScatterMode,
1312
)
14-
from sglang.srt.layers.dp_attention import get_attention_tp_size
1513
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
1614
from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms
1715
from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -20,9 +18,6 @@
2018
from sglang.srt.operations_strategy import OperationsStrategy
2119
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
2220

23-
if TYPE_CHECKING:
24-
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
25-
2621
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
2722

2823
logger = logging.getLogger(__name__)
@@ -46,7 +41,7 @@ def compute_split_seq_index(
4641
assert num_tokens == 0
4742
return 0
4843
else:
49-
raise NotImplementedError
44+
raise NotImplementedError()
5045

5146

5247
def _split_array_by_half_sum(arr: Sequence[int]) -> int:

python/sglang/srt/utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1928,16 +1928,18 @@ def next_power_of_2(n: int):
19281928
setattr(triton, "next_power_of_2", next_power_of_2)
19291929

19301930

1931-
@contextmanager
1932-
def empty_context(*args, **kwargs):
1933-
try:
1934-
# Setup code goes here
1935-
yield
1936-
finally:
1937-
# Cleanup code goes here
1931+
class EmptyContextManager:
1932+
def __enter__(self):
1933+
return self
1934+
1935+
def __exit__(self, exc_type, exc_value, traceback):
19381936
pass
19391937

19401938

1939+
def empty_context(*args, **kwargs):
1940+
return EmptyContextManager()
1941+
1942+
19411943
def add_prefix(name: str, prefix: str) -> str:
19421944
"""Add a weight path prefix to a module name.
19431945

0 commit comments

Comments
 (0)