Skip to content

Commit c1ed48f

Browse files
sxufacebook-github-bot
authored andcommitted
Static attention implementation (pytorch#8310)
Summary: Merge existing implementation for static shaped NPU where valid data are kept at the end of cache and MHA -> SHA rewrite is done. Reviewed By: iseeyuan Differential Revision: D69080741
1 parent 77f18b2 commit c1ed48f

File tree

6 files changed

+448
-0
lines changed

6 files changed

+448
-0
lines changed

examples/models/llama/TARGETS

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,23 @@ runtime.python_library(
2828
],
2929
)
3030

31+
runtime.python_library(
32+
name = "static_attention",
33+
srcs = [
34+
"static_attention.py",
35+
],
36+
_is_external_target = True,
37+
base_module = "executorch.examples.models.llama",
38+
visibility = [
39+
"//executorch/...",
40+
"@EXECUTORCH_CLIENTS",
41+
],
42+
deps = [
43+
":llama_transformer",
44+
"//caffe2:torch",
45+
],
46+
)
47+
3148
runtime.python_library(
3249
name = "llama2_model",
3350
srcs = [

examples/models/llama/attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ class ForwardOptions(TypedDict, total=False):
1313

1414
mask: Optional[torch.Tensor]
1515
input_pos: Optional[torch.Tensor]
16+
freqs_cos_override: Optional[torch.Tensor]
17+
freqs_sin_override: Optional[torch.Tensor]
1618
in_cache_state: Optional[Any]
1719
out_cache_state: Optional[Any]
1820

examples/models/llama/llama_transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ def forward(
205205
attn_options.get("input_pos"), seqlen
206206
)
207207

208+
# Make a shallow copy so the updates don't get captured by export
209+
attn_options = attn_options.copy()
208210
attn_options_update = None
209211
for layer in self.layers:
210212
h, attn_options_update = layer(h, freqs_cos, freqs_sin, attn_options)
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Dict, Optional, Tuple
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
from executorch.examples.models.llama.attention import (
8+
Attention,
9+
AttentionMHA,
10+
ForwardOptions,
11+
register_attention,
12+
)
13+
from executorch.examples.models.llama.model_args import ModelArgs
14+
from executorch.examples.models.llama.rope import Rope
15+
16+
17+
# Dict key has the form (layer ID, KV head ID).
18+
_CacheMap = Dict[Tuple[int, int], torch.Tensor]
19+
# Key and value caches are kept separate so the key caches can be kept transposed.
20+
_InputCacheState = Tuple[_CacheMap, _CacheMap]
21+
_OutputCacheState = Tuple[_CacheMap, _CacheMap]
22+
23+
24+
class StaticKVCache(nn.Module, ABC):
25+
def __init__(self, layer_id: int, head_id: int):
26+
super().__init__()
27+
self.layer_id = layer_id
28+
self.head_id = head_id
29+
30+
@abstractmethod
31+
def update(
32+
self,
33+
new_data: torch.Tensor,
34+
in_cache_state: Optional[_InputCacheState],
35+
out_cache_state: Optional[_OutputCacheState],
36+
) -> Tuple[torch.Tensor, Optional[_OutputCacheState]]:
37+
"""
38+
Given input cache state and new keys/values, returns the combined keys/values
39+
and the updated the output cache state.
40+
"""
41+
pass
42+
43+
def cache_key(self) -> str:
44+
return self.calculate_cache_key(self.layer_id, self.head_id)
45+
46+
@staticmethod
47+
def calculate_cache_key(layer_id: int, head_id: int) -> str:
48+
return f"l{layer_id},h{head_id}"
49+
50+
@staticmethod
51+
def apply_update(cache, update, transpose=False):
52+
"""
53+
After inference, update the cache state for next iteration. The runtime needs to
54+
implement the same operation.
55+
"""
56+
if transpose:
57+
update_len = update.size(-1)
58+
updated = torch.roll(cache, -update_len, -1)
59+
updated[:, :, -update_len:] = update
60+
else:
61+
update_len = update.size(-2)
62+
updated = torch.roll(cache, -update_len, -2)
63+
updated[:, -update_len:, :] = update
64+
65+
return updated
66+
67+
68+
class StaticKCache(StaticKVCache):
69+
def __init__(self, layer_id: int, head_id: int, transpose=False):
70+
"""
71+
If transpose is True, key cache is kept in (batch, dim, seq_len), otherwise in
72+
(batch, seq_len, dim).
73+
"""
74+
super().__init__(layer_id, head_id)
75+
self.transpose = transpose
76+
77+
def update(
78+
self,
79+
new_data: torch.Tensor,
80+
in_cache_state: Optional[_InputCacheState],
81+
out_cache_state: Optional[_OutputCacheState],
82+
) -> Tuple[torch.Tensor, Optional[_OutputCacheState]]:
83+
seq_dim = -2
84+
if self.transpose:
85+
seq_dim = -1
86+
new_data = new_data.transpose(-1, -2)
87+
if in_cache_state is None:
88+
return new_data, None
89+
if out_cache_state is None:
90+
out_cache_state = ({}, {})
91+
92+
all_data = torch.cat(
93+
[in_cache_state[0][self.cache_key()], new_data], dim=seq_dim
94+
)
95+
out_k_cache, out_v_cache = out_cache_state
96+
out_k_cache[self.cache_key()] = new_data
97+
return all_data, (out_k_cache, out_v_cache)
98+
99+
100+
class StaticVCache(StaticKVCache):
101+
def update(
102+
self,
103+
new_data: torch.Tensor,
104+
in_cache_state: Optional[_InputCacheState],
105+
out_cache_state: Optional[_OutputCacheState],
106+
) -> Tuple[torch.Tensor, Optional[_OutputCacheState]]:
107+
if in_cache_state is None:
108+
return new_data, None
109+
if out_cache_state is None:
110+
out_cache_state = ({}, {})
111+
112+
all_data = torch.cat([in_cache_state[1][self.cache_key()], new_data], dim=-2)
113+
out_k_cache, out_v_cache = out_cache_state
114+
out_v_cache[self.cache_key()] = new_data
115+
return all_data, (out_k_cache, out_v_cache)
116+
117+
118+
def _apply_rotary_embedding(
119+
x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
120+
) -> torch.Tensor:
121+
x_r, x_i = x[..., ::2], x[..., 1::2]
122+
x_out_r = x_r * freqs_cos - x_i * freqs_sin
123+
x_out_i = x_r * freqs_sin + x_i * freqs_cos
124+
125+
x_out = torch.cat([x_out_r, x_out_i], dim=-1)
126+
return x_out
127+
128+
129+
@register_attention("static")
130+
class StaticAttention(Attention):
131+
"""
132+
An attention implementation meant for NPUs that require static shapes and are not
133+
flexible with tensor operations needed to perform KV cache updates. MHA/GQA is
134+
implemented as multiple SHAs, and the KV caches keep valid data at the end so the
135+
model only needs to perform a concat to combine past and new data.
136+
"""
137+
138+
def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
139+
super().__init__()
140+
self.n_heads = config.n_heads
141+
self.n_kv_heads = (
142+
self.n_heads if config.n_kv_heads is None else config.n_kv_heads
143+
)
144+
assert self.n_heads % self.n_kv_heads == 0
145+
self.n_heads_per_kv_group = self.n_heads // self.n_kv_heads
146+
self.dim = config.dim
147+
self.head_dim = config.head_dim
148+
self.inv_scale = 1.0 / (float(self.head_dim) ** 0.5)
149+
150+
self.wqs = nn.ModuleList(
151+
[
152+
nn.Linear(self.dim, self.head_dim, bias=False)
153+
for _ in range(self.n_heads)
154+
]
155+
)
156+
self.wks = nn.ModuleList(
157+
[
158+
nn.Linear(self.dim, self.head_dim, bias=False)
159+
for _ in range(self.n_kv_heads)
160+
]
161+
)
162+
self.wvs = nn.ModuleList(
163+
[
164+
nn.Linear(self.dim, self.head_dim, bias=False)
165+
for _ in range(self.n_kv_heads)
166+
]
167+
)
168+
169+
self.k_caches = nn.ModuleList(
170+
[StaticKCache(layer_id, i) for i in range(self.n_kv_heads)]
171+
)
172+
self.v_caches = nn.ModuleList(
173+
[StaticVCache(layer_id, i) for i in range(self.n_kv_heads)]
174+
)
175+
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
176+
177+
def forward(
178+
self,
179+
x: torch.Tensor,
180+
freqs_cos: torch.Tensor,
181+
freqs_sin: torch.Tensor,
182+
**kwargs: ForwardOptions,
183+
):
184+
mask = kwargs.get("mask")
185+
freqs_cos = kwargs.get("freqs_cos_override", freqs_cos)
186+
freqs_sin = kwargs.get("freqs_sin_override", freqs_sin)
187+
in_cache_state = kwargs.get("in_cache_state")
188+
out_cache_state = kwargs.get("out_cache_state")
189+
190+
new_qs = [self.wqs[i](x) for i in range(self.n_heads)]
191+
new_ks = [self.wks[i](x) for i in range(self.n_kv_heads)]
192+
new_vs = [self.wvs[i](x) for i in range(self.n_kv_heads)]
193+
new_qs = [_apply_rotary_embedding(q, freqs_cos, freqs_sin) for q in new_qs]
194+
new_ks = [_apply_rotary_embedding(k, freqs_cos, freqs_sin) for k in new_ks]
195+
196+
all_ks = []
197+
all_vs = []
198+
for i in range(self.n_kv_heads):
199+
ks, out_cache_state = self.k_caches[i].update(
200+
new_ks[i], in_cache_state, out_cache_state
201+
)
202+
all_ks.append(ks)
203+
vs, out_cache_state = self.v_caches[i].update(
204+
new_vs[i], in_cache_state, out_cache_state
205+
)
206+
all_vs.append(vs)
207+
208+
heads = []
209+
for i in range(self.n_heads):
210+
kv_idx = i // self.n_heads_per_kv_group
211+
attn = new_qs[i] @ all_ks[kv_idx].transpose(-2, -1)
212+
attn = attn * self.inv_scale
213+
attn = attn + mask
214+
attn = F.softmax(attn, dim=-1)
215+
heads.append(attn @ all_vs[kv_idx])
216+
217+
y = torch.cat(heads, dim=-1)
218+
y = self.wo(y)
219+
return y, {"out_cache_state": out_cache_state}
220+
221+
def load_weights_from_attention_mha(self, other: AttentionMHA):
222+
for i in range(self.n_heads):
223+
self.wqs[i].weight.data.copy_(
224+
other.wq.weight[i * self.head_dim : (i + 1) * self.head_dim, :]
225+
)
226+
227+
for i in range(self.n_kv_heads):
228+
self.wks[i].weight.data.copy_(
229+
other.wk.weight[i * self.head_dim : (i + 1) * self.head_dim, :]
230+
)
231+
self.wvs[i].weight.data.copy_(
232+
other.wv.weight[i * self.head_dim : (i + 1) * self.head_dim, :]
233+
)
234+
235+
self.wo.weight.data.copy_(other.wo.weight)

examples/models/llama/tests/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,15 @@ python_unittest(
2626
"//pytorch/ao:torchao",
2727
],
2828
)
29+
30+
python_unittest(
31+
name = "test_static_attention",
32+
srcs = [
33+
"test_static_attention.py",
34+
],
35+
deps = [
36+
"//caffe2:torch",
37+
"//executorch/examples/models/llama:llama_transformer",
38+
"//executorch/examples/models/llama:static_attention",
39+
],
40+
)

0 commit comments

Comments
 (0)