|
9 | 9 | import argparse
|
10 | 10 | import copy
|
11 | 11 | import logging
|
| 12 | +import math |
12 | 13 | import os
|
13 | 14 | import shlex
|
14 | 15 |
|
|
33 | 34 | from executorch.sdk.etrecord import generate_etrecord
|
34 | 35 | from executorch.util.activation_memory_profiler import generate_memory_trace
|
35 | 36 | from sentencepiece import SentencePieceProcessor
|
| 37 | +from torch.nn import functional as F |
36 | 38 |
|
37 | 39 | from .builder import DType, LlamaEdgeManager, load_llama_model, WeightType
|
38 | 40 | from .quant_lib import _get_pt2e_quantization_params, get_pt2e_quantizers
|
@@ -143,6 +145,61 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
|
143 | 145 | return module
|
144 | 146 |
|
145 | 147 |
|
| 148 | +class SDPASimple(torch.nn.Module): |
| 149 | + def __init__( |
| 150 | + self, |
| 151 | + kv_cache: KVCache, |
| 152 | + dim: int, |
| 153 | + head_dim: int, |
| 154 | + n_rep: int, |
| 155 | + ): |
| 156 | + super().__init__() |
| 157 | + self.kv_cache = kv_cache |
| 158 | + self.dim = dim |
| 159 | + self.head_dim = head_dim |
| 160 | + self.n_rep = n_rep |
| 161 | + |
| 162 | + def forward( |
| 163 | + self, |
| 164 | + input_pos: torch.Tensor, |
| 165 | + q: torch.Tensor, |
| 166 | + k: torch.Tensor, |
| 167 | + v: torch.Tensor, |
| 168 | + bsz, |
| 169 | + seqlen, |
| 170 | + mask, |
| 171 | + ): |
| 172 | + q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) |
| 173 | + k = k.transpose(1, 2) |
| 174 | + v = v.transpose(1, 2) |
| 175 | + |
| 176 | + k, v = self.kv_cache.update(input_pos, k, v) |
| 177 | + mask = mask[None, None, input_pos] |
| 178 | + |
| 179 | + k = k.repeat_interleave(self.n_rep, dim=1) |
| 180 | + v = v.repeat_interleave(self.n_rep, dim=1) |
| 181 | + scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim) |
| 182 | + scores = F.softmax(scores.float(), dim=-1).type_as(q) |
| 183 | + scores = scores + mask |
| 184 | + output = torch.matmul(scores, v) # (bs, n_local_heads, seqlen, head_dim) |
| 185 | + |
| 186 | + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) |
| 187 | + return output |
| 188 | + |
| 189 | + |
| 190 | +def replace_sdpa_with_simple_sdpa(module: torch.nn.Module): |
| 191 | + for name, child in module.named_children(): |
| 192 | + if isinstance(child, SDPA): |
| 193 | + setattr( |
| 194 | + module, |
| 195 | + name, |
| 196 | + SDPASimple(child.kv_cache, child.dim, child.head_dim, child.n_rep), |
| 197 | + ) |
| 198 | + else: |
| 199 | + replace_sdpa_with_simple_sdpa(child) |
| 200 | + return module |
| 201 | + |
| 202 | + |
146 | 203 | def quantize(
|
147 | 204 | model: torch.nn.Module,
|
148 | 205 | qmode: str,
|
|
0 commit comments