Skip to content

Commit 682f291

Browse files
committed
Add a simple sdpa
Add a simple sdpa so it's decomposed to simpler ops instead of the decompose F.scaled_dot_product_attention, which includes 29 ops including `torch.where` ``` def forward(self, q, k, v): aten_mul_scalar = executorch_exir_dialects_edge__ops_aten_mul_Scalar(q, 0.5946035575013605); q = None aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([8, 8], True, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'), pin_memory = False) aten_arange_start_step = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False) aten_unsqueeze_copy_default = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step, -2); aten_arange_start_step = None aten_arange_start_step_1 = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False) aten_unsqueeze_copy_default_1 = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step_1, -1); aten_arange_start_step_1 = None aten_sub_tensor = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_unsqueeze_copy_default, aten_unsqueeze_copy_default_1); aten_unsqueeze_copy_default = aten_unsqueeze_copy_default_1 = None aten_le_scalar = executorch_exir_dialects_edge__ops_aten_le_Scalar(aten_sub_tensor, 0); aten_sub_tensor = None aten_logical_and_default = executorch_exir_dialects_edge__ops_aten_logical_and_default(aten_le_scalar, aten_full_default); aten_le_scalar = aten_full_default = None aten_full_like_default = executorch_exir_dialects_edge__ops_aten_full_like_default(aten_logical_and_default, 0, dtype = torch.float32, pin_memory = False, memory_format = torch.preserve_format) aten_logical_not_default = executorch_exir_dialects_edge__ops_aten_logical_not_default(aten_logical_and_default); aten_logical_and_default = None aten_scalar_tensor_default = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) aten_where_self = executorch_exir_dialects_edge__ops_aten_where_self(aten_logical_not_default, aten_scalar_tensor_default, aten_full_like_default); aten_logical_not_default = aten_scalar_tensor_default = aten_full_like_default = None aten_permute_copy_default = executorch_exir_dialects_edge__ops_aten_permute_copy_default(k, [0, 1, 3, 2]); k = None aten_mul_scalar_1 = executorch_exir_dialects_edge__ops_aten_mul_Scalar(aten_permute_copy_default, 0.5946035575013605); aten_permute_copy_default = None aten_expand_copy_default = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar, [1, 1, 8, 8]); aten_mul_scalar = None aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default, [1, 8, 8]); aten_expand_copy_default = None aten_expand_copy_default_1 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar_1, [1, 1, 8, 8]); aten_mul_scalar_1 = None aten_view_copy_default_1 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_1, [1, 8, 8]); aten_expand_copy_default_1 = None aten_bmm_default = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default, aten_view_copy_default_1); aten_view_copy_default = aten_view_copy_default_1 = None aten_view_copy_default_2 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default, [1, 1, 8, 8]); aten_bmm_default = None aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_view_copy_default_2, aten_where_self); aten_view_copy_default_2 = aten_where_self = None aten__softmax_default = executorch_exir_dialects_edge__ops_aten__softmax_default(aten_add_tensor, -1, False); aten_add_tensor = None aten_expand_copy_default_2 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten__softmax_default, [1, 1, 8, 8]); aten__softmax_default = None aten_view_copy_default_3 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_2, [1, 8, 8]); aten_expand_copy_default_2 = None aten_expand_copy_default_3 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(v, [1, 1, 8, 8]); v = None aten_view_copy_default_4 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_3, [1, 8, 8]); aten_expand_copy_default_3 = None aten_bmm_default_1 = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default_3, aten_view_copy_default_4); aten_view_copy_default_3 = aten_view_copy_default_4 = None aten_view_copy_default_5 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default_1, [1, 1, 8, 8]); aten_bmm_default_1 = None return (aten_view_copy_default_5,) ``` Differential Revision: [D56119737](https://our.internmc.facebook.com/intern/diff/D56119737/) ghstack-source-id: 222465698 Pull Request resolved: #3037
1 parent a8f04ae commit 682f291

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import argparse
1010
import copy
1111
import logging
12+
import math
1213
import os
1314
import shlex
1415

@@ -33,6 +34,7 @@
3334
from executorch.sdk.etrecord import generate_etrecord
3435
from executorch.util.activation_memory_profiler import generate_memory_trace
3536
from sentencepiece import SentencePieceProcessor
37+
from torch.nn import functional as F
3638

3739
from .builder import DType, LlamaEdgeManager, load_llama_model, WeightType
3840
from .quant_lib import _get_pt2e_quantization_params, get_pt2e_quantizers
@@ -143,6 +145,61 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
143145
return module
144146

145147

148+
class SDPASimple(torch.nn.Module):
149+
def __init__(
150+
self,
151+
kv_cache: KVCache,
152+
dim: int,
153+
head_dim: int,
154+
n_rep: int,
155+
):
156+
super().__init__()
157+
self.kv_cache = kv_cache
158+
self.dim = dim
159+
self.head_dim = head_dim
160+
self.n_rep = n_rep
161+
162+
def forward(
163+
self,
164+
input_pos: torch.Tensor,
165+
q: torch.Tensor,
166+
k: torch.Tensor,
167+
v: torch.Tensor,
168+
bsz,
169+
seqlen,
170+
mask,
171+
):
172+
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
173+
k = k.transpose(1, 2)
174+
v = v.transpose(1, 2)
175+
176+
k, v = self.kv_cache.update(input_pos, k, v)
177+
mask = mask[None, None, input_pos]
178+
179+
k = k.repeat_interleave(self.n_rep, dim=1)
180+
v = v.repeat_interleave(self.n_rep, dim=1)
181+
scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
182+
scores = F.softmax(scores.float(), dim=-1).type_as(q)
183+
scores = scores + mask
184+
output = torch.matmul(scores, v) # (bs, n_local_heads, seqlen, head_dim)
185+
186+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
187+
return output
188+
189+
190+
def replace_sdpa_with_simple_sdpa(module: torch.nn.Module):
191+
for name, child in module.named_children():
192+
if isinstance(child, SDPA):
193+
setattr(
194+
module,
195+
name,
196+
SDPASimple(child.kv_cache, child.dim, child.head_dim, child.n_rep),
197+
)
198+
else:
199+
replace_sdpa_with_simple_sdpa(child)
200+
return module
201+
202+
146203
def quantize(
147204
model: torch.nn.Module,
148205
qmode: str,

0 commit comments

Comments
 (0)