@@ -376,45 +376,50 @@ def rewrite(
376
376
)
377
377
378
378
379
- parameter_combinations = [
380
- {
381
- "double_transpose" : double_transpose ,
382
- "transpose_4d" : transpose_4d ,
383
- "pre_scale_q" : pre_scale_q ,
384
- "is_rotary" : is_rotary ,
385
- "use_mask" : use_mask ,
386
- "has_past_present" : has_past_present ,
387
- "is_cross_attention" : is_cross_attention ,
388
- }
389
- for double_transpose in [False , True ]
390
- for transpose_4d in (
391
- [False , True ] if double_transpose else [False ]
392
- ) # Only generate patterns when double_transpose is True
393
- for pre_scale_q in [True , False ]
394
- for is_rotary in [False , True ]
395
- for use_mask in [False , True ]
396
- for is_cross_attention in [False , True ]
397
- for has_past_present in ([False ] if is_cross_attention else [True , False ])
398
- # Skip if both has_past_present and is_cross_attention are True
399
- if not (has_past_present and is_cross_attention )
400
- ]
401
-
402
- # Dynamically create the rules
403
- mha_rules = pattern .RewriteRuleSet (
404
- [
405
- MultiHeadAttention .rule (
406
- f"MHA_{ '4D' if params ['transpose_4d' ] else '3D' } _Transpose"
407
- f"{ '_Twice' if params ['double_transpose' ] else '' } "
408
- f"{ '_PreScaleQ' if params ['pre_scale_q' ] else '' } "
409
- f"{ '_Rotary' if params ['is_rotary' ] else '' } "
410
- f"{ '_Masked' if params ['use_mask' ] else '' } "
411
- f"{ '_Past' if params ['has_past_present' ] else '' } "
412
- f"{ '_CrossAttention' if params ['is_cross_attention' ] else '' } " ,
413
- ** params ,
414
- )
415
- for params in parameter_combinations
379
+ def _make_rule_set (has_past_present : bool ):
380
+ parameter_combinations = [
381
+ {
382
+ "double_transpose" : double_transpose ,
383
+ "transpose_4d" : transpose_4d ,
384
+ "pre_scale_q" : pre_scale_q ,
385
+ "is_rotary" : is_rotary ,
386
+ "use_mask" : use_mask ,
387
+ "has_past_present" : has_past_present ,
388
+ "is_cross_attention" : is_cross_attention ,
389
+ }
390
+ for double_transpose in [False , True ]
391
+ for transpose_4d in (
392
+ [False , True ] if double_transpose else [False ]
393
+ ) # Only generate patterns when double_transpose is True
394
+ for pre_scale_q in [True , False ]
395
+ for is_rotary in [False , True ]
396
+ for use_mask in [False , True ]
397
+ for is_cross_attention in ([False ] if has_past_present else [False , True ])
416
398
]
417
- )
418
399
400
+ # Dynamically create the rules
401
+ mha_rules = pattern .RewriteRuleSet (
402
+ [
403
+ MultiHeadAttention .rule (
404
+ f"MHA_{ '4D' if params ['transpose_4d' ] else '3D' } _Transpose"
405
+ f"{ '_Twice' if params ['double_transpose' ] else '' } "
406
+ f"{ '_PreScaleQ' if params ['pre_scale_q' ] else '' } "
407
+ f"{ '_Rotary' if params ['is_rotary' ] else '' } "
408
+ f"{ '_Masked' if params ['use_mask' ] else '' } "
409
+ f"{ '_Past' if params ['has_past_present' ] else '' } "
410
+ f"{ '_CrossAttention' if params ['is_cross_attention' ] else '' } " ,
411
+ ** params ,
412
+ )
413
+ for params in parameter_combinations
414
+ ]
415
+ )
416
+
417
+ return mha_rules
418
+
419
+
420
+ mha_rules_no_past = _make_rule_set (has_past_present = False )
421
+ mha_rules_with_past = _make_rule_set (has_past_present = True )
419
422
420
- fuse_mha = _fusion_utils .apply_fusion_rules (mha_rules )
423
+ # Try rules with past first, and then rules without past.
424
+ fuse_mha1 = _fusion_utils .apply_fusion_rules (mha_rules_with_past )
425
+ fuse_mha2 = _fusion_utils .apply_fusion_rules (mha_rules_no_past )
0 commit comments