Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 98 additions & 1 deletion python/sglang/srt/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,25 @@
from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.models.utils import apply_qk_norm
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, is_cuda, is_npu
from sglang.srt.utils import add_prefix, get_bool_env_var, is_cuda, is_hip, is_npu

Qwen3Config = None

logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_npu = is_npu()

_has_fused_qk_norm_mrope = False
if get_bool_env_var("SGLANG_USE_AITER") and _is_hip:
try:
from aiter import fused_qk_norm_mrope_3d_cache_pts_quant_shuffle

_has_fused_qk_norm_mrope = True
logger.info("aiter fused_qk_norm_mrope_3d kernel available")
except ImportError:
pass

if _is_npu:
from sgl_kernel_npu.norm.split_qkv_rmsnorm_rope import split_qkv_rmsnorm_rope

Expand Down Expand Up @@ -138,6 +149,21 @@ def __init__(
)
self.alt_stream = alt_stream

from sglang.srt.layers.rotary_embedding.mrope import MRotaryEmbedding

self.use_fused_qk_norm_mrope = (
_has_fused_qk_norm_mrope
and isinstance(self.rotary_emb, MRotaryEmbedding)
and getattr(self.rotary_emb, "mrope_section", None) is not None
)
if self.use_fused_qk_norm_mrope:
# Scale tensors MUST stay on CPU: the C++ kernel uses .item<float>()
# which triggers hipMemcpy D2H + sync on CUDA tensors, breaking graph capture.
# Explicit device='cpu' is required because SGLang constructs models inside
# a `with torch.device('cuda'):` context that changes the default device.
self._fused_k_scale = torch.tensor(1.0, dtype=torch.float32, device="cpu")
self._fused_v_scale = torch.tensor(1.0, dtype=torch.float32, device="cpu")

def forward_prepare_native(self, positions, hidden_states):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Expand Down Expand Up @@ -172,6 +198,73 @@ def forward_prepare_npu(self, positions, hidden_states, forward_batch):
)
return q, k, v

def _forward_fused_mrope_decode(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
"""Fused QK-norm + 3D mRoPE + KV cache write using aiter kernel.

Replaces the separate split → QK norm → mRoPE → cache write pipeline
with a single fused HIP kernel, reducing 4-5 kernel launches to 1.
Only used on ROCm decode path with MRotaryEmbedding.
"""
qkv, _ = self.qkv_proj(hidden_states)
num_tokens = qkv.shape[0]

qkv_3d = qkv.view(num_tokens, -1, self.head_dim)

token_to_kv_pool = forward_batch.token_to_kv_pool
k_cache, v_cache = token_to_kv_pool.get_kv_buffer(self.attn.layer_id)
slot_mapping = forward_batch.out_cache_loc

cos_sin = self.rotary_emb.cos_sin_cache
if cos_sin.dtype != qkv.dtype:
cos_sin = cos_sin.to(dtype=qkv.dtype)

q_out = torch.empty(
num_tokens,
self.num_heads,
self.head_dim,
dtype=qkv.dtype,
device=qkv.device,
)

fused_qk_norm_mrope_3d_cache_pts_quant_shuffle(
qkv_3d,
self.q_norm.weight,
self.k_norm.weight,
cos_sin,
positions,
num_tokens,
self.num_heads,
self.num_kv_heads,
self.num_kv_heads,
self.head_dim,
self.rotary_emb.is_neox_style,
self.rotary_emb.mrope_section,
self.rotary_emb.mrope_interleaved,
self.q_norm.variance_epsilon,
q_out,
k_cache,
v_cache,
slot_mapping,
self._fused_k_scale,
self._fused_v_scale,
None,
None,
False,
False,
0,
0,
)

q = q_out.reshape(num_tokens, -1)
attn_output = self.attn(q, None, None, forward_batch, save_kv_cache=False)
output, _ = self.o_proj(attn_output)
return output

def forward(
self,
positions: torch.Tensor,
Expand All @@ -180,6 +273,10 @@ def forward(
) -> torch.Tensor:
if get_global_server_args().rl_on_policy_target is not None:
hidden_states = hidden_states.bfloat16()
elif self.use_fused_qk_norm_mrope and forward_batch.forward_mode.is_decode():
return self._forward_fused_mrope_decode(
positions, hidden_states, forward_batch
)

if (
not _is_npu
Expand Down
Loading