Skip to content

Commit e841162

Browse files
sxufacebook-github-bot
authored andcommitted
Static attention implementation (#8310)
Summary: Merge existing implementation for static shaped NPU where valid data are kept at the end of cache and MHA -> SHA rewrite is done. Reviewed By: iseeyuan Differential Revision: D69080741
1 parent 16281ce commit e841162

File tree

6 files changed

+451
-2
lines changed

6 files changed

+451
-2
lines changed

examples/models/llama/TARGETS

+17
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,23 @@ runtime.python_library(
2828
],
2929
)
3030

31+
runtime.python_library(
32+
name = "static_attention",
33+
srcs = [
34+
"static_attention.py",
35+
],
36+
_is_external_target = True,
37+
base_module = "executorch.examples.models.llama",
38+
visibility = [
39+
"//executorch/...",
40+
"@EXECUTORCH_CLIENTS",
41+
],
42+
deps = [
43+
":llama_transformer",
44+
"//caffe2:torch",
45+
],
46+
)
47+
3148
runtime.python_library(
3249
name = "llama2_model",
3350
srcs = [

examples/models/llama/attention.py

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ class ForwardOptions(TypedDict, total=False):
1313

1414
mask: Optional[torch.Tensor]
1515
input_pos: Optional[torch.Tensor]
16+
freqs_cos_override: Optional[torch.Tensor]
17+
freqs_sin_override: Optional[torch.Tensor]
1618
in_cache_state: Optional[Any]
1719
out_cache_state: Optional[Any]
1820

examples/models/llama/llama_transformer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,13 @@ def forward(
205205
attn_options.get("input_pos"), seqlen
206206
)
207207

208+
# Make a shallow copy so the updates don't get captured by export
209+
attn_options_ = attn_options.copy() if attn_options is not None else {}
208210
attn_options_update = None
209211
for layer in self.layers:
210-
h, attn_options_update = layer(h, freqs_cos, freqs_sin, attn_options)
212+
h, attn_options_update = layer(h, freqs_cos, freqs_sin, attn_options_)
211213
if attn_options_update is not None:
212-
attn_options.update(**attn_options_update)
214+
attn_options_.update(**attn_options_update)
213215

214216
if not self.generate_full_logits:
215217
# Only the last logit is used for the new generated token
+236
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Dict, Optional, Tuple
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
from executorch.examples.models.llama.attention import (
8+
Attention,
9+
AttentionMHA,
10+
ForwardOptions,
11+
register_attention,
12+
)
13+
from executorch.examples.models.llama.model_args import ModelArgs
14+
from executorch.examples.models.llama.rope import Rope
15+
16+
17+
_CacheMap = Dict[str, torch.Tensor]
18+
# Key and value caches are kept separate so the key caches can be kept transposed.
19+
_InputCacheState = Tuple[_CacheMap, _CacheMap]
20+
_OutputCacheState = Tuple[_CacheMap, _CacheMap]
21+
22+
23+
class StaticKVCache(nn.Module, ABC):
24+
def __init__(self, layer_id: int, head_id: int):
25+
super().__init__()
26+
self.layer_id = layer_id
27+
self.head_id = head_id
28+
29+
@abstractmethod
30+
def update(
31+
self,
32+
new_data: torch.Tensor,
33+
in_cache_state: Optional[_InputCacheState],
34+
out_cache_state: Optional[_OutputCacheState],
35+
) -> Tuple[torch.Tensor, Optional[_OutputCacheState]]:
36+
"""
37+
Given input cache state and new keys/values, returns the combined keys/values
38+
and the updated the output cache state.
39+
"""
40+
pass
41+
42+
def cache_key(self) -> str:
43+
return self.calculate_cache_key(self.layer_id, self.head_id)
44+
45+
@staticmethod
46+
def calculate_cache_key(layer_id: int, head_id: int) -> str:
47+
return f"l{layer_id},h{head_id}"
48+
49+
@staticmethod
50+
def apply_update(cache, update, transpose=False):
51+
"""
52+
After inference, update the cache state for next iteration. The runtime needs to
53+
implement the same operation.
54+
"""
55+
if transpose:
56+
update_len = update.size(-1)
57+
updated = torch.roll(cache, -update_len, -1)
58+
updated[:, :, -update_len:] = update
59+
else:
60+
update_len = update.size(-2)
61+
updated = torch.roll(cache, -update_len, -2)
62+
updated[:, -update_len:, :] = update
63+
64+
return updated
65+
66+
67+
class StaticKCache(StaticKVCache):
68+
def __init__(self, layer_id: int, head_id: int, transpose=False):
69+
"""
70+
If transpose is True, key cache is kept in (batch, dim, seq_len), otherwise in
71+
(batch, seq_len, dim).
72+
"""
73+
super().__init__(layer_id, head_id)
74+
self.transpose = transpose
75+
76+
def update(
77+
self,
78+
new_data: torch.Tensor,
79+
in_cache_state: Optional[_InputCacheState],
80+
out_cache_state: Optional[_OutputCacheState],
81+
) -> Tuple[torch.Tensor, Optional[_OutputCacheState]]:
82+
seq_dim = -2
83+
if self.transpose:
84+
seq_dim = -1
85+
new_data = new_data.transpose(-1, -2)
86+
if in_cache_state is None:
87+
return new_data, None
88+
if out_cache_state is None:
89+
out_cache_state = ({}, {})
90+
91+
all_data = torch.cat(
92+
[in_cache_state[0][self.cache_key()], new_data], dim=seq_dim
93+
)
94+
out_k_cache, out_v_cache = out_cache_state
95+
out_k_cache[self.cache_key()] = new_data
96+
return all_data, (out_k_cache, out_v_cache)
97+
98+
99+
class StaticVCache(StaticKVCache):
100+
def update(
101+
self,
102+
new_data: torch.Tensor,
103+
in_cache_state: Optional[_InputCacheState],
104+
out_cache_state: Optional[_OutputCacheState],
105+
) -> Tuple[torch.Tensor, Optional[_OutputCacheState]]:
106+
if in_cache_state is None:
107+
return new_data, None
108+
if out_cache_state is None:
109+
out_cache_state = ({}, {})
110+
111+
all_data = torch.cat([in_cache_state[1][self.cache_key()], new_data], dim=-2)
112+
out_k_cache, out_v_cache = out_cache_state
113+
out_v_cache[self.cache_key()] = new_data
114+
return all_data, (out_k_cache, out_v_cache)
115+
116+
117+
def _apply_rotary_embedding(
118+
x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
119+
) -> torch.Tensor:
120+
x_r, x_i = x[..., ::2], x[..., 1::2]
121+
x_out_r = x_r * freqs_cos - x_i * freqs_sin
122+
x_out_i = x_r * freqs_sin + x_i * freqs_cos
123+
124+
x_out = torch.cat([x_out_r, x_out_i], dim=-1)
125+
return x_out
126+
127+
128+
@register_attention("static")
129+
class StaticAttention(Attention):
130+
"""
131+
An attention implementation meant for NPUs that require static shapes and are not
132+
flexible with tensor operations needed to perform KV cache updates. MHA/GQA is
133+
implemented as multiple SHAs, and the KV caches keep valid data at the end so the
134+
model only needs to perform a concat to combine past and new data.
135+
"""
136+
137+
def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
138+
super().__init__()
139+
self.n_heads = config.n_heads
140+
self.n_kv_heads = (
141+
self.n_heads if config.n_kv_heads is None else config.n_kv_heads
142+
)
143+
assert self.n_heads % self.n_kv_heads == 0
144+
self.n_heads_per_kv_group = self.n_heads // self.n_kv_heads
145+
self.dim = config.dim
146+
self.head_dim = config.head_dim
147+
self.inv_scale = 1.0 / (float(self.head_dim) ** 0.5)
148+
149+
self.wqs = nn.ModuleList(
150+
[
151+
nn.Linear(self.dim, self.head_dim, bias=False)
152+
for _ in range(self.n_heads)
153+
]
154+
)
155+
self.wks = nn.ModuleList(
156+
[
157+
nn.Linear(self.dim, self.head_dim, bias=False)
158+
for _ in range(self.n_kv_heads)
159+
]
160+
)
161+
self.wvs = nn.ModuleList(
162+
[
163+
nn.Linear(self.dim, self.head_dim, bias=False)
164+
for _ in range(self.n_kv_heads)
165+
]
166+
)
167+
168+
self.k_caches = nn.ModuleList(
169+
[StaticKCache(layer_id, i) for i in range(self.n_kv_heads)]
170+
)
171+
self.v_caches = nn.ModuleList(
172+
[StaticVCache(layer_id, i) for i in range(self.n_kv_heads)]
173+
)
174+
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
175+
176+
def forward(
177+
self,
178+
x: torch.Tensor,
179+
freqs_cos: torch.Tensor,
180+
freqs_sin: torch.Tensor,
181+
**kwargs: ForwardOptions,
182+
):
183+
mask = kwargs.get("mask")
184+
if (freqs_cos_override := kwargs.get("freqs_cos_override")) is not None:
185+
freqs_cos = freqs_cos_override # pyre-ignore
186+
if (freqs_sin_override := kwargs.get("freqs_sin_override")) is not None:
187+
freqs_sin = freqs_sin_override # pyre-ignore
188+
in_cache_state = kwargs.get("in_cache_state")
189+
out_cache_state = kwargs.get("out_cache_state")
190+
191+
new_qs = [self.wqs[i](x) for i in range(self.n_heads)]
192+
new_ks = [self.wks[i](x) for i in range(self.n_kv_heads)]
193+
new_vs = [self.wvs[i](x) for i in range(self.n_kv_heads)]
194+
new_qs = [_apply_rotary_embedding(q, freqs_cos, freqs_sin) for q in new_qs]
195+
new_ks = [_apply_rotary_embedding(k, freqs_cos, freqs_sin) for k in new_ks]
196+
197+
all_ks = []
198+
all_vs = []
199+
for i in range(self.n_kv_heads):
200+
ks, out_cache_state = self.k_caches[i].update(
201+
new_ks[i], in_cache_state, out_cache_state
202+
)
203+
all_ks.append(ks)
204+
vs, out_cache_state = self.v_caches[i].update(
205+
new_vs[i], in_cache_state, out_cache_state
206+
)
207+
all_vs.append(vs)
208+
209+
heads = []
210+
for i in range(self.n_heads):
211+
kv_idx = i // self.n_heads_per_kv_group
212+
attn = new_qs[i] @ all_ks[kv_idx].transpose(-2, -1)
213+
attn = attn * self.inv_scale
214+
attn = attn + mask # pyre-ignore
215+
attn = F.softmax(attn, dim=-1)
216+
heads.append(attn @ all_vs[kv_idx])
217+
218+
y = torch.cat(heads, dim=-1)
219+
y = self.wo(y)
220+
return y, {"out_cache_state": out_cache_state}
221+
222+
def load_weights_from_attention_mha(self, other: AttentionMHA):
223+
for i in range(self.n_heads):
224+
self.wqs[i].weight.data.copy_(
225+
other.wq.weight[i * self.head_dim : (i + 1) * self.head_dim, :]
226+
)
227+
228+
for i in range(self.n_kv_heads):
229+
self.wks[i].weight.data.copy_(
230+
other.wk.weight[i * self.head_dim : (i + 1) * self.head_dim, :]
231+
)
232+
self.wvs[i].weight.data.copy_(
233+
other.wv.weight[i * self.head_dim : (i + 1) * self.head_dim, :]
234+
)
235+
236+
self.wo.weight.data.copy_(other.wo.weight)

examples/models/llama/tests/TARGETS

+12
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,15 @@ python_unittest(
2626
"//pytorch/ao:torchao",
2727
],
2828
)
29+
30+
python_unittest(
31+
name = "test_static_attention",
32+
srcs = [
33+
"test_static_attention.py",
34+
],
35+
deps = [
36+
"//caffe2:torch",
37+
"//executorch/examples/models/llama:llama_transformer",
38+
"//executorch/examples/models/llama:static_attention",
39+
],
40+
)

0 commit comments

Comments
 (0)