Skip to content

Commit acc9ae6

Browse files
Fridge003thyecust
authored andcommitted
[Fix] Add torch compile for torch.clamp back (sgl-project#4936)
1 parent 8339e22 commit acc9ae6

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

python/sglang/srt/model_executor/forward_batch_info.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import triton.language as tl
4040

4141
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
42+
from sglang.srt.utils import get_compiler_backend
4243

4344
if TYPE_CHECKING:
4445
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -299,7 +300,7 @@ def init_new(
299300
# Init position information
300301
if ret.forward_mode.is_decode():
301302
if ret.positions is None:
302-
ret.positions = torch.clamp((batch.seq_lens - 1), min=0).to(torch.int64)
303+
ret.positions = clamp_position(batch.seq_lens)
303304
else:
304305
ret.extend_seq_lens = torch.tensor(
305306
batch.extend_seq_lens, dtype=torch.int32
@@ -519,3 +520,8 @@ def compute_position_torch(
519520
extend_start_loc = torch.zeros_like(extend_seq_lens)
520521
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
521522
return positions.to(torch.int64), extend_start_loc
523+
524+
525+
@torch.compile(dynamic=True, backend=get_compiler_backend())
526+
def clamp_position(seq_lens):
527+
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)

0 commit comments

Comments
 (0)