@@ -1515,33 +1515,33 @@ __global__ void moe_smooth_per_token_scaled_quant_kernel_v2(DTYPE_O* __restrict_
15151515}
15161516
15171517
1518- #define MOE_SMOOTH_PER_TOKEN_SCALED_QUANT_KERNEL_V2_IMPL (quant_kernel, DTYPE_O, THREAD_DATA, BLOCK_SIZE ) \
1519- AITER_DISPATCH_FLOATING16_TYPES (input.scalar_type(), " quant_kernel" , [&] { \
1520- using input_dtype = typename t2ck<scalar_t >::type; \
1521- int warps_per_cu = 8 * BLOCK_SIZE / WARP_SIZE; \
1522- int num_tg = persistent_mode? num_cu * warps_per_cu : num_blocks; \
1523- dim3 const grid (num_tg); \
1524- aiter::quant_kernel<input_dtype, DTYPE_O, BLOCK_SIZE, THREAD_DATA> \
1525- <<<grid, dim3 (BLOCK_SIZE), 0 , stream>>> ( \
1526- reinterpret_cast <DTYPE_O*>(out.data_ptr ()), \
1527- scales.data_ptr <float >(), \
1528- reinterpret_cast <input_dtype*>(input.data_ptr ()), \
1529- smooth_scale.data_ptr <float >(), \
1530- sorted_token_ids.data_ptr <int >(), \
1531- sorted_expert_ids.data_ptr <int >(), \
1532- num_valid_ids.data_ptr <int >(), \
1533- num_experts, \
1534- num_tokens, \
1535- num_blocks, \
1536- num_tg, \
1537- cols, \
1538- topk, \
1539- block_m, \
1540- block_m_log2split, \
1541- input_stride0, \
1542- input_stride1, \
1543- shuffle_scale, \
1544- transpose_out); \
1518+ #define MOE_SMOOTH_PER_TOKEN_SCALED_QUANT_KERNEL_V2_IMPL (quant_kernel, DTYPE_O, THREAD_DATA, BLOCK_SIZE ) \
1519+ AITER_DISPATCH_FLOATING16_TYPES (input.scalar_type(), " quant_kernel" , [&] { \
1520+ using input_dtype = typename t2ck<scalar_t >::type; \
1521+ int warps_per_cu = 8 * BLOCK_SIZE / WARP_SIZE; \
1522+ int num_tg = persistent_mode? num_cu * warps_per_cu : num_blocks; \
1523+ dim3 const grid (num_tg); \
1524+ aiter::quant_kernel<input_dtype, DTYPE_O, BLOCK_SIZE, THREAD_DATA> \
1525+ <<<grid, dim3 (BLOCK_SIZE), 0 , stream>>> ( \
1526+ reinterpret_cast <DTYPE_O*>(out.data_ptr ()), \
1527+ scales.data_ptr <float >(), \
1528+ reinterpret_cast <input_dtype*>(input.data_ptr ()), \
1529+ smooth_scale.data_ptr <float >(), \
1530+ sorted_token_ids.data_ptr <int >(), \
1531+ sorted_expert_ids.data_ptr <int >(), \
1532+ num_valid_ids.data_ptr <int >(), \
1533+ num_experts, \
1534+ num_tokens, \
1535+ num_blocks, \
1536+ num_tg, \
1537+ cols, \
1538+ topk, \
1539+ block_m, \
1540+ block_m_log2split, \
1541+ input_stride0, \
1542+ input_stride1, \
1543+ shuffle_scale, \
1544+ transpose_out); \
15451545 });
15461546
15471547
@@ -1589,10 +1589,11 @@ void moe_smooth_per_token_scaled_quant_v2(
15891589 int input_stride1= input.dim () == 2 ? 0 : input.stride (1 );
15901590
15911591 const int num_cu = get_num_cu_func ();
1592- int sub_block_m = 2 ;
1593- int num_blocks = sorted_expert_ids.size (0 ) * (block_m / sub_block_m);
1594- int block_split = block_m / sub_block_m;
1592+ int block_split = 16 ;
15951593 int block_m_log2split = log2 (block_split);
1594+ TORCH_CHECK (block_m % block_split == 0 , __func__, " block_m is not divisible by block_split" );
1595+ int sub_block_m = block_m >> block_m_log2split;
1596+ int num_blocks = sorted_expert_ids.size (0 ) * block_split;
15961597 const bool persistent_mode = true ;
15971598
15981599 const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard (device_of (input));
0 commit comments