diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 810b7f550df..489d42c29c4 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -96,6 +96,9 @@ runtime.command_alias( runtime.python_library( name = "source_transformation", + visibility = [ + "//executorch/examples/...", + ], srcs = [ "source_transformation/apply_spin_quant_r1_r2.py", "source_transformation/attention.py", diff --git a/examples/qualcomm/oss_scripts/llama/TARGETS b/examples/qualcomm/oss_scripts/llama/TARGETS index d49253c5668..e4bad10a234 100644 --- a/examples/qualcomm/oss_scripts/llama/TARGETS +++ b/examples/qualcomm/oss_scripts/llama/TARGETS @@ -19,6 +19,7 @@ python_library( name = "llama_lib", srcs = ["llama.py"], deps = [ + "//executorch/examples/models/llama:source_transformation", "//caffe2:torch", "//executorch/backends/qualcomm/partition:partition", "//executorch/backends/qualcomm/quantizer:quantizer", diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 48353d3ee6b..e853812a949 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -1039,10 +1039,7 @@ def _build_parser(): return parser -def main(args) -> None: - parser = _build_parser() - - args = parser.parse_args(args) +def export_llama(args) -> None: if args.compile_only and args.pre_gen_pte: exit("Cannot set both compile_only and pre_gen_pte as true") @@ -1143,6 +1140,12 @@ def main(args) -> None: raise Exception(e) +def main(): + parser = _build_parser() + args = parser.parse_args() + export_llama(args) + + # flake8: noqa: C901 if __name__ == "__main__": - main(sys.argv[1:]) + main()