@@ -668,8 +668,8 @@ def weight_per_128x128_quant(weight, quant_dtype):
668668torch .cuda .manual_seed_all (seed )
669669l_dtype = ["bf16" , "fp16" ][:1 ]
670670# l_dim = [(6144, 4096)]
671- l_dim = [(7168 , 256 )]
672- # l_dim = [(3072, 3072)]
671+ # l_dim = [(7168, 256)]
672+ l_dim = [(3072 , 3072 )]
673673l_tokenNum = [
674674 # 1,
675675 # 2,
@@ -693,8 +693,8 @@ def weight_per_128x128_quant(weight, quant_dtype):
693693 # (aiter.QuantType.per_Token, dtypes.fp8, dtypes.fp8), # a8w8
694694 # (aiter.QuantType.per_Token, dtypes.fp8, torch.int4), # a8w4
695695 # (aiter.QuantType.per_1x32, dtypes.fp4x2, dtypes.fp4x2), # a4w4
696- (aiter .QuantType .per_128x128 , dtypes .fp8 , dtypes .fp8 ), # a8w8
697- # (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2), # a16w4
696+ # (aiter.QuantType.per_128x128, dtypes.fp8, dtypes.fp8), # a8w8
697+ (aiter .QuantType .per_1x32 , dtypes .bf16 , dtypes .fp4x2 ), # a16w4
698698]
699699l_act = [aiter .ActivationType .Silu , aiter .ActivationType .Gelu ][:1 ]
700700l_doweight_stage1 = [False , True ][:1 ]
0 commit comments