File tree Expand file tree Collapse file tree 1 file changed +7
-1
lines changed
python/sglang/srt/model_executor Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change 39
39
import triton .language as tl
40
40
41
41
from sglang .srt .layers .rotary_embedding import MRotaryEmbedding
42
+ from sglang .srt .utils import get_compiler_backend
42
43
43
44
if TYPE_CHECKING :
44
45
from sglang .srt .layers .attention .base_attn_backend import AttentionBackend
@@ -299,7 +300,7 @@ def init_new(
299
300
# Init position information
300
301
if ret .forward_mode .is_decode ():
301
302
if ret .positions is None :
302
- ret .positions = torch . clamp (( batch .seq_lens - 1 ), min = 0 ). to ( torch . int64 )
303
+ ret .positions = clamp_position ( batch .seq_lens )
303
304
else :
304
305
ret .extend_seq_lens = torch .tensor (
305
306
batch .extend_seq_lens , dtype = torch .int32
@@ -519,3 +520,8 @@ def compute_position_torch(
519
520
extend_start_loc = torch .zeros_like (extend_seq_lens )
520
521
extend_start_loc [1 :] = torch .cumsum (extend_seq_lens [:- 1 ], dim = 0 )
521
522
return positions .to (torch .int64 ), extend_start_loc
523
+
524
+
525
+ @torch .compile (dynamic = True , backend = get_compiler_backend ())
526
+ def clamp_position (seq_lens ):
527
+ return torch .clamp ((seq_lens - 1 ), min = 0 ).to (torch .int64 )
You can’t perform that action at this time.
0 commit comments