Skip to content

Commit 88886b8

Browse files
authored
Refactor and add Llama Python library build (#8107)
* Refactor and add Llama Python library build (#8107) Summary: As title. To use static llama export outside QC dir. Reviewed By: cccclai Differential Revision: D68937637 * Address lint issue
1 parent 1d43d91 commit 88886b8

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

examples/qualcomm/oss_scripts/llama/TARGETS

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,24 @@ python_library(
1515
],
1616
)
1717

18+
python_library(
19+
name = "llama_lib",
20+
srcs = ["llama.py"],
21+
deps = [
22+
"//caffe2:torch",
23+
"//executorch/backends/qualcomm/partition:partition",
24+
"//executorch/backends/qualcomm/quantizer:quantizer",
25+
"//executorch/devtools:lib",
26+
"//executorch/examples/models:models",
27+
"//executorch/examples/qualcomm/oss_scripts/llama:static_llama",
28+
"//executorch/examples/qualcomm:utils",
29+
"//executorch/extension/export_util:export_util",
30+
"//executorch/extension/llm/custom_ops:model_sharding_py",
31+
"//executorch/extension/llm/export:export_lib",
32+
"//executorch/extension/pybindings:aten_lib",
33+
],
34+
)
35+
1836
python_binary(
1937
name = "llama",
2038
srcs = ["llama.py"],

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,7 @@ def post_process():
847847
logging.info(f"Results[{idx}]:\n{output}")
848848

849849

850-
def main():
850+
def _build_parser():
851851
parser = setup_common_args_and_variables()
852852
parser.add_argument(
853853
"-a",
@@ -980,7 +980,13 @@ def main():
980980
help="Fallback to cpu embedding operator and type of embedding quantization, '<bitwidth>,<groupsize>', e.g., '4,32'.",
981981
)
982982

983-
args = parser.parse_args()
983+
return parser
984+
985+
986+
def main(args) -> None:
987+
parser = _build_parser()
988+
989+
args = parser.parse_args(args)
984990
if args.compile_only and args.pre_gen_pte:
985991
exit("Cannot set both compile_only and pre_gen_pte as true")
986992

@@ -1071,4 +1077,4 @@ def main():
10711077

10721078
# flake8: noqa: C901
10731079
if __name__ == "__main__":
1074-
main()
1080+
main(sys.argv[1:])

0 commit comments

Comments
 (0)