File tree 1 file changed +13
-1
lines changed 1 file changed +13
-1
lines changed Original file line number Diff line number Diff line change @@ -1222,10 +1222,22 @@ def _get_source_transforms( # noqa
1222
1222
if args .expand_rope_table :
1223
1223
transforms .append (materialze_broadcast_of_rope_freq_cis )
1224
1224
1225
+ use_attention_mask_for_custom_sdpa = False
1226
+ if isinstance (args , argparse .Namespace ):
1227
+ if getattr (args , "use_custom_sdpa_with_attention_mask" , None ):
1228
+ use_attention_mask_for_custom_sdpa = True
1229
+
1225
1230
if args .use_sdpa_with_kv_cache :
1226
1231
transforms .append (replace_kv_cache_with_custom_kv_cache )
1227
1232
# todo: do this optionally
1228
- transforms .append (replace_sdpa_with_custom_op )
1233
+ # if use attention mask instead of causal attention
1234
+ # then create partial function that sets use_attention_mask=True
1235
+ if use_attention_mask_for_custom_sdpa :
1236
+ transforms .append (
1237
+ partial (replace_sdpa_with_custom_op , use_attention_mask = True )
1238
+ )
1239
+ else :
1240
+ transforms .append (replace_sdpa_with_custom_op )
1229
1241
1230
1242
if args .quantize_kv_cache :
1231
1243
assert args .use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
You can’t perform that action at this time.
0 commit comments