Skip to content

Commit 7c542ec

Browse files
committed
[Executorch][llama] Hookup use_attention_mask option in the source transforms inside llm mananger
Differential Revision: [D73222734](https://our.internmc.facebook.com/intern/diff/D73222734/) ghstack-source-id: 278833094 Pull Request resolved: #10286
1 parent e16d7c8 commit 7c542ec

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1222,10 +1222,22 @@ def _get_source_transforms( # noqa
12221222
if args.expand_rope_table:
12231223
transforms.append(materialze_broadcast_of_rope_freq_cis)
12241224

1225+
use_attention_mask_for_custom_sdpa = False
1226+
if isinstance(args, argparse.Namespace):
1227+
if getattr(args, "use_custom_sdpa_with_attention_mask", None):
1228+
use_attention_mask_for_custom_sdpa = True
1229+
12251230
if args.use_sdpa_with_kv_cache:
12261231
transforms.append(replace_kv_cache_with_custom_kv_cache)
12271232
# todo: do this optionally
1228-
transforms.append(replace_sdpa_with_custom_op)
1233+
# if use attention mask instead of causal attention
1234+
# then create partial function that sets use_attention_mask=True
1235+
if use_attention_mask_for_custom_sdpa:
1236+
transforms.append(
1237+
partial(replace_sdpa_with_custom_op, use_attention_mask=True)
1238+
)
1239+
else:
1240+
transforms.append(replace_sdpa_with_custom_op)
12291241

12301242
if args.quantize_kv_cache:
12311243
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"

0 commit comments

Comments
 (0)