Skip to content

Commit 817a6eb

Browse files
committed
update
1 parent 646f8a0 commit 817a6eb

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

3rdparty/composable_kernel

Submodule composable_kernel updated 347 files

csrc/pybind/moe_op_pybind.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
/* SPDX-License-Identifier: MIT
22
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
33
*/
4-
#include "rocm_ops.hpp"
54
#include "moe_op.h"
5+
#include "rocm_ops.hpp"
66

77
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
88
{
9+
AITER_ENUM_PYBIND;
910
MOE_OP_PYBIND;
1011
}

op_tests/test_moe_2stage.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -668,8 +668,8 @@ def weight_per_128x128_quant(weight, quant_dtype):
668668
torch.cuda.manual_seed_all(seed)
669669
l_dtype = ["bf16", "fp16"][:1]
670670
# l_dim = [(6144, 4096)]
671-
l_dim = [(7168, 256)]
672-
# l_dim = [(3072, 3072)]
671+
# l_dim = [(7168, 256)]
672+
l_dim = [(3072, 3072)]
673673
l_tokenNum = [
674674
# 1,
675675
# 2,
@@ -693,8 +693,8 @@ def weight_per_128x128_quant(weight, quant_dtype):
693693
# (aiter.QuantType.per_Token, dtypes.fp8, dtypes.fp8), # a8w8
694694
# (aiter.QuantType.per_Token, dtypes.fp8, torch.int4), # a8w4
695695
# (aiter.QuantType.per_1x32, dtypes.fp4x2, dtypes.fp4x2), # a4w4
696-
(aiter.QuantType.per_128x128, dtypes.fp8, dtypes.fp8), # a8w8
697-
# (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2), # a16w4
696+
# (aiter.QuantType.per_128x128, dtypes.fp8, dtypes.fp8), # a8w8
697+
(aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2), # a16w4
698698
]
699699
l_act = [aiter.ActivationType.Silu, aiter.ActivationType.Gelu][:1]
700700
l_doweight_stage1 = [False, True][:1]

0 commit comments

Comments
 (0)