File tree Expand file tree Collapse file tree 1 file changed +7
-1
lines changed
benchmark/kernels/fused_moe_triton Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change @@ -408,6 +408,12 @@ def main(args: argparse.Namespace):
408
408
topk = config .num_experts_per_tok
409
409
intermediate_size = config .moe_intermediate_size
410
410
shard_intermediate_size = 2 * intermediate_size // args .tp_size
411
+ elif config .architectures [0 ] == "Llama4ForConditionalGeneration" :
412
+ n_share_fusion_experts = args .n_share_experts_fusion
413
+ E = config .text_config .num_local_experts + n_share_fusion_experts
414
+ topk = config .text_config .num_experts_per_tok
415
+ intermediate_size = config .text_config .intermediate_size
416
+ shard_intermediate_size = 2 * intermediate_size // args .tp_size
411
417
elif config .architectures [0 ] in [
412
418
"Grok1ForCausalLM" ,
413
419
"Grok1ImgGen" ,
@@ -424,7 +430,7 @@ def main(args: argparse.Namespace):
424
430
intermediate_size = config .intermediate_size
425
431
shard_intermediate_size = 2 * intermediate_size // args .tp_size
426
432
427
- hidden_size = config .hidden_size
433
+ hidden_size = getattr ( config , "hidden_size" , None ) or config . text_config .hidden_size
428
434
dtype = config .torch_dtype
429
435
use_fp8_w8a8 = args .dtype == "fp8_w8a8"
430
436
use_int8_w8a8 = args .dtype == "int8_w8a8"
You can’t perform that action at this time.
0 commit comments