Skip to content

Commit 3fb9c71

Browse files
authored
update moe_smooth_per_token_scaled_quant dispatch and v2 supports block_m is a multiple of 16 (#2333)
1 parent cb0b0c8 commit 3fb9c71

File tree

3 files changed

+57
-71
lines changed

3 files changed

+57
-71
lines changed

aiter/fused_moe_bf16_asm.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -198,35 +198,6 @@ def asm_moe(
198198
# aiter.moe_smoothquant_fwd(
199199
# a8, hidden_states, fc1_smooth_scale, topk_ids, a8_scale
200200
# )
201-
# aiter.smooth_per_token_scaled_quant(
202-
# a8.view(topk, M, model_dim).transpose(0, 1),
203-
# hidden_states.view(M, 1, model_dim).expand(-1, topk, -1),
204-
# a8_scale,
205-
# fc1_smooth_scale,
206-
# topk_ids,
207-
# smooth_scale_map_hash=local_expert_hash,
208-
# enable_ps=True,
209-
# )
210-
# aiter.moe_smooth_per_token_scaled_quant_v1(
211-
# a8,
212-
# hidden_states,
213-
# a8_scale,
214-
# fc1_smooth_scale,
215-
# topk_ids,
216-
# smooth_scale_map_hash=local_expert_hash,
217-
# transpose_out=True,
218-
# )
219-
# aiter.moe_smooth_per_token_scaled_quant_v2(
220-
# a8,
221-
# hidden_states,
222-
# a8_scale,
223-
# fc1_smooth_scale,
224-
# sorted_ids,
225-
# sorted_expert_ids,
226-
# num_valid_ids,
227-
# BLOCK_SIZE_M,
228-
# transpose_out=True,
229-
# )
230201
aiter.moe_smooth_per_token_scaled_quant(
231202
a8,
232203
hidden_states,

aiter/ops/quant.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -415,21 +415,35 @@ def moe_smooth_per_token_scaled_quant(
415415
local_expert_hash: Optional[torch.Tensor] = None,
416416
shuffle_scale: bool = False,
417417
transpose_out: bool = False,
418+
is_balanced: bool = False,
418419
) -> None:
419420
cu_num = get_cu_num()
420421
is_moe_stage1 = input.numel() != out.numel()
421-
token_num = input.shape[0]
422-
if is_moe_stage1 and local_expert_hash is not None and token_num < cu_num * 8:
423-
moe_smooth_per_token_scaled_quant_v1(
424-
out,
425-
input,
426-
scales,
427-
smooth_scale,
428-
topk_ids,
429-
shuffle_scale,
430-
local_expert_hash,
431-
transpose_out,
432-
)
422+
M = input.shape[0]
423+
if is_moe_stage1 and local_expert_hash is not None and M < cu_num * 8:
424+
if is_balanced:
425+
moe_smooth_per_token_scaled_quant_v1(
426+
out,
427+
input,
428+
scales,
429+
smooth_scale,
430+
topk_ids,
431+
shuffle_scale,
432+
local_expert_hash,
433+
transpose_out,
434+
)
435+
else:
436+
topk = topk_ids.shape[1]
437+
model_dim = input.shape[-1]
438+
smooth_per_token_scaled_quant(
439+
out.view(topk, M, model_dim).transpose(0, 1),
440+
input.view(M, 1, model_dim).expand(-1, topk, -1),
441+
scales,
442+
smooth_scale,
443+
topk_ids,
444+
smooth_scale_map_hash=local_expert_hash,
445+
enable_ps=True,
446+
)
433447
else:
434448
moe_smooth_per_token_scaled_quant_v2(
435449
out,

csrc/kernels/quant_kernels.cu

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,33 +1515,33 @@ __global__ void moe_smooth_per_token_scaled_quant_kernel_v2(DTYPE_O* __restrict_
15151515
}
15161516

15171517

1518-
#define MOE_SMOOTH_PER_TOKEN_SCALED_QUANT_KERNEL_V2_IMPL(quant_kernel, DTYPE_O, THREAD_DATA, BLOCK_SIZE) \
1519-
AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "quant_kernel", [&] { \
1520-
using input_dtype = typename t2ck<scalar_t>::type; \
1521-
int warps_per_cu = 8 * BLOCK_SIZE / WARP_SIZE; \
1522-
int num_tg = persistent_mode? num_cu * warps_per_cu : num_blocks; \
1523-
dim3 const grid(num_tg); \
1524-
aiter::quant_kernel<input_dtype, DTYPE_O, BLOCK_SIZE, THREAD_DATA> \
1525-
<<<grid, dim3(BLOCK_SIZE), 0, stream>>>( \
1526-
reinterpret_cast<DTYPE_O*>(out.data_ptr()), \
1527-
scales.data_ptr<float>(), \
1528-
reinterpret_cast<input_dtype*>(input.data_ptr()), \
1529-
smooth_scale.data_ptr<float>(), \
1530-
sorted_token_ids.data_ptr<int>(), \
1531-
sorted_expert_ids.data_ptr<int>(), \
1532-
num_valid_ids.data_ptr<int>(), \
1533-
num_experts, \
1534-
num_tokens, \
1535-
num_blocks, \
1536-
num_tg, \
1537-
cols, \
1538-
topk, \
1539-
block_m, \
1540-
block_m_log2split, \
1541-
input_stride0, \
1542-
input_stride1, \
1543-
shuffle_scale, \
1544-
transpose_out); \
1518+
#define MOE_SMOOTH_PER_TOKEN_SCALED_QUANT_KERNEL_V2_IMPL(quant_kernel, DTYPE_O, THREAD_DATA, BLOCK_SIZE) \
1519+
AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "quant_kernel", [&] { \
1520+
using input_dtype = typename t2ck<scalar_t>::type; \
1521+
int warps_per_cu = 8 * BLOCK_SIZE / WARP_SIZE; \
1522+
int num_tg = persistent_mode? num_cu * warps_per_cu : num_blocks; \
1523+
dim3 const grid(num_tg); \
1524+
aiter::quant_kernel<input_dtype, DTYPE_O, BLOCK_SIZE, THREAD_DATA> \
1525+
<<<grid, dim3(BLOCK_SIZE), 0, stream>>>( \
1526+
reinterpret_cast<DTYPE_O*>(out.data_ptr()), \
1527+
scales.data_ptr<float>(), \
1528+
reinterpret_cast<input_dtype*>(input.data_ptr()), \
1529+
smooth_scale.data_ptr<float>(), \
1530+
sorted_token_ids.data_ptr<int>(), \
1531+
sorted_expert_ids.data_ptr<int>(), \
1532+
num_valid_ids.data_ptr<int>(), \
1533+
num_experts, \
1534+
num_tokens, \
1535+
num_blocks, \
1536+
num_tg, \
1537+
cols, \
1538+
topk, \
1539+
block_m, \
1540+
block_m_log2split, \
1541+
input_stride0, \
1542+
input_stride1, \
1543+
shuffle_scale, \
1544+
transpose_out); \
15451545
});
15461546

15471547

@@ -1589,10 +1589,11 @@ void moe_smooth_per_token_scaled_quant_v2(
15891589
int input_stride1= input.dim() == 2 ? 0 : input.stride(1);
15901590

15911591
const int num_cu = get_num_cu_func();
1592-
int sub_block_m = 2;
1593-
int num_blocks = sorted_expert_ids.size(0) * (block_m / sub_block_m);
1594-
int block_split = block_m / sub_block_m;
1592+
int block_split = 16;
15951593
int block_m_log2split = log2(block_split);
1594+
TORCH_CHECK(block_m % block_split == 0, __func__, " block_m is not divisible by block_split");
1595+
int sub_block_m = block_m >> block_m_log2split;
1596+
int num_blocks = sorted_expert_ids.size(0) * block_split;
15961597
const bool persistent_mode = true;
15971598

15981599
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input));

0 commit comments

Comments
 (0)