From 6908d14487bcc7bbae491c43b5e439b5fc2bdbdb Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 21 Apr 2025 08:11:44 -0700 Subject: [PATCH] [Executorch][llama] Allow custom sdpa op replacement pass to leverage attention mask Pull Request resolved: https://github.com/pytorch/executorch/pull/10285 Previously we assumed that the custom sdpa always does causal attention. This diff adds option to this module swap pass to make custom sdpa leverage attention mask instead of causal. ghstack-source-id: 279292324 @exported-using-ghexport Differential Revision: [D73222736](https://our.internmc.facebook.com/intern/diff/D73222736/) --- .../llama/source_transformation/sdpa.py | 64 +++++++++++++++---- .../test_sdpa_with_quantized_kv_cache.py | 4 +- 2 files changed, 52 insertions(+), 16 deletions(-) diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index a50c6aeea22..1bc54198fba 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -22,9 +22,15 @@ class SDPACustom(torch.nn.Module): def __init__( self, dim: int, + max_context_len, + enable_dynamic_shape, + use_attention_mask: bool = False, ): super().__init__() self.dim = dim + self.max_context_len = max_context_len + self.use_attention_mask = use_attention_mask + self.enable_dynamic_shape = enable_dynamic_shape def forward( self, @@ -36,6 +42,16 @@ def forward( seqlen, mask, ): + if self.use_attention_mask: + if self.enable_dynamic_shape: + start_pos = input_pos[-1].item() + torch._check_is_size(start_pos) + torch._check(start_pos < self.max_context_len) + seq_length = q.size(2) + mask = mask.narrow(0, start_pos, seq_length) + else: + mask = mask[input_pos] + q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim) k = k.transpose(1, 2) v = v.transpose(1, 2) @@ -47,34 +63,54 @@ def forward( k = k.to(dtype=torch.float) v = v.to(dtype=torch.float) - output = torch.ops.llama.custom_sdpa( - q, - k, - v, - input_pos[0].item(), - None, # Attention mask - 0, # dropout probability. Ignored by the code - True, # is_causal - ) + if self.use_attention_mask: + output = torch.ops.llama.custom_sdpa( + q, + k, + v, + input_pos[0].item(), + mask, # Attention mask + 0, # dropout probability. Ignored by the code + False, # is_causal + ) + else: + output = torch.ops.llama.custom_sdpa( + q, + k, + v, + input_pos[0].item(), + None, # Attention mask + 0, # dropout probability. Ignored by the code + True, # is_causal + ) return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype) -def _replace_sdpa_with_custom_op(module: torch.nn.Module): +def _replace_sdpa_with_custom_op( + module: torch.nn.Module, use_attention_mask: bool = False +): for name, child in module.named_children(): if isinstance(child, SDPA): setattr( module, name, - SDPACustom(child.dim), + SDPACustom( + child.dim, + child.max_context_len, + child.enable_dynamic_shape, + use_attention_mask=use_attention_mask, + ), ) else: - _replace_sdpa_with_custom_op(child) + _replace_sdpa_with_custom_op(child, use_attention_mask=use_attention_mask) -def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: +def replace_sdpa_with_custom_op( + module: torch.nn.Module, use_attention_mask: bool = False +) -> torch.nn.Module: from executorch.extension.llm.custom_ops import custom_ops # noqa - _replace_sdpa_with_custom_op(module) + _replace_sdpa_with_custom_op(module, use_attention_mask=use_attention_mask) return module diff --git a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py index b2c93d7d93d..e5e278f8ce8 100644 --- a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py @@ -71,8 +71,8 @@ def test_simple(self, is_dynamic_shape=False): self.seq_len = 3 self._init_cache() q, k_val, v_val = self._init_kv() - self.float_sdpa = SDPACustom(self.dim) - self.quantized_sdpa = SDPACustom(self.dim) + self.float_sdpa = SDPACustom(self.dim, self.max_context_len, True) + self.quantized_sdpa = SDPACustom(self.dim, self.max_context_len, True) k, v = self.custom_kv_cache.update(input_pos, k_val, v_val) float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None) k, v = self.quantized_kv_cache.update(input_pos, k_val, v_val)