Skip to content

Commit ef8ec07

Browse files
authored
Support tuning moe for llama 4 model (#6042)
1 parent f24fc5b commit ef8ec07

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,12 @@ def main(args: argparse.Namespace):
408408
topk = config.num_experts_per_tok
409409
intermediate_size = config.moe_intermediate_size
410410
shard_intermediate_size = 2 * intermediate_size // args.tp_size
411+
elif config.architectures[0] == "Llama4ForConditionalGeneration":
412+
n_share_fusion_experts = args.n_share_experts_fusion
413+
E = config.text_config.num_local_experts + n_share_fusion_experts
414+
topk = config.text_config.num_experts_per_tok
415+
intermediate_size = config.text_config.intermediate_size
416+
shard_intermediate_size = 2 * intermediate_size // args.tp_size
411417
elif config.architectures[0] in [
412418
"Grok1ForCausalLM",
413419
"Grok1ImgGen",
@@ -424,7 +430,7 @@ def main(args: argparse.Namespace):
424430
intermediate_size = config.intermediate_size
425431
shard_intermediate_size = 2 * intermediate_size // args.tp_size
426432

427-
hidden_size = config.hidden_size
433+
hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size
428434
dtype = config.torch_dtype
429435
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
430436
use_int8_w8a8 = args.dtype == "int8_w8a8"

0 commit comments

Comments
 (0)