|
| 1 | +from abc import ABC, abstractmethod |
| 2 | +from typing import Dict, Optional, Tuple |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | +import torch.nn.functional as F |
| 7 | +from executorch.examples.models.llama.attention import ( |
| 8 | + Attention, |
| 9 | + AttentionMHA, |
| 10 | + ForwardOptions, |
| 11 | + register_attention, |
| 12 | +) |
| 13 | +from executorch.examples.models.llama.model_args import ModelArgs |
| 14 | +from executorch.examples.models.llama.rope import Rope |
| 15 | + |
| 16 | + |
| 17 | +_CacheMap = Dict[str, torch.Tensor] |
| 18 | +# Key and value caches are kept separate so the key caches can be kept transposed. |
| 19 | +_InputCacheState = Tuple[_CacheMap, _CacheMap] |
| 20 | +_OutputCacheState = Tuple[_CacheMap, _CacheMap] |
| 21 | + |
| 22 | + |
| 23 | +class StaticKVCache(nn.Module, ABC): |
| 24 | + def __init__(self, layer_id: int, head_id: int): |
| 25 | + super().__init__() |
| 26 | + self.layer_id = layer_id |
| 27 | + self.head_id = head_id |
| 28 | + |
| 29 | + @abstractmethod |
| 30 | + def update( |
| 31 | + self, |
| 32 | + new_data: torch.Tensor, |
| 33 | + in_cache_state: Optional[_InputCacheState], |
| 34 | + out_cache_state: Optional[_OutputCacheState], |
| 35 | + ) -> Tuple[torch.Tensor, Optional[_OutputCacheState]]: |
| 36 | + """ |
| 37 | + Given input cache state and new keys/values, returns the combined keys/values |
| 38 | + and the updated the output cache state. |
| 39 | + """ |
| 40 | + pass |
| 41 | + |
| 42 | + def cache_key(self) -> str: |
| 43 | + return self.calculate_cache_key(self.layer_id, self.head_id) |
| 44 | + |
| 45 | + @staticmethod |
| 46 | + def calculate_cache_key(layer_id: int, head_id: int) -> str: |
| 47 | + return f"l{layer_id},h{head_id}" |
| 48 | + |
| 49 | + @staticmethod |
| 50 | + def apply_update(cache, update, transpose=False): |
| 51 | + """ |
| 52 | + After inference, update the cache state for next iteration. The runtime needs to |
| 53 | + implement the same operation. |
| 54 | + """ |
| 55 | + if transpose: |
| 56 | + update_len = update.size(-1) |
| 57 | + updated = torch.roll(cache, -update_len, -1) |
| 58 | + updated[:, :, -update_len:] = update |
| 59 | + else: |
| 60 | + update_len = update.size(-2) |
| 61 | + updated = torch.roll(cache, -update_len, -2) |
| 62 | + updated[:, -update_len:, :] = update |
| 63 | + |
| 64 | + return updated |
| 65 | + |
| 66 | + |
| 67 | +class StaticKCache(StaticKVCache): |
| 68 | + def __init__(self, layer_id: int, head_id: int, transpose=False): |
| 69 | + """ |
| 70 | + If transpose is True, key cache is kept in (batch, dim, seq_len), otherwise in |
| 71 | + (batch, seq_len, dim). |
| 72 | + """ |
| 73 | + super().__init__(layer_id, head_id) |
| 74 | + self.transpose = transpose |
| 75 | + |
| 76 | + def update( |
| 77 | + self, |
| 78 | + new_data: torch.Tensor, |
| 79 | + in_cache_state: Optional[_InputCacheState], |
| 80 | + out_cache_state: Optional[_OutputCacheState], |
| 81 | + ) -> Tuple[torch.Tensor, Optional[_OutputCacheState]]: |
| 82 | + seq_dim = -2 |
| 83 | + if self.transpose: |
| 84 | + seq_dim = -1 |
| 85 | + new_data = new_data.transpose(-1, -2) |
| 86 | + if in_cache_state is None: |
| 87 | + return new_data, None |
| 88 | + if out_cache_state is None: |
| 89 | + out_cache_state = ({}, {}) |
| 90 | + |
| 91 | + all_data = torch.cat( |
| 92 | + [in_cache_state[0][self.cache_key()], new_data], dim=seq_dim |
| 93 | + ) |
| 94 | + out_k_cache, out_v_cache = out_cache_state |
| 95 | + out_k_cache[self.cache_key()] = new_data |
| 96 | + return all_data, (out_k_cache, out_v_cache) |
| 97 | + |
| 98 | + |
| 99 | +class StaticVCache(StaticKVCache): |
| 100 | + def update( |
| 101 | + self, |
| 102 | + new_data: torch.Tensor, |
| 103 | + in_cache_state: Optional[_InputCacheState], |
| 104 | + out_cache_state: Optional[_OutputCacheState], |
| 105 | + ) -> Tuple[torch.Tensor, Optional[_OutputCacheState]]: |
| 106 | + if in_cache_state is None: |
| 107 | + return new_data, None |
| 108 | + if out_cache_state is None: |
| 109 | + out_cache_state = ({}, {}) |
| 110 | + |
| 111 | + all_data = torch.cat([in_cache_state[1][self.cache_key()], new_data], dim=-2) |
| 112 | + out_k_cache, out_v_cache = out_cache_state |
| 113 | + out_v_cache[self.cache_key()] = new_data |
| 114 | + return all_data, (out_k_cache, out_v_cache) |
| 115 | + |
| 116 | + |
| 117 | +def _apply_rotary_embedding( |
| 118 | + x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor |
| 119 | +) -> torch.Tensor: |
| 120 | + x_r, x_i = x[..., ::2], x[..., 1::2] |
| 121 | + x_out_r = x_r * freqs_cos - x_i * freqs_sin |
| 122 | + x_out_i = x_r * freqs_sin + x_i * freqs_cos |
| 123 | + |
| 124 | + x_out = torch.cat([x_out_r, x_out_i], dim=-1) |
| 125 | + return x_out |
| 126 | + |
| 127 | + |
| 128 | +@register_attention("static") |
| 129 | +class StaticAttention(Attention): |
| 130 | + """ |
| 131 | + An attention implementation meant for NPUs that require static shapes and are not |
| 132 | + flexible with tensor operations needed to perform KV cache updates. MHA/GQA is |
| 133 | + implemented as multiple SHAs, and the KV caches keep valid data at the end so the |
| 134 | + model only needs to perform a concat to combine past and new data. |
| 135 | + """ |
| 136 | + |
| 137 | + def __init__(self, config: ModelArgs, layer_id: int, rope: Rope): |
| 138 | + super().__init__() |
| 139 | + self.n_heads = config.n_heads |
| 140 | + self.n_kv_heads = ( |
| 141 | + self.n_heads if config.n_kv_heads is None else config.n_kv_heads |
| 142 | + ) |
| 143 | + assert self.n_heads % self.n_kv_heads == 0 |
| 144 | + self.n_heads_per_kv_group = self.n_heads // self.n_kv_heads |
| 145 | + self.dim = config.dim |
| 146 | + self.head_dim = config.head_dim |
| 147 | + self.inv_scale = 1.0 / (float(self.head_dim) ** 0.5) |
| 148 | + |
| 149 | + self.wqs = nn.ModuleList( |
| 150 | + [ |
| 151 | + nn.Linear(self.dim, self.head_dim, bias=False) |
| 152 | + for _ in range(self.n_heads) |
| 153 | + ] |
| 154 | + ) |
| 155 | + self.wks = nn.ModuleList( |
| 156 | + [ |
| 157 | + nn.Linear(self.dim, self.head_dim, bias=False) |
| 158 | + for _ in range(self.n_kv_heads) |
| 159 | + ] |
| 160 | + ) |
| 161 | + self.wvs = nn.ModuleList( |
| 162 | + [ |
| 163 | + nn.Linear(self.dim, self.head_dim, bias=False) |
| 164 | + for _ in range(self.n_kv_heads) |
| 165 | + ] |
| 166 | + ) |
| 167 | + |
| 168 | + self.k_caches = nn.ModuleList( |
| 169 | + [StaticKCache(layer_id, i) for i in range(self.n_kv_heads)] |
| 170 | + ) |
| 171 | + self.v_caches = nn.ModuleList( |
| 172 | + [StaticVCache(layer_id, i) for i in range(self.n_kv_heads)] |
| 173 | + ) |
| 174 | + self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) |
| 175 | + |
| 176 | + def forward( |
| 177 | + self, |
| 178 | + x: torch.Tensor, |
| 179 | + freqs_cos: torch.Tensor, |
| 180 | + freqs_sin: torch.Tensor, |
| 181 | + **kwargs: ForwardOptions, |
| 182 | + ): |
| 183 | + mask = kwargs.get("mask") |
| 184 | + if (freqs_cos_override := kwargs.get("freqs_cos_override")) is not None: |
| 185 | + freqs_cos = freqs_cos_override # pyre-ignore |
| 186 | + if (freqs_sin_override := kwargs.get("freqs_sin_override")) is not None: |
| 187 | + freqs_sin = freqs_sin_override # pyre-ignore |
| 188 | + in_cache_state = kwargs.get("in_cache_state") |
| 189 | + out_cache_state = kwargs.get("out_cache_state") |
| 190 | + |
| 191 | + new_qs = [self.wqs[i](x) for i in range(self.n_heads)] |
| 192 | + new_ks = [self.wks[i](x) for i in range(self.n_kv_heads)] |
| 193 | + new_vs = [self.wvs[i](x) for i in range(self.n_kv_heads)] |
| 194 | + new_qs = [_apply_rotary_embedding(q, freqs_cos, freqs_sin) for q in new_qs] |
| 195 | + new_ks = [_apply_rotary_embedding(k, freqs_cos, freqs_sin) for k in new_ks] |
| 196 | + |
| 197 | + all_ks = [] |
| 198 | + all_vs = [] |
| 199 | + for i in range(self.n_kv_heads): |
| 200 | + ks, out_cache_state = self.k_caches[i].update( |
| 201 | + new_ks[i], in_cache_state, out_cache_state |
| 202 | + ) |
| 203 | + all_ks.append(ks) |
| 204 | + vs, out_cache_state = self.v_caches[i].update( |
| 205 | + new_vs[i], in_cache_state, out_cache_state |
| 206 | + ) |
| 207 | + all_vs.append(vs) |
| 208 | + |
| 209 | + heads = [] |
| 210 | + for i in range(self.n_heads): |
| 211 | + kv_idx = i // self.n_heads_per_kv_group |
| 212 | + attn = new_qs[i] @ all_ks[kv_idx].transpose(-2, -1) |
| 213 | + attn = attn * self.inv_scale |
| 214 | + attn = attn + mask # pyre-ignore |
| 215 | + attn = F.softmax(attn, dim=-1) |
| 216 | + heads.append(attn @ all_vs[kv_idx]) |
| 217 | + |
| 218 | + y = torch.cat(heads, dim=-1) |
| 219 | + y = self.wo(y) |
| 220 | + return y, {"out_cache_state": out_cache_state} |
| 221 | + |
| 222 | + def load_weights_from_attention_mha(self, other: AttentionMHA): |
| 223 | + for i in range(self.n_heads): |
| 224 | + self.wqs[i].weight.data.copy_( |
| 225 | + other.wq.weight[i * self.head_dim : (i + 1) * self.head_dim, :] |
| 226 | + ) |
| 227 | + |
| 228 | + for i in range(self.n_kv_heads): |
| 229 | + self.wks[i].weight.data.copy_( |
| 230 | + other.wk.weight[i * self.head_dim : (i + 1) * self.head_dim, :] |
| 231 | + ) |
| 232 | + self.wvs[i].weight.data.copy_( |
| 233 | + other.wv.weight[i * self.head_dim : (i + 1) * self.head_dim, :] |
| 234 | + ) |
| 235 | + |
| 236 | + self.wo.weight.data.copy_(other.wo.weight) |
0 commit comments