Skip to content

Commit 2fee7ce

Browse files
Martin Yuanfacebook-github-bot
Martin Yuan
authored andcommitted
Enable SDPA without kv cache (#8950)
Summary: Sdpa custom op has been decoupled from kv cache by kimishpatel. Update the llama definition so that the sdpa op is applied both with and without kv cache. Reviewed By: kimishpatel Differential Revision: D70593177
1 parent f6c5959 commit 2fee7ce

File tree

2 files changed

+105
-25
lines changed

2 files changed

+105
-25
lines changed

examples/models/llama/attention.py

+17-25
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def __init__(
133133

134134
def forward(
135135
self,
136-
input_pos: torch.Tensor,
136+
input_pos: Optional[torch.Tensor],
137137
q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim)
138138
k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim)
139139
v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim)
@@ -218,13 +218,17 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
218218
self.head_dim,
219219
args.enable_dynamic_shape,
220220
)
221-
self.SDPA = SDPA(
222-
dim=self.n_local_heads * self.head_dim,
223-
head_dim=self.head_dim,
224-
n_rep=self.n_rep,
225-
max_context_len=self.max_context_len,
226-
enable_dynamic_shape=args.enable_dynamic_shape,
227-
)
221+
else:
222+
# Use a constant state to avoid export error
223+
self.zero_pos = torch.tensor([0])
224+
225+
self.SDPA = SDPA(
226+
dim=self.n_local_heads * self.head_dim,
227+
head_dim=self.head_dim,
228+
n_rep=self.n_rep,
229+
max_context_len=self.max_context_len,
230+
enable_dynamic_shape=args.enable_dynamic_shape,
231+
)
228232

229233
def forward(
230234
self,
@@ -258,20 +262,8 @@ def forward(
258262
assert input_pos is not None
259263
k, v = self.kv_cache.update(input_pos, k, v)
260264
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
261-
return self.wo(output), None
262-
263-
# grouped multiquery attention: expand out keys and values
264-
k = k.repeat_interleave(self.n_rep, dim=1)
265-
v = v.repeat_interleave(self.n_rep, dim=1)
266-
267-
assert hasattr(self, "mask")
268-
269-
mask = self.mask[:seqlen, :seqlen]
270-
271-
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
272-
273-
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
274-
275-
output = self.wo(output)
276-
277-
return output, None
265+
else:
266+
mask = self.mask[:seqlen, :seqlen]
267+
# No kv cache. Pass 0 input_pos
268+
output = self.SDPA(self.zero_pos, q, k, v, bsz, seqlen, mask)
269+
return self.wo(output), None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import unittest
2+
3+
import torch
4+
from executorch.examples.models.llama.attention import (
5+
AttentionMHA,
6+
KVCache,
7+
ModelArgs,
8+
Rope,
9+
SDPA,
10+
)
11+
12+
13+
class TestAttentionMHA(unittest.TestCase):
14+
15+
def create_mock_args(self):
16+
return ModelArgs(
17+
use_kv_cache=True,
18+
n_heads=8,
19+
n_kv_heads=4,
20+
head_dim=64,
21+
max_batch_size=2,
22+
max_context_len=16,
23+
dim=512,
24+
attention_qkv_bias=False,
25+
enable_dynamic_shape=False,
26+
)
27+
28+
def test_attentionmha_init(self):
29+
args = self.create_mock_args()
30+
rope = Rope(args)
31+
attn = AttentionMHA(args, layer_id=0, rope=rope)
32+
33+
self.assertEqual(attn.n_heads, 8)
34+
self.assertEqual(attn.n_kv_heads, 4)
35+
self.assertEqual(attn.n_local_heads, 8)
36+
self.assertEqual(attn.n_local_kv_heads, 4)
37+
self.assertEqual(attn.head_dim, 64)
38+
self.assertEqual(attn.dim, 512)
39+
self.assertEqual(attn.mask.shape, (16, 16)) # Causal mask shape check
40+
self.assertTrue(attn.use_kv_cache)
41+
42+
if attn.use_kv_cache:
43+
self.assertIsInstance(attn.kv_cache, KVCache)
44+
self.assertIsInstance(attn.SDPA, SDPA)
45+
46+
def test_attentionmha_forward(self):
47+
args = self.create_mock_args()
48+
rope = Rope(args)
49+
attn = AttentionMHA(args, layer_id=0, rope=rope)
50+
51+
bsz, seqlen, dim = 2, 4, args.dim
52+
x = torch.randn(bsz, seqlen, dim)
53+
freqs_cos = torch.randn(seqlen, args.head_dim // 2)
54+
freqs_sin = torch.randn(seqlen, args.head_dim // 2)
55+
input_pos = torch.tensor([0, 1, 2, 3])
56+
57+
output, _ = attn.forward(x, freqs_cos, freqs_sin, input_pos=input_pos)
58+
59+
self.assertEqual(output.shape, (bsz, seqlen, dim))
60+
61+
def test_attentionmha_forward_no_kv_cache(self):
62+
args = self.create_mock_args()
63+
args.use_kv_cache = False # Disable KV cache for this test
64+
rope = Rope(args)
65+
attn = AttentionMHA(args, layer_id=0, rope=rope)
66+
67+
bsz, seqlen, dim = 2, 4, args.dim
68+
x = torch.randn(bsz, seqlen, dim)
69+
freqs_cos = torch.randn(seqlen, args.head_dim // 2)
70+
freqs_sin = torch.randn(seqlen, args.head_dim // 2)
71+
72+
output, _ = attn.forward(x, freqs_cos, freqs_sin)
73+
74+
self.assertEqual(output.shape, (bsz, seqlen, dim))
75+
76+
def test_attentionmha_invalid_kv_cache(self):
77+
args = self.create_mock_args()
78+
rope = Rope(args)
79+
attn = AttentionMHA(args, layer_id=0, rope=rope)
80+
81+
bsz, seqlen, dim = 2, 4, args.dim
82+
x = torch.randn(bsz, seqlen, dim)
83+
freqs_cos = torch.randn(seqlen, args.head_dim // 2)
84+
freqs_sin = torch.randn(seqlen, args.head_dim // 2)
85+
86+
# No input_pos provided, should raise assertion error
87+
with self.assertRaises(AssertionError):
88+
attn.forward(x, freqs_cos, freqs_sin)

0 commit comments

Comments
 (0)