Skip to content

Commit 60bf405

Browse files
Decouple custom ops in llama_transformer.py Part 1/N (#3005) (#3052)
Summary: This is a no-op Pull Request resolved: #3005 Test Plan: CI Run with `python -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -kv --use_sdpa_with_kv_cache -X` and with `python -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -kv -X` Make sure both work Reviewed By: cccclai Differential Revision: D56048177 Pulled By: mergennachin fbshipit-source-id: 3ac9ac5c34f6fe215de1cfe8b5ddc7aae3635359 (cherry picked from commit 488afc5) Co-authored-by: Mergen Nachin <[email protected]>
1 parent 638433f commit 60bf405

File tree

4 files changed

+104
-56
lines changed

4 files changed

+104
-56
lines changed

examples/models/llama2/builder.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,7 @@ def source_transform(
202202
def _get_dynamic_shape(self) -> Any:
203203
dim = torch.export.Dim("token_dim", max=self.model.params.max_seq_len - 1)
204204
if self.use_kv_cache:
205-
if self.use_sdpa_with_kv_cache:
206-
return None
207-
else:
208-
# return {1: dim}, {0: dim}} TODO update xnnpack to be able to handle dynamic shape kv cache
209-
return None
205+
return None
210206
else:
211207
return ({1: dim},)
212208

examples/models/llama2/export_llama_lib.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
482482
if args.expand_rope_table:
483483
transforms.append(materialze_broadcast_of_rope_freq_cis)
484484

485+
if args.use_sdpa_with_kv_cache:
486+
pass
487+
# TODO: Next diff transforms.append()
488+
485489
return (
486490
load_llama_model(
487491
checkpoint=checkpoint_path,

examples/models/llama2/llama_transformer.py

Lines changed: 98 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,95 @@ def update(
193193
return k_out, v_out
194194

195195

196+
class SDPA(nn.Module):
197+
def __init__(
198+
self,
199+
kv_cache: KVCache,
200+
mask,
201+
use_sdpa_with_kv_cache_op: bool,
202+
dim: int,
203+
n_rep: int,
204+
):
205+
super().__init__()
206+
self.kv_cache = kv_cache
207+
self.mask = mask
208+
self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
209+
self.dim = dim
210+
self.n_rep = n_rep
211+
212+
def forward(
213+
self,
214+
input_pos: torch.Tensor,
215+
q: torch.Tensor,
216+
k: torch.Tensor,
217+
v: torch.Tensor,
218+
bsz,
219+
seqlen,
220+
) -> torch.Tensor:
221+
if not self.use_sdpa_with_kv_cache_op:
222+
return self._forward_default(
223+
input_pos,
224+
q,
225+
k,
226+
v,
227+
bsz,
228+
seqlen,
229+
)
230+
else:
231+
return self._forward_custom(
232+
input_pos,
233+
q,
234+
k,
235+
v,
236+
bsz,
237+
seqlen,
238+
)
239+
240+
def _forward_custom(
241+
self,
242+
input_pos: torch.Tensor,
243+
q: torch.Tensor,
244+
k: torch.Tensor,
245+
v: torch.Tensor,
246+
bsz,
247+
seqlen,
248+
):
249+
from .custom_ops import sdpa_with_kv_cache # noqa
250+
251+
output = torch.ops.llama.sdpa_with_kv_cache(
252+
q,
253+
k,
254+
v,
255+
self.kv_cache.k_cache,
256+
self.kv_cache.v_cache,
257+
input_pos[-1].item(),
258+
seqlen,
259+
)
260+
return output.view(bsz, seqlen, self.dim)
261+
262+
def _forward_default(
263+
self,
264+
input_pos: torch.Tensor,
265+
q: torch.Tensor,
266+
k: torch.Tensor,
267+
v: torch.Tensor,
268+
bsz,
269+
seqlen,
270+
) -> torch.Tensor:
271+
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
272+
k = k.transpose(1, 2)
273+
v = v.transpose(1, 2)
274+
275+
k, v = self.kv_cache.update(input_pos, k, v)
276+
mask = self.mask[None, None, input_pos]
277+
278+
k = k.repeat_interleave(self.n_rep, dim=1)
279+
v = v.repeat_interleave(self.n_rep, dim=1)
280+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
281+
282+
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
283+
284+
196285
class Attention(nn.Module):
197286
def __init__(self, args: ModelArgs, layer_id: int):
198287
super().__init__()
@@ -213,7 +302,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
213302
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
214303
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
215304

216-
self.use_sdpa_with_kv_cache_op = args.use_sdpa_with_kv_cache_op
217305
self.layer_id = layer_id
218306

219307
causal_mask = torch.tril(
@@ -234,6 +322,13 @@ def __init__(self, args: ModelArgs, layer_id: int):
234322
self.head_dim,
235323
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v
236324
)
325+
self.SDPA = SDPA(
326+
self.kv_cache,
327+
self.mask,
328+
args.use_sdpa_with_kv_cache_op,
329+
self.dim,
330+
self.n_rep,
331+
)
237332

238333
def forward(
239334
self,
@@ -256,41 +351,8 @@ def forward(
256351

257352
if self.use_kv_cache:
258353
assert input_pos is not None
259-
260-
if not self.use_sdpa_with_kv_cache_op:
261-
262-
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
263-
k = k.transpose(1, 2)
264-
v = v.transpose(1, 2)
265-
266-
k, v = self.kv_cache.update(input_pos, k, v)
267-
mask = self.mask[None, None, input_pos]
268-
269-
k = k.repeat_interleave(self.n_rep, dim=1)
270-
v = v.repeat_interleave(self.n_rep, dim=1)
271-
y = F.scaled_dot_product_attention(
272-
q, k, v, attn_mask=mask, dropout_p=0.0
273-
)
274-
275-
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
276-
277-
y = self.wo(y)
278-
return y
279-
else:
280-
from .custom_ops.sdpa_with_kv_cache import sdpa_with_kv_cache # noqa
281-
282-
output = torch.ops.llama.sdpa_with_kv_cache(
283-
q,
284-
k,
285-
v,
286-
self.kv_cache.k_cache,
287-
self.kv_cache.v_cache,
288-
input_pos[-1].item(),
289-
seqlen,
290-
)
291-
output = output.view(bsz, seqlen, -1)
292-
output = self.wo(output)
293-
return output
354+
output = self.SDPA(input_pos, q, k, v, bsz, seqlen)
355+
return self.wo(output)
294356

295357
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
296358
k = k.transpose(1, 2)

examples/models/llama2/model.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,7 @@ def get_eager_model(self):
173173

174174
def get_example_inputs(self):
175175
if self.use_kv_cache:
176-
if self.use_sdpa_with_kv_cache_op:
177-
return self.get_example_inputs_kvcache_sdpa()
178-
else:
179-
# return self.get_example_inputs_kvcache() TODO xnnpack does not handle forwarding symints, update partitioner to not partition symints
180-
return self.get_example_inputs_kvcache_sdpa()
176+
return self.get_example_inputs_kvcache_sdpa()
181177
else:
182178
return (
183179
torch.tensor(
@@ -195,13 +191,3 @@ def get_example_inputs_kvcache_sdpa(self):
195191
[0], dtype=torch.long
196192
), # start_pos, what token of output are we on.)
197193
)
198-
199-
def get_example_inputs_kvcache(self):
200-
return (
201-
torch.tensor(
202-
[[1, 2, 3]], dtype=torch.long
203-
), # tokens, with kv cache our input token length is always just 1 token.
204-
torch.tensor(
205-
[0, 1, 2], dtype=torch.long
206-
), # start_pos, what token of output are we on.
207-
)

0 commit comments

Comments
 (0)