Skip to content

Commit 1ffbc72

Browse files
mergennachincccclai
authored andcommitted
Decouple custom ops in llama_transformer.py Part 2/N (pytorch#3007)
Summary: Pull Request resolved: pytorch#3007 Keep llama_transformer.py to look like stock implementation, so that it can be reused everywhere. Do module swap Reviewed By: cccclai Differential Revision: D56048640 fbshipit-source-id: 76de1b09b7f5d79422bb3b32bc830a9a7ecd935c (cherry picked from commit 74eb8b3)
1 parent 60bf405 commit 1ffbc72

File tree

3 files changed

+59
-57
lines changed

3 files changed

+59
-57
lines changed

examples/models/llama2/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ runtime.python_library(
1818
],
1919
deps = [
2020
"//caffe2:torch",
21-
"//executorch/examples/models/llama2/custom_ops:llama_custom_ops_aot_lib",
2221
],
2322
)
2423

@@ -85,6 +84,7 @@ runtime.python_library(
8584
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
8685
"//executorch/examples/models:model_base",
8786
"//executorch/examples/models:models",
87+
"//executorch/examples/models/llama2/custom_ops:custom_ops_aot_py",
8888
"//executorch/examples/portable:utils",
8989
"//executorch/exir:lib",
9090
"//executorch/sdk/etrecord:etrecord",

examples/models/llama2/export_llama_lib.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@
2323
XnnpackDynamicallyQuantizedPartitioner,
2424
)
2525

26-
from executorch.examples.models.llama2.llama_transformer import Transformer
26+
from executorch.examples.models.llama2.llama_transformer import (
27+
KVCache,
28+
SDPA,
29+
Transformer,
30+
)
2731
from executorch.exir.backend.backend_details import CompileSpec
2832

2933
from executorch.sdk.etrecord import generate_etrecord
@@ -88,6 +92,58 @@ def materialze_broadcast_of_rope_freq_cis(
8892
return module
8993

9094

95+
class SDPACustom(torch.nn.Module):
96+
def __init__(
97+
self,
98+
kv_cache: KVCache,
99+
mask,
100+
dim: int,
101+
):
102+
super().__init__()
103+
self.kv_cache = kv_cache
104+
self.mask = mask
105+
self.dim = dim
106+
107+
def forward(
108+
self,
109+
input_pos: torch.Tensor,
110+
q: torch.Tensor,
111+
k: torch.Tensor,
112+
v: torch.Tensor,
113+
bsz,
114+
seqlen,
115+
):
116+
output = torch.ops.llama.sdpa_with_kv_cache(
117+
q,
118+
k,
119+
v,
120+
self.kv_cache.k_cache,
121+
self.kv_cache.v_cache,
122+
input_pos[-1].item(),
123+
seqlen,
124+
)
125+
return output.view(bsz, seqlen, self.dim)
126+
127+
128+
def _replace_sdpa_with_custom_op(module: torch.nn.Module):
129+
for name, child in module.named_children():
130+
if isinstance(child, SDPA):
131+
setattr(
132+
module,
133+
name,
134+
SDPACustom(child.kv_cache, child.mask, child.dim),
135+
)
136+
else:
137+
_replace_sdpa_with_custom_op(child)
138+
139+
140+
def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
141+
from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # noqa
142+
143+
_replace_sdpa_with_custom_op(module)
144+
return module
145+
146+
91147
def quantize(
92148
model: torch.nn.Module,
93149
qmode: str,
@@ -483,8 +539,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
483539
transforms.append(materialze_broadcast_of_rope_freq_cis)
484540

485541
if args.use_sdpa_with_kv_cache:
486-
pass
487-
# TODO: Next diff transforms.append()
542+
transforms.append(replace_sdpa_with_custom_op)
488543

489544
return (
490545
load_llama_model(

examples/models/llama2/llama_transformer.py

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -198,14 +198,12 @@ def __init__(
198198
self,
199199
kv_cache: KVCache,
200200
mask,
201-
use_sdpa_with_kv_cache_op: bool,
202201
dim: int,
203202
n_rep: int,
204203
):
205204
super().__init__()
206205
self.kv_cache = kv_cache
207206
self.mask = mask
208-
self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
209207
self.dim = dim
210208
self.n_rep = n_rep
211209

@@ -217,56 +215,6 @@ def forward(
217215
v: torch.Tensor,
218216
bsz,
219217
seqlen,
220-
) -> torch.Tensor:
221-
if not self.use_sdpa_with_kv_cache_op:
222-
return self._forward_default(
223-
input_pos,
224-
q,
225-
k,
226-
v,
227-
bsz,
228-
seqlen,
229-
)
230-
else:
231-
return self._forward_custom(
232-
input_pos,
233-
q,
234-
k,
235-
v,
236-
bsz,
237-
seqlen,
238-
)
239-
240-
def _forward_custom(
241-
self,
242-
input_pos: torch.Tensor,
243-
q: torch.Tensor,
244-
k: torch.Tensor,
245-
v: torch.Tensor,
246-
bsz,
247-
seqlen,
248-
):
249-
from .custom_ops import sdpa_with_kv_cache # noqa
250-
251-
output = torch.ops.llama.sdpa_with_kv_cache(
252-
q,
253-
k,
254-
v,
255-
self.kv_cache.k_cache,
256-
self.kv_cache.v_cache,
257-
input_pos[-1].item(),
258-
seqlen,
259-
)
260-
return output.view(bsz, seqlen, self.dim)
261-
262-
def _forward_default(
263-
self,
264-
input_pos: torch.Tensor,
265-
q: torch.Tensor,
266-
k: torch.Tensor,
267-
v: torch.Tensor,
268-
bsz,
269-
seqlen,
270218
) -> torch.Tensor:
271219
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
272220
k = k.transpose(1, 2)
@@ -325,7 +273,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
325273
self.SDPA = SDPA(
326274
self.kv_cache,
327275
self.mask,
328-
args.use_sdpa_with_kv_cache_op,
329276
self.dim,
330277
self.n_rep,
331278
)

0 commit comments

Comments
 (0)