5
5
import torch .distributed as dist
6
6
from torch import nn
7
7
8
- from sglang .srt .distributed import get_tensor_model_parallel_group
8
+ from sglang .srt .distributed import get_tp_group
9
9
from sglang .srt .layers .dp_attention import get_attention_tp_group
10
10
from sglang .srt .layers .logits_processor import LogitsProcessorOutput
11
11
from sglang .srt .managers .schedule_batch import global_server_args_dict
@@ -30,7 +30,7 @@ class Sampler(nn.Module):
30
30
def __init__ (self ):
31
31
super ().__init__ ()
32
32
self .use_nan_detection = global_server_args_dict ["enable_nan_detection" ]
33
- self .tp_sync_group = get_tensor_model_parallel_group ().device_group
33
+ self .tp_sync_group = get_tp_group ().device_group
34
34
35
35
if global_server_args_dict ["enable_dp_attention" ]:
36
36
self .tp_sync_group = get_attention_tp_group ().device_group
@@ -59,7 +59,7 @@ def forward(
59
59
60
60
# Apply the custom logit processors if registered in the sampling info.
61
61
if sampling_info .has_custom_logit_processor :
62
- self . _apply_custom_logit_processor (logits , sampling_info )
62
+ apply_custom_logit_processor (logits , sampling_info )
63
63
64
64
if self .use_nan_detection and torch .any (torch .isnan (logits )):
65
65
logger .warning ("Detected errors during sampling! NaN in the logits." )
@@ -81,49 +81,39 @@ def forward(
81
81
probs = logits
82
82
del logits
83
83
84
- if global_server_args_dict ["sampling_backend" ] == "flashinfer" :
85
- if return_logprob :
86
- # NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
87
- # https://github.com/flashinfer-ai/flashinfer/issues/708
88
- # so we use the torch implementation.
89
- # NOTE: OpenAI's logprobs is independent of top-p, we use the
90
- # same rule.
91
- logprobs = torch .log (probs ).clamp (min = torch .finfo (probs .dtype ).min )
92
-
93
- max_top_k_round , batch_size = 32 , probs .shape [0 ]
94
- if sampling_info .need_min_p_sampling :
95
- probs = top_k_renorm_prob (probs , sampling_info .top_ks )
96
- probs = top_p_renorm_prob (probs , sampling_info .top_ps )
97
- batch_next_token_ids = min_p_sampling_from_probs (
98
- probs , sampling_info .min_ps
99
- )
100
- else :
101
- # Check Nan will throw exception, only check when crash_on_warnings is True
102
- check_nan = self .use_nan_detection and crash_on_warnings ()
103
- batch_next_token_ids = top_k_top_p_sampling_from_probs (
104
- probs .contiguous (),
84
+ if True : # Keep this redundant check to simplify some internal code sync
85
+ if global_server_args_dict ["sampling_backend" ] == "flashinfer" :
86
+ if sampling_info .need_min_p_sampling :
87
+ probs = top_k_renorm_prob (probs , sampling_info .top_ks )
88
+ probs = top_p_renorm_prob (probs , sampling_info .top_ps )
89
+ batch_next_token_ids = min_p_sampling_from_probs (
90
+ probs , sampling_info .min_ps
91
+ )
92
+ else :
93
+ batch_next_token_ids = top_k_top_p_sampling_from_probs (
94
+ probs ,
95
+ sampling_info .top_ks ,
96
+ sampling_info .top_ps ,
97
+ filter_apply_order = "joint" ,
98
+ check_nan = self .use_nan_detection ,
99
+ )
100
+ elif global_server_args_dict ["sampling_backend" ] == "pytorch" :
101
+ # A slower fallback implementation with torch native operations.
102
+ batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch (
103
+ probs ,
105
104
sampling_info .top_ks ,
106
105
sampling_info .top_ps ,
107
- filter_apply_order = "joint" ,
108
- check_nan = check_nan ,
106
+ sampling_info .min_ps ,
107
+ sampling_info .need_min_p_sampling ,
108
+ )
109
+ else :
110
+ raise ValueError (
111
+ f"Invalid sampling backend: { global_server_args_dict ['sampling_backend' ]} "
109
112
)
110
113
111
- elif global_server_args_dict ["sampling_backend" ] == "pytorch" :
112
- # A slower fallback implementation with torch native operations.
113
- batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch (
114
- probs ,
115
- sampling_info .top_ks ,
116
- sampling_info .top_ps ,
117
- sampling_info .min_ps ,
118
- sampling_info .need_min_p_sampling ,
119
- )
120
-
121
- if return_logprob :
122
- logprobs = torch .log (probs ).clamp (min = torch .finfo (probs .dtype ).min )
123
- else :
124
- raise ValueError (
125
- f"Invalid sampling backend: { global_server_args_dict ['sampling_backend' ]} "
126
- )
114
+ if return_logprob :
115
+ # clamp to avoid -inf
116
+ logprobs = torch .log (probs ).clamp (min = torch .finfo (probs .dtype ).min )
127
117
128
118
# Attach logprobs to logits_output (in-place modification)
129
119
if return_logprob :
@@ -160,39 +150,6 @@ def forward(
160
150
161
151
return batch_next_token_ids
162
152
163
- def _apply_custom_logit_processor (
164
- self , logits : torch .Tensor , sampling_batch_info : SamplingBatchInfo
165
- ):
166
- """Apply custom logit processors to the logits.
167
- This function will modify the logits in-place."""
168
-
169
- assert logits .shape [0 ] == len (sampling_batch_info ), (
170
- f"The batch size of logits ({ logits .shape [0 ]} ) does not match the batch size of "
171
- f"sampling_batch_info ({ len (sampling_batch_info )} )"
172
- )
173
-
174
- for _ , (
175
- processor ,
176
- batch_mask ,
177
- ) in sampling_batch_info .custom_logit_processor .items ():
178
- # Get the batch indices that need to be processed
179
- batch_indices = batch_mask .nonzero (as_tuple = True )[0 ]
180
-
181
- assert batch_mask .shape [0 ] == len (sampling_batch_info ), (
182
- f"The number of batch mask ({ batch_mask .shape [0 ]} ) does not match the number of "
183
- f"sampling_batch_info ({ len (sampling_batch_info )} )"
184
- )
185
-
186
- # Apply the processor to the logits
187
- logits [batch_mask ] = processor (
188
- logits [batch_mask ],
189
- [sampling_batch_info .custom_params [i ] for i in batch_indices ],
190
- )
191
-
192
- logger .debug (
193
- f"Custom logit processor { processor .__class__ .__name__ } is applied."
194
- )
195
-
196
153
197
154
def top_k_top_p_min_p_sampling_from_probs_torch (
198
155
probs : torch .Tensor ,
@@ -221,6 +178,14 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
221
178
return batch_next_token_ids
222
179
223
180
181
+ def sampling_from_probs_torch (probs : torch .Tensor ):
182
+ """A sampling implementation with native pytorch operations, without
183
+ top-k, top-p, or min-p filtering."""
184
+ sampled_index = torch .multinomial (probs , num_samples = 1 )
185
+ batch_next_token_ids = sampled_index .view (- 1 ).to (torch .int32 )
186
+ return batch_next_token_ids
187
+
188
+
224
189
def top_p_normalize_probs_torch (
225
190
probs : torch .Tensor ,
226
191
top_ps : torch .Tensor ,
@@ -259,3 +224,44 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List
259
224
output_token_ids_logprobs_idx .append ([])
260
225
261
226
return output_token_ids_logprobs_val , output_token_ids_logprobs_idx
227
+
228
+
229
+ def apply_custom_logit_processor (
230
+ logits : torch .Tensor ,
231
+ sampling_batch_info : SamplingBatchInfo ,
232
+ num_tokens_in_batch : int = 1 ,
233
+ ):
234
+ """Apply custom logit processors to the logits.
235
+ This function will modify the logits in-place.
236
+ num_tokens_in_batch is needed to support spec decoding, where each batch can contain multiple
237
+ tokens. By default, we assume each batch contains only 1 token.
238
+ """
239
+
240
+ assert logits .shape [0 ] == len (sampling_batch_info ) * num_tokens_in_batch , (
241
+ f"The batch size of logits ({ logits .shape [0 ]} ) does not match the batch size of "
242
+ f"sampling_batch_info ({ len (sampling_batch_info )} ) x num_tokens_in_batch "
243
+ f"({ num_tokens_in_batch } )"
244
+ )
245
+
246
+ for _ , (
247
+ processor ,
248
+ batch_mask ,
249
+ ) in sampling_batch_info .custom_logit_processor .items ():
250
+ # Get the batch indices that need to be processed
251
+ batch_indices = batch_mask .nonzero (as_tuple = True )[0 ]
252
+
253
+ assert batch_mask .shape [0 ] == len (sampling_batch_info ), (
254
+ f"The number of batch mask ({ batch_mask .shape [0 ]} ) does not match the number of "
255
+ f"sampling_batch_info ({ len (sampling_batch_info )} )"
256
+ )
257
+ batch_mask = torch .repeat_interleave (batch_mask , num_tokens_in_batch )
258
+
259
+ # Apply the processor to the logits
260
+ logits [batch_mask ] = processor (
261
+ logits [batch_mask ],
262
+ [sampling_batch_info .custom_params [i ] for i in batch_indices ],
263
+ )
264
+
265
+ logger .debug (
266
+ f"Custom logit processor { processor .__class__ .__name__ } is applied."
267
+ )
0 commit comments