Skip to content

ROCm: update AITER #5816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .github/workflows/pr-test-amd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ jobs:
else
DEVICE_FLAG="--device /dev/dri"
fi
docker pull lmsysorg/sglang:v0.4.5.post3-rocm630
docker pull ghcr.io/saienduri/sglang-aiter-v0.1.1:428
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
-w /sglang-checkout --name ci_sglang \
lmsysorg/sglang:v0.4.5.post3-rocm630
ghcr.io/saienduri/sglang-aiter-v0.1.1:428

- name: Install dependencies
run: |
Expand Down Expand Up @@ -82,12 +82,12 @@ jobs:
else
DEVICE_FLAG="--device /dev/dri"
fi
docker pull lmsysorg/sglang:v0.4.5.post3-rocm630
docker pull ghcr.io/saienduri/sglang-aiter-v0.1.1:428
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
--cap-add=SYS_PTRACE -e HF_TOKEN=${{ secrets.AMD_HF_TOKEN }} --security-opt seccomp=unconfined \
-w /sglang-checkout --name ci_sglang \
lmsysorg/sglang:v0.4.5.post3-rocm630
ghcr.io/saienduri/sglang-aiter-v0.1.1:428

- name: Install dependencies
run: |
Expand Down Expand Up @@ -120,12 +120,12 @@ jobs:
else
DEVICE_FLAG="--device /dev/dri"
fi
docker pull lmsysorg/sglang:v0.4.5.post3-rocm630
docker pull ghcr.io/saienduri/sglang-aiter-v0.1.1:428
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
-w /sglang-checkout --name ci_sglang \
lmsysorg/sglang:v0.4.5.post3-rocm630
ghcr.io/saienduri/sglang-aiter-v0.1.1:428

- name: Install dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/amd/tuning/benchmark_moe_rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
get_config_file_name,
)

padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0


def main(model, tp_size, dtype: str, batches):
Expand Down
4 changes: 2 additions & 2 deletions docker/Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ ARG TRITON_COMMIT="improve_fa_decode_3.0.0"


ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ARG AITER_COMMIT="testx"
ARG AITER_COMMIT="v0.1.1"

RUN git clone ${SGL_REPO} \
&& cd sglang \
Expand Down Expand Up @@ -74,7 +74,7 @@ ENV SGLANG_SET_CPU_AFFINITY=1
ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1
ENV NCCL_MIN_NCHANNELS=112

ENV MOE_PADDING=1
ENV SGLANG_MOE_PADDING=1
ENV VLLM_FP8_PADDING=1
ENV VLLM_FP8_ACT_PADDING=1
ENV VLLM_FP8_WEIGHT_PADDING=1
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@


logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
enable_moe_align_block_size_triton = bool(
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
)
Expand Down Expand Up @@ -1327,7 +1327,7 @@ def fused_experts_impl(
if (
not (use_fp8_w8a8 or use_int8_w8a8)
or block_shape is not None
or (_is_hip and get_bool_env_var("CK_MOE"))
or (_is_hip and get_bool_env_var("SGLANG_AITER_MOE"))
):
padded_size = 0

Expand Down
32 changes: 15 additions & 17 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.utils import get_bool_env_var, is_hip, permute_weight, set_weight_attrs
from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs

if torch.cuda.is_available():
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
Expand All @@ -30,7 +30,9 @@
_is_hip = is_hip()

if _is_hip:
from aiter import ck_moe
from aiter import ActivationType
from aiter.fused_moe_bf16_asm import ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -102,14 +104,14 @@ def create_weights(
set_weight_attrs(w2_weight, extra_weight_attrs)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _is_hip and get_bool_env_var("CK_MOE"):
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
Expand Down Expand Up @@ -182,21 +184,17 @@ def forward_cuda(
routed_scaling_factor=routed_scaling_factor,
)

if _is_hip and get_bool_env_var("CK_MOE"):
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
assert not no_combine, "unsupported"
return ck_moe(
return ck_moe_2stages(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
None,
None,
None,
None,
32,
None,
activation,
activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
),
)
else:
return fused_experts(
Expand Down Expand Up @@ -527,7 +525,7 @@ def weight_loader(
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0

# this is needed for compressed-tensors only
Expand Down Expand Up @@ -569,7 +567,7 @@ def weight_loader(
quant_method = getattr(param, "quant_method", None)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
loaded_weight = loaded_weight * 0.5

self._load_per_channel_weight_scale(
Expand All @@ -592,7 +590,7 @@ def weight_loader(
)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0

self._load_per_tensor_weight_scale(
Expand Down
42 changes: 20 additions & 22 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def dummy_func(*args, **kwargs):
_is_cuda = is_cuda()

if _is_hip:
from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
from aiter import ActivationType, QuantType
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight

if not _is_cuda:
Expand Down Expand Up @@ -484,7 +484,7 @@ def create_weights(
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = (
torch.uint32
if get_bool_env_var("USE_INT4_WEIGHT")
if get_bool_env_var("SGLANG_INT4_WEIGHT")
else torch.float8_e4m3fn
)
tp_size = get_tensor_model_parallel_world_size()
Expand All @@ -511,7 +511,7 @@ def create_weights(
)

# WEIGHTS
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
# INT4 MoE weight - INT32 packed
w13_weight = torch.nn.Parameter(
torch.empty(
Expand Down Expand Up @@ -585,7 +585,7 @@ def create_weights(

if (
_is_hip
): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
): # and get_bool_env_var("SGLANG_AITER_MOE"): TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1 = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
Expand All @@ -612,7 +612,7 @@ def create_weights(
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)

if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
Expand Down Expand Up @@ -644,7 +644,7 @@ def create_weights(
layer.w2_input_scale = None

def process_weights_after_loading(self, layer: Module) -> None:
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
self.process_weights_hip_int4(layer)
return

Expand Down Expand Up @@ -675,7 +675,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
)
layer.w2_input_scale = None

if get_bool_env_var("CK_MOE"):
if get_bool_env_var("SGLANG_AITER_MOE"):
# Pre-shuffle weights
layer.w13_weight.data = shuffle_weight(
layer.w13_weight.contiguous(), (16, 16)
Expand Down Expand Up @@ -798,17 +798,15 @@ def process_weights_after_loading(self, layer: Module) -> None:
return

def process_weights_hip_int4(self, layer: Module):
# TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added
# TODO: and get_bool_env_var("SGLANG_AITER_MOE"): add after triton kernel added
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
layer.w13_weight = torch.nn.Parameter(
# permute_weight(layer.w13_weight.data),
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
# permute_weight(layer.w2_weight.data),
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False,
)
Expand Down Expand Up @@ -847,23 +845,21 @@ def process_weights_hip_scale_padding(self, layer: Module):
padding_size, # Avoid circular import
)

if get_bool_env_var("CK_MOE"):
if get_bool_env_var("SGLANG_AITER_MOE"):
layer.w13_weight = torch.nn.Parameter(
# permute_weight(layer.w13_weight.data),
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
# permute_weight(layer.w2_weight.data),
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
# ROCm (CK_MOE): using column-wise scaling
# ROCm (SGLANG_AITER_MOE): using column-wise scaling
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
elif get_bool_env_var("MOE_PADDING"):
elif get_bool_env_var("SGLANG_MOE_PADDING"):
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
Expand Down Expand Up @@ -912,15 +908,16 @@ def apply(
)

if _is_hip:
if get_bool_env_var("USE_INT4_WEIGHT"):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
if get_bool_env_var("SGLANG_INT4_WEIGHT"):
# TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE")
assert not no_combine, f"{no_combine=} is not supported."
return ck_moe_2stages_win4(
return ck_moe_2stages(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
QuantType.per_Token,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=(
Expand All @@ -930,13 +927,13 @@ def apply(
),
)

if get_bool_env_var("CK_MOE"):
if get_bool_env_var("SGLANG_AITER_MOE"):
assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant:
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
# TODO(SGLANG_AITER_MOE): FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
), f"SGLANG_AITER_MOE: FP8 bloack_quant {activation=} will be supported later, unset SGLANG_AITER_MOE"
return asm_moe(
x,
layer.w13_weight,
Expand All @@ -955,6 +952,7 @@ def apply(
layer.w2_weight,
topk_weights,
topk_ids,
QuantType.per_Token,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=(
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/quantization/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
_is_hip = is_hip()
_is_cuda = is_cuda()

if _is_hip and get_bool_env_var("CK_MOE"):
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
from aiter import gemm_a8w8_blockscale

if _is_cuda:
Expand Down Expand Up @@ -132,7 +132,7 @@ def apply_w8a8_block_fp8_linear(
output = fp8_blockwise_scaled_mm(
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
)
elif _is_hip and get_bool_env_var("CK_MOE"):
elif _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
Expand Down
Loading