Skip to content

Commit 337d73d

Browse files
authored
Add buck rules in coreml llama transformer
Differential Revision: D70415647 Pull Request resolved: #9017
1 parent 1145892 commit 337d73d

File tree

5 files changed

+83
-14
lines changed

5 files changed

+83
-14
lines changed

examples/apple/coreml/llama/TARGETS

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
5+
6+
runtime.python_library(
7+
name = "llama_transformer",
8+
srcs = [
9+
"llama_transformer.py",
10+
],
11+
_is_external_target = True,
12+
base_module = "executorch.examples.apple.coreml.llama",
13+
visibility = [
14+
"//executorch/...",
15+
"@EXECUTORCH_CLIENTS",
16+
],
17+
deps = [
18+
"//caffe2:torch",
19+
"//executorch/examples/models/llama:llama_transformer",
20+
],
21+
)
22+
23+
runtime.python_library(
24+
name = "utils",
25+
srcs = [
26+
"utils.py",
27+
],
28+
_is_external_target = True,
29+
base_module = "executorch.examples.apple.coreml.llama",
30+
visibility = [
31+
"//executorch/...",
32+
"@EXECUTORCH_CLIENTS",
33+
],
34+
deps = [
35+
"//caffe2:torch",
36+
],
37+
)
38+
39+
runtime.python_binary(
40+
name = "export",
41+
srcs = [
42+
"export.py",
43+
],
44+
main_function = "executorch.examples.apple.coreml.llama.export.main",
45+
visibility = [
46+
"//executorch/...",
47+
"@EXECUTORCH_CLIENTS",
48+
],
49+
deps = [
50+
"fbsource//third-party/pypi/coremltools:coremltools",
51+
":llama_transformer",
52+
":utils",
53+
"//caffe2:torch",
54+
"//executorch/backends/apple/coreml:backend",
55+
"//executorch/backends/apple/coreml:partitioner",
56+
"//executorch/examples/models/llama:source_transformation",
57+
"//executorch/exir/backend:utils",
58+
"//executorch/exir/capture:config",
59+
"//executorch/exir/passes:lib",
60+
"//executorch/exir/passes:quant_fusion_pass",
61+
"//executorch/exir/passes:sym_shape_eval_pass",
62+
"//executorch/exir/program:program",
63+
"//executorch/extension/export_util:export_util",
64+
"//executorch/extension/llm/export:export_lib",
65+
],
66+
)

examples/apple/coreml/llama/export.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,18 @@
66

77
import argparse
88

9-
import sys
10-
119
import coremltools as ct
1210
import torch
1311
from executorch.backends.apple.coreml.compiler import CoreMLBackend # pyre-ignore
1412
from executorch.backends.apple.coreml.partition import CoreMLPartitioner # pyre-ignore
13+
14+
from executorch.examples.apple.coreml.llama.llama_transformer import (
15+
InputManager,
16+
load_model,
17+
)
18+
from executorch.examples.apple.coreml.llama.utils import (
19+
replace_linear_with_split_linear,
20+
)
1521
from executorch.examples.models.llama.source_transformation.quantize import (
1622
EmbeddingQuantHandler,
1723
)
@@ -24,10 +30,6 @@
2430
from executorch.exir.program._program import to_edge_with_preserved_ops
2531
from executorch.extension.export_util.utils import save_pte_program
2632

27-
sys.path.insert(0, ".")
28-
from llama_transformer import InputManager, load_model
29-
from utils import replace_linear_with_split_linear
30-
3133

3234
def main() -> None:
3335
parser = argparse.ArgumentParser()

examples/apple/coreml/llama/llama_transformer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def forward(
443443
if not self.use_cache_list:
444444
k_out = torch.stack(k_out, dim=0)
445445
v_out = torch.stack(v_out, dim=0)
446-
return logits, k_out, v_out
446+
return logits, k_out, v_out # pyre-ignore[7]
447447

448448

449449
def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list):
@@ -614,7 +614,7 @@ def get_inputs(self, tokens: List[int]):
614614
torch.tensor(tokens, dtype=torch.int64),
615615
torch.zeros(self.seq_length - input_length, dtype=torch.int64),
616616
],
617-
axis=-1,
617+
dim=-1,
618618
).reshape(1, -1),
619619
# input_pos
620620
torch.tensor([self.input_pos], dtype=torch.long),

examples/apple/coreml/llama/run.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import argparse
8-
import sys
98

109
import sentencepiece as spm
1110

1211
import torch
12+
from executorch.examples.apple.coreml.llama.llama_transformer import (
13+
InputManager,
14+
load_model,
15+
)
1316

14-
from executorch.runtime import Runtime
15-
16-
17-
sys.path.insert(0, ".")
1817
from executorch.examples.models.llama.runner.generation import next_token
1918
from executorch.examples.models.llama.tokenizer import tiktoken
20-
from llama_transformer import InputManager, load_model
19+
20+
from executorch.runtime import Runtime
2121

2222

2323
class Tokenizer:

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ flatc = "executorch.data.bin:flatc"
9292
# TODO(mnachin T180504136): Do not put examples/models
9393
# into core pip packages. Refactor out the necessary utils
9494
# or core models files into a separate package.
95+
"executorch.examples.apple.coreml.llama" = "examples/apple/coreml/llama"
9596
"executorch.examples.models" = "examples/models"
9697
"executorch.exir" = "exir"
9798
"executorch.extension" = "extension"

0 commit comments

Comments
 (0)