From 35387f6de7c731e8d3f52ce504c2abd912c6f096 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Thu, 3 Oct 2024 14:50:55 -0700 Subject: [PATCH] prefill model (#5807) Summary: python -m executorch.examples.models.llama2.export_llama --disable_dynamic_shape --qnn --pt2e_quantize qnn_16a4w Segfault error stacktrace ``` [INFO] [Qnn ExecuTorch]: Initialize Qnn backend parameters for Qnn executorch backend type 2 [INFO] [Qnn ExecuTorch]: Caching: Caching is in SAVE MODE. [WARNING] [Qnn ExecuTorch]: Qnn API version 2.19.0 is used. The version is tested against 2.18.0. [INFO] [Qnn ExecuTorch]: Running level=3 optimization. AddressSanitizer:DEADLYSIGNAL ================================================================= ==1523599==ERROR: AddressSanitizer: SEGV on unknown address 0x000000000020 (pc 0x7f1585ee38e2 bp 0x7f16d5ab8800 sp 0x7ffed19ab8b0 T0) ==1523599==The signal is caused by a READ memory access. ==1523599==Hint: address points to the zero page. SCARINESS: 10 (null-deref) #0 0x7f1585ee38e2 (/home/chenlai/fbsource/third-party/qualcomm/qnn/qnn-2.26/lib/x86_64-linux-clang/libQnnHtp.so+0x2ce38e2) (BuildId: bc3ab8ddc89a0e65) #1 0x7f1585dd8926 (/home/chenlai/fbsource/third-party/qualcomm/qnn/qnn-2.26/lib/x86_64-linux-clang/libQnnHtp.so+0x2bd8926) (BuildId: bc3ab8ddc89a0e65) #2 0x7f15844d1161 (/home/chenlai/fbsource/third-party/qualcomm/qnn/qnn-2.26/lib/x86_64-linux-clang/libQnnHtp.so+0x12d1161) (BuildId: bc3ab8ddc89a0e65) #3 0x7f15844dcac6 (/home/chenlai/fbsource/third-party/qualcomm/qnn/qnn-2.26/lib/x86_64-linux-clang/libQnnHtp.so+0x12dcac6) (BuildId: bc3ab8ddc89a0e65) #4 0x7f15844d245b (/home/chenlai/fbsource/third-party/qualcomm/qnn/qnn-2.26/lib/x86_64-linux-clang/libQnnHtp.so+0x12d245b) (BuildId: bc3ab8ddc89a0e65) #5 0x7f15b9bc7b21 in auto torch::executor::qnn::QnnInterface::qnn_backend_validate_op_config(void*, Qnn_OpConfig_t) const fbcode/executorch/backends/qualcomm/runtime/backends/QnnFunctionInterface.h:39 #6 0x7f15b9bc7682 in torch::executor::qnn::QnnBackend::BackendValidateOpConfig(Qnn_OpConfig_t const&) fbcode/executorch/backends/qualcomm/runtime/backends/QnnBackendCommon.h:41 #7 0x7f15b9bc7115 in torch::executor::qnn::QnnManager::IsNodeSupportedByBackend(std::vector, std::allocator>>&) fbcode/executorch/backends/qualcomm/runtime/QnnManager.cpp:450 #8 0x7f15b9dd44ee in torch::executor::qnn::PyQnnManager::IsNodeSupportedByBackend(std::vector, std::allocator>>&) fbcode/executorch/backends/qualcomm/aot/python/PyQnnManagerAdaptor.h:57 #9 0x7f15b9e5b986 in pybind11::cpp_function::cpp_function, std::allocator>>&, pybind11::name, pybind11::is_method, pybind11::sibling>(bool (torch::executor::qnn::PyQnnManager::*)(std::vector, std::allocator>>&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(torch::executor::qnn::PyQnnManager*, std::vector, std::allocator>>&)::operator()(torch::executor::qnn::PyQnnManager*, std::vector, std::allocator>>&) const fbsource/pybind11/pybind11.h:84 #10 0x7f15b9e5b8b5 in bool pybind11::detail::argument_loader, std::allocator>>&>::call_impl, std::allocator>>&, pybind11::name, pybind11::is_method, pybind11::sibling>(bool (torch::executor::qnn::PyQnnManager::*)(std::vector, std::allocator>>&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(torch::executor::qnn::PyQnnManager*, std::vector, std::allocator>>&)&, 0ul, 1ul, pybind11::detail::void_type>(torch::executor::qnn::PyQnnManager&&, std::integer_sequence, pybind11::detail::void_type&&) && fbsource/pybind11/cast.h:2042 #11 0x7f15b9e53831 in std::enable_if::value, bool>::type pybind11::detail::argument_loader, std::allocator>>&>::call, std::allocator>>&, pybind11::name, pybind11::is_method, pybind11::sibling>(bool (torch::executor::qnn::PyQnnManager::*)(std::vector, std::allocator>>&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(torch::executor::qnn::PyQnnManager*, std::vector, std::allocator>>&)&>(pybind11::cpp_function::cpp_function, std::allocator>>&, pybind11::name, pybind11::is_method, pybind11::sibling>(bool (torch::executor::qnn::PyQnnManager::*)(std::vector, std::allocator>>&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(torch::executor::qnn::PyQnnManager*, std::vector, std::allocator>>&)&) && fbsource/pybind11/cast.h:2014 #12 0x7f15b9e53454 in void pybind11::cpp_function::initialize, std::allocator>>&, pybind11::name, pybind11::is_method, pybind11::sibling>(bool (torch::executor::qnn::PyQnnManager::*)(std::vector, std::allocator>>&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(torch::executor::qnn::PyQnnManager*, std::vector, std::allocator>>&), bool, torch::executor::qnn::PyQnnManager*, std::vector, std::allocator>>&, pybind11::name, pybind11::is_method, pybind11::sibling>(bool&&, torch::executor::qnn::PyQnnManager (*)(std::vector, std::allocator>>&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(pybind11::detail::function_call&)::operator()(pybind11::detail::function_call&) const fbsource/pybind11/pybind11.h:193 #13 0x7f15b9e530d3 in void pybind11::cpp_function::initialize, std::allocator>>&, pybind11::name, pybind11::is_method, pybind11::sibling>(bool (torch::executor::qnn::PyQnnManager::*)(std::vector, std::allocator>>&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(torch::executor::qnn::PyQnnManager*, std::vector, std::allocator>>&), bool, torch::executor::qnn::PyQnnManager*, std::vector, std::allocator>>&, pybind11::name, pybind11::is_method, pybind11::sibling>(bool&&, torch::executor::qnn::PyQnnManager (*)(std::vector, std::allocator>>&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(pybind11::detail::function_call&)::__invoke(pybind11::detail::function_call&) fbsource/pybind11/pybind11.h:170 #14 0x7f15b9d8f707 in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) fbsource/pybind11/pybind11.h:767 #15 0x327141 in cfunction_call(_object*, _object*, _object*) (.__uniq.281047882695835599676768160755749362799) (/usr/local/fbcode/platform010/bin/python3.10+0x327141) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #16 0x349630 in _PyObject_MakeTpCall (/usr/local/fbcode/platform010/bin/python3.10+0x349630) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #17 0x5897d4 in method_vectorcall(_object*, _object* const*, unsigned long, _object*) (.__uniq.243338978568352371442406765225626566013.llvm.6236606370933165261) (/usr/local/fbcode/platform010/bin/python3.10+0x5897d4) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #18 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #19 0x331421 in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x331421) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #20 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #21 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #22 0x3313f2 in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3313f2) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #23 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #24 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #25 0x3313f2 in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3313f2) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #26 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #27 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #28 0x3313f2 in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3313f2) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #29 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #30 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #31 0x331577 in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x331577) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #32 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #33 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #34 0x3313f2 in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3313f2) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #35 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #36 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #37 0x3313f2 in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3313f2) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #38 0x39b8ca in _PyEval_Vector (/usr/local/fbcode/platform010/bin/python3.10+0x39b8ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #39 0x39ad7d in _PyObject_FastCallDictTstate (/usr/local/fbcode/platform010/bin/python3.10+0x39ad7d) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #40 0x3c8b72 in slot_tp_call(_object*, _object*, _object*) (.__uniq.235726554139783955843240177532338160225) (/usr/local/fbcode/platform010/bin/python3.10+0x3c8b72) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #41 0x392ca8 in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x392ca8) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #42 0x3314ca in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3314ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #43 0x39b8ca in _PyEval_Vector (/usr/local/fbcode/platform010/bin/python3.10+0x39b8ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #44 0x331b18 in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x331b18) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #45 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #46 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #47 0x3314ca in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3314ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #48 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #49 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #50 0x3313f2 in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3313f2) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #51 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #52 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #53 0x3313f2 in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3313f2) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #54 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #55 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #56 0x3314ca in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3314ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #57 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #58 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #59 0x3314ca in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3314ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #60 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #61 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #62 0x3314ca in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3314ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #63 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #64 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #65 0x3314ca in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3314ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #66 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #67 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #68 0x3314ca in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3314ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #69 0x327547 in _PyFunction_Vectorcall (/usr/local/fbcode/platform010/bin/python3.10+0x327547) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #70 0x3928df in call_function(_ts*, PyTraceInfo*, _object***, long, _object*) (.__uniq.79849310599369217189729546442812793949) (/usr/local/fbcode/platform010/bin/python3.10+0x3928df) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #71 0x3314ca in _PyEval_EvalFrameDefault (/usr/local/fbcode/platform010/bin/python3.10+0x3314ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #72 0x39b8ca in _PyEval_Vector (/usr/local/fbcode/platform010/bin/python3.10+0x39b8ca) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #73 0x431565 in PyEval_EvalCode (/usr/local/fbcode/platform010/bin/python3.10+0x431565) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #74 0x431447 in run_mod(_mod*, _object*, _object*, _object*, PyCompilerFlags*, _arena*) (.__uniq.251861886623903963524397139660542440724.llvm.17622910512627074885) (/usr/local/fbcode/platform010/bin/python3.10+0x431447) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #75 0x4e3054 in pyrun_file(_IO_FILE*, _object*, int, _object*, _object*, int, PyCompilerFlags*) (.__uniq.251861886623903963524397139660542440724) (/usr/local/fbcode/platform010/bin/python3.10+0x4e3054) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #76 0x4e2b54 in _PyRun_SimpleFileObject (/usr/local/fbcode/platform010/bin/python3.10+0x4e2b54) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #77 0x4e28f1 in _PyRun_AnyFileObject (/usr/local/fbcode/platform010/bin/python3.10+0x4e28f1) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #78 0x4d4a54 in Py_RunMain (/usr/local/fbcode/platform010/bin/python3.10+0x4d4a54) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #79 0x4d286b in pymain_main(_PyArgv*) (.__uniq.297908980262787110426434251325078884054) (/usr/local/fbcode/platform010/bin/python3.10+0x4d286b) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #80 0x4d2759 in Py_BytesMain (/usr/local/fbcode/platform010/bin/python3.10+0x4d2759) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) #81 0x7f19e282c656 in __libc_start_call_main (/usr/local/fbcode/platform010/lib/libc.so.6+0x2c656) (BuildId: 93cdceeb8322234c38e1f2c93ad0ff10c7632fa6) #82 0x7f19e282c717 in __libc_start_main@GLIBC_2.2.5 (/usr/local/fbcode/platform010/lib/libc.so.6+0x2c717) (BuildId: 93cdceeb8322234c38e1f2c93ad0ff10c7632fa6) #83 0x553d90 in _start (/usr/local/fbcode/platform010/bin/python3.10+0x553d90) (BuildId: a620038add613fd8585eb50983ca8e455d54738e) AddressSanitizer can not provide additional info. AddressSanitizer: SEGV (/home/chenlai/fbsource/third-party/qualcomm/qnn/qnn-2.26/lib/x86_64-linux-clang/libQnnHtp.so+0x2ce38e2) (BuildId: bc3ab8ddc89a0e65) ==1523599==ABORTING ``` Differential Revision: D63736779 --- .../qualcomm/quantizer/custom_annotation.py | 26 + .../serialization/qnn_compile_spec_schema.py | 2 + backends/qualcomm/serialization/schema.fbs | 3 +- backends/qualcomm/tests/utils.py | 1 + examples/models/llama2/export_llama_lib.py | 55 +- examples/models/llama2/llama_transformer.py | 564 +++++------------- examples/models/llama2/main.cpp | 515 ++++++++++++++-- examples/models/llama2/model.py | 70 ++- .../models/llama2/params/demo_config.json | 2 +- examples/models/llama2/runner/targets.bzl | 1 + examples/models/model_factory.py | 8 +- examples/qualcomm/utils.py | 1 + extension/export_util/utils.py | 6 +- extension/llm/export/builder.py | 3 + extension/llm/export/partitioner_lib.py | 9 +- extension/llm/export/quantizer_lib.py | 14 +- 16 files changed, 758 insertions(+), 522 deletions(-) diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 9cde50b9c70..881d24bbb5e 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -118,3 +118,29 @@ def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig): if "SDPA" in full_qualified_name: annotate_matmul(node, quantization_config_16a8w) annotate_matmul_input1(node.args[1], quantization_config_8a8w) + + +def custom_annotate_matmul_16a8w(gm: torch.fx.GraphModule): + """ + Annotate matmul op with 16a8w quantization config + """ + + def annotate_matmul(node: Node, quantization_config: QuantizationConfig): + input_qspec_map = {} + input_act = node.args[0] + input_spec = quantization_config.input_activation + input_qspec_map[input_act] = input_spec + input_act1 = node.args[1] + input_spec1 = quantization_config.weight + input_qspec_map[input_act1] = input_spec1 + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + # Annotate 16a8w for matmul op to get better performance + quantization_config_16a8w = get_16a8w_qnn_ptq_config() + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: + annotate_matmul(node, quantization_config_16a8w) diff --git a/backends/qualcomm/serialization/qnn_compile_spec_schema.py b/backends/qualcomm/serialization/qnn_compile_spec_schema.py index 8471aad982d..c376dd6e476 100644 --- a/backends/qualcomm/serialization/qnn_compile_spec_schema.py +++ b/backends/qualcomm/serialization/qnn_compile_spec_schema.py @@ -34,6 +34,7 @@ class QcomChipset(IntEnum): SM8475 = 42 # v69 SM8550 = 43 # v73 SM8650 = 57 # v75 + SSG2115P = 46 # v73... I wish I can know where the number comes from... @dataclass @@ -47,6 +48,7 @@ class SocInfo: QcomChipset.SM8475: SocInfo(QcomChipset.SM8475, HtpInfo(HtpArch.V69, 8)), QcomChipset.SM8550: SocInfo(QcomChipset.SM8550, HtpInfo(HtpArch.V73, 8)), QcomChipset.SM8650: SocInfo(QcomChipset.SM8650, HtpInfo(HtpArch.V75, 8)), + QcomChipset.SSG2115P: SocInfo(QcomChipset.SSG2115P, HtpInfo(HtpArch.V73, 2)), } diff --git a/backends/qualcomm/serialization/schema.fbs b/backends/qualcomm/serialization/schema.fbs index 4e7fdb56e89..f2275377f7b 100644 --- a/backends/qualcomm/serialization/schema.fbs +++ b/backends/qualcomm/serialization/schema.fbs @@ -32,6 +32,7 @@ enum QcomChipset: int { SM8450 = 36, SM8475 = 42, SM8550 = 43, + SSG2115P = 46, SM8650 = 57, } @@ -170,7 +171,7 @@ table QnnExecuTorchOptions { /// Profiling level of the delegate and the backend. Default is off. profile_level:QnnExecuTorchProfileLevel; - + /// Enables usage of shared buffer between application and backend for graph I/O. shared_buffer:bool; diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 7209b0a2678..52ffac46eee 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -118,6 +118,7 @@ class TestQNN(unittest.TestCase): model: QcomChipset = None compiler_specs: List[CompileSpec] = None arch_table = { + "SSG2115P": QcomChipset.SSG2115P, "SM8650": QcomChipset.SM8650, "SM8550": QcomChipset.SM8550, "SM8475": QcomChipset.SM8475, diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index cf8d221c8e5..fff29185413 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -53,21 +53,23 @@ get_quant_embedding_transform, get_quant_weight_transform, ) -from .source_transformation.quantized_kv_cache import ( - replace_kv_cache_with_quantized_kv_cache, -) + +# from .source_transformation.quantized_kv_cache import ( +# replace_kv_cache_with_quantized_kv_cache, +# ) from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis -from .source_transformation.sdpa import ( - replace_causal_mask, - replace_kv_cache_with_coreml_kv_cache, - replace_kv_cache_with_simple_kv_cache, - replace_sdpa_with_coreml_sdpa, - replace_sdpa_with_custom_op, - replace_sdpa_with_flex_sdpa, - replace_sdpa_with_simple_sdpa, -) + +# from .source_transformation.sdpa import ( +# replace_causal_mask, +# replace_kv_cache_with_coreml_kv_cache, +# replace_kv_cache_with_simple_kv_cache, +# replace_sdpa_with_coreml_sdpa, +# replace_sdpa_with_custom_op, +# replace_sdpa_with_flex_sdpa, +# replace_sdpa_with_simple_sdpa, +# ) IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -910,23 +912,20 @@ def _get_source_transforms( # noqa assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True" transforms.append(replace_kv_cache_with_quantized_kv_cache) + if args.qnn: + # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` + from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d + + # transforms.append(replace_kv_cache_with_simple_kv_cache) + # transforms.append(replace_sdpa_with_flex_sdpa) + # transforms.append(replace_causal_mask) + transforms.append(replace_rms_norm_with_native_rms_norm) + if args.optimized_rotation_path: + transforms.append(fuse_layer_norms) + transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) + transforms.append(convert_linear_to_conv2d) if args.use_kv_cache: - if args.qnn: - # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` - from executorch.backends.qualcomm.utils.utils import ( - convert_linear_to_conv2d, - ) - - transforms.append(replace_kv_cache_with_simple_kv_cache) - transforms.append(replace_sdpa_with_flex_sdpa) - transforms.append(replace_causal_mask) - transforms.append(replace_rms_norm_with_native_rms_norm) - if args.optimized_rotation_path: - transforms.append(fuse_layer_norms) - transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) - transforms.append(convert_linear_to_conv2d) - - elif args.mps: + if args.mps: # Currently mps doesn't support sdpa op, use the simpler decomposition # to get free perf gain. transforms.append(replace_sdpa_with_simple_sdpa) diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 8e17013ae3d..b2453488785 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -1,27 +1,17 @@ -# @lint-ignore-every LICENSELINT -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# Llama 2 is licensed under the LLAMA 2 Community License, -# Copyright (c) Meta Platforms, Inc. All Rights Reserved. - -# Please refer to README.md in the same folder for more information. +#!/usr/bin/env python3 +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +import logging +import math from dataclasses import dataclass -from functools import partial -from typing import Dict, Optional, Tuple +from typing import Optional, Tuple import torch -import torch.nn.functional as F -from executorch.examples.models.llama2.rope import ( - apply_rotary_emb, - hf_apply_rotary_emb, - hf_precompute_freqs_cis, - precompute_freqs_cis, -) +from torch.nn import functional as F + -from torch import nn +logger: logging.Logger = logging.getLogger() class RMSNorm(torch.nn.Module): @@ -39,9 +29,8 @@ def __init__(self, dim: int, eps: float = 1e-6): """ super().__init__() - self.dim = dim self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) + self.weight = torch.nn.Parameter(torch.ones(dim)) def _norm(self, x): """ @@ -54,7 +43,7 @@ def _norm(self, x): torch.Tensor: The normalized tensor. """ - return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): """ @@ -71,12 +60,6 @@ def forward(self, x): return output * self.weight -def find_multiple(n: int, k: int) -> int: - if n % k == 0: - return n - return n + k - (n % k) - - @dataclass class ModelArgs: dim: int = 4096 @@ -84,182 +67,58 @@ class ModelArgs: n_heads: int = 32 n_kv_heads: Optional[int] = None vocab_size: int = -1 # defined later by tokenizer - hidden_dim: Optional[int] = None + invocation_vocab_size: int = -1 # defined later by tokenizer multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 ffn_dim_multiplier: Optional[float] = None norm_eps: float = 1e-5 max_batch_size: int = 32 max_seq_len: int = 2048 - moe: bool = False # True to enable the MoE (Mixture of Experts) - num_experts: int = 8 # Number of experts - num_activated_experts: int = 2 # Number of experts to activate - use_kv_cache: bool = False # Use key/value cache - use_sdpa_with_kv_cache_op: bool = ( - False # Use custom sdpa op that updates kv cache in-place - ) - # Generate logits for all inputs. When it's True, it would take big memory usage - # at runtime. Enable it only necessary (e.g., use perplexity tools that requires - # logits for all input tokens.) - generate_full_logits: bool = False - enable_dynamic_shape: bool = False # export model with dynamic shape support - # A dictionary mapping from pruned token-id to original token-id - output_prune_map: Optional[Dict[int, int]] = None - use_hf_rope: bool = False # Use HuggingFace's RoPE implementation - rope_theta: Optional[float] = ( - None # The official name to override self.rope_freq_base. - ) - rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC. - use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1. - # Additional Model Metadata needed at runtime - bos_idx: int = 1 - eos_idx: int = 3 - bos_count: int = -1 # i.e., a single EOS is used as BOS - eos_count: int = 2 - - def __post_init__(self): - if self.n_kv_heads is None: - self.n_kv_heads = self.n_heads - - # rope_theta overrides rope_freq_base since it's the official name. - if self.rope_theta is not None: - self.rope_freq_base = self.rope_theta - - if self.use_sdpa_with_kv_cache_op: - assert self.use_kv_cache, "use_sdpa_with_kv_cache_op requires use_kv_cache" - - if self.hidden_dim is None: - # If hidden_dim is not explicitly set in the ModelArgs, - # then calculate implicitly based on dim and also multiple of `args.multiple_of` - multiple_of = self.multiple_of - hidden_dim = 4 * self.dim - hidden_dim = int(2 * hidden_dim / 3) - if self.ffn_dim_multiplier is not None: - hidden_dim = int(self.ffn_dim_multiplier * hidden_dim) - self.hidden_dim = find_multiple(hidden_dim, multiple_of) + use_layer_norm_op: bool = False + use_rms_norm_op: bool = False + hidden_dim: Optional[int] = None -class KVCache(nn.Module): - def __init__( - self, - max_batch_size: int, - max_seq_length: int, - n_heads: int, - head_dim: int, - transpose_cache: bool, - enable_dynamic_shape: bool, - dtype=torch.float32, - ): - super().__init__() - self.max_seq_length = max_seq_length - self.is_tranposed = transpose_cache - if transpose_cache: - cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) - else: - cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) - - self.max_batch_size = max_batch_size - self.n_heads = n_heads - self.head_dim = head_dim - self.transpose_cache = transpose_cache - self.enable_dynamic_shape = enable_dynamic_shape - self.register_buffer( - "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") - ) - self.register_buffer( - "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") - ) +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # pyre-ignore + freqs = torch.outer(t, freqs).float() # pyre-ignore + freqs_cos = torch.cos(freqs) + freqs_sin = torch.sin(freqs) + return freqs_cos, freqs_sin - def update( - self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache - if self.enable_dynamic_shape: - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - torch._check(start_pos < self.max_seq_length) - dim_to_slice = 2 if self.transpose_cache else 1 - seq_length = k_val.size(dim_to_slice) - # Replace the entry in the cache for this token - # The following lines are equivalent to: - # cache_k[:bsz, start_pos : start_pos + seqlen] = xk - # cache_v[:bsz, start_pos : start_pos + seqlen] = xv - # when dim_to_slice is 1 - # We use .narrow() here to make the compiler happy - # pyre-ignore: Incompatible parameter type [6] - narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) - # pyre-ignore: Incompatible parameter type [6] - narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) - - narrowed_k.copy_(k_val) - narrowed_v.copy_(v_val) - return self.k_cache, self.v_cache - else: - k_out = self.k_cache - v_out = self.v_cache - if self.transpose_cache: - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val - else: - k_out[:, input_pos] = k_val - v_out[:, input_pos] = v_val - return k_out, v_out +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(shape) -class SDPA(nn.Module): - def __init__( - self, - kv_cache: KVCache, - dim: int, - head_dim: int, - n_rep: int, - max_seq_len: int, - enable_dynamic_shape: bool, - ): - super().__init__() - self.kv_cache = kv_cache - self.dim = dim - self.head_dim = head_dim - self.n_rep = n_rep - self.max_seq_len = max_seq_len - self.enable_dynamic_shape = enable_dynamic_shape +def apply_rotary_emb( + xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: - def forward( - self, - input_pos: torch.Tensor, - q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim) - k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim) - v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim) - bsz, - seqlen, - mask: torch.Tensor, - ) -> torch.Tensor: - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - k, v = self.kv_cache.update(input_pos, k, v) - if self.enable_dynamic_shape: - start_pos = input_pos[-1].item() - torch._check_is_size(start_pos) - torch._check(start_pos < self.max_seq_len) - seq_length = q.size(2) - # pyre-ignore: Incompatible parameter type [6] - attn_mask = mask.narrow(0, start_pos, seq_length) - else: - attn_mask = mask[None, None, input_pos] + xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) + xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) + + freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) + freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) - k = k.repeat_interleave(self.n_rep, dim=1) - v = v.repeat_interleave(self.n_rep, dim=1) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) + xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin + xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos + xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin + xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos - return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) + xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) -class Attention(nn.Module): - def __init__(self, args: ModelArgs, layer_id: int): + +class Attention(torch.nn.Module): + def __init__(self, args: ModelArgs): super().__init__() - self.use_kv_cache = args.use_kv_cache self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads assert args.n_heads % self.n_kv_heads == 0 model_parallel_size = 1 @@ -267,295 +126,158 @@ def __init__(self, args: ModelArgs, layer_id: int): self.n_local_kv_heads = self.n_kv_heads // model_parallel_size self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = args.dim // args.n_heads - self.max_batch_size = args.max_batch_size - self.max_seq_len = args.max_seq_len - self.dim = args.dim - # args.dim = 4096, args.n_heads = 32, self.head_dim = 4096 / 32 = 125 - self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) - self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) - - self.layer_id = layer_id - - causal_mask = torch.tril( - torch.ones( - self.max_seq_len, - self.max_seq_len, - dtype=torch.bool, - device="cpu", - ) - ) - self.register_buffer("mask", causal_mask, persistent=False) - - if self.use_kv_cache: - self.kv_cache = KVCache( - args.max_batch_size, - args.max_seq_len, - self.n_kv_heads, - self.head_dim, - not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v - args.enable_dynamic_shape, - ) - self.SDPA = SDPA( - kv_cache=self.kv_cache, - dim=self.dim, - head_dim=self.head_dim, - n_rep=self.n_rep, - max_seq_len=self.max_seq_len, - enable_dynamic_shape=args.enable_dynamic_shape, - ) - if args.use_hf_rope: - self.apply_rotary_emb = hf_apply_rotary_emb - else: - self.apply_rotary_emb = apply_rotary_emb + self.wq = torch.nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) + self.wk = torch.nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = torch.nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = torch.nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) + # set large value of -inf (or -32768 with int16) when we want to + # ignore correspnding values in the mask + mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-32768")) + mask = torch.triu(mask, diagonal=1) + self.register_buffer("mask", mask) def forward( self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, - input_pos: Optional[torch.Tensor] = None, ): bsz, seqlen, _ = x.shape # QKV - q, k, v = self.wq(x), self.wk(x), self.wv(x) - # We need view_copy elimination - q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) - k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) # RoPE relative positional embeddings - q, k = self.apply_rotary_emb(q, k, freqs_cos, freqs_sin) - - if self.use_kv_cache: - assert input_pos is not None - output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask) - return self.wo(output) - - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) + xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) # grouped multiquery attention: expand out keys and values - k = k.repeat_interleave(self.n_rep, dim=1) - v = v.repeat_interleave(self.n_rep, dim=1) - + xk = [ + torch.cat([xk[:, :, i : i + 1, :]] * self.n_rep, dim=2) + for i in range(xk.size(2)) + ] + xk = torch.cat(xk, dim=2) + + xv = [ + torch.cat([xv[:, :, i : i + 1, :]] * self.n_rep, dim=2) + for i in range(xv.size(2)) + ] + xv = torch.cat(xv, dim=2) + + # make heads into a batch dimension + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = xk.transpose(1, 2) + xv = xv.transpose(1, 2) + + scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim) assert hasattr(self, "mask") - - mask = self.mask[:seqlen, :seqlen] - - output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + scores = ( + scores + self.mask[:, :, :seqlen, :seqlen] + ) # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim) output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) output = self.wo(output) - return output -class FeedForward(nn.Module): - def __init__(self, args: ModelArgs): +class FeedForward(torch.nn.Module): + def __init__(self, dim: int, hidden_dim: int, multiple_of: int): super().__init__() - assert args.hidden_dim is not None - hidden_dim: int = args.hidden_dim - self.w1 = nn.Linear(args.dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, args.dim, bias=False) - self.w3 = nn.Linear(args.dim, hidden_dim, bias=False) + self.w1 = torch.nn.Linear(dim, hidden_dim, bias=False) + self.w2 = torch.nn.Linear(hidden_dim, dim, bias=False) + self.w3 = torch.nn.Linear(dim, hidden_dim, bias=False) def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) + x = F.silu(self.w1(x)) * self.w3(x) + x = self.w2(x) + return x -class ConditionalFeedForward(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.dim = args.dim - hidden_dim = args.hidden_dim - if hidden_dim is None: - # If hidden_dim is not explicitly set in the ModelArgs, - # then calculate implicitly based on dim and also multiple of `args.multiple_of` - multiple_of = args.multiple_of - hidden_dim = 4 * self.dim - hidden_dim = int(2 * hidden_dim / 3) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - - self.w1 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim)) - self.w2 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim)) - self.w3 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim)) - self.num_experts = args.num_experts - - def forward(self, x: torch.Tensor, expert_indices: torch.Tensor) -> torch.Tensor: - w1_weights = self.w1[expert_indices].transpose(-1, -2) # [T, A, D, D] - w3_weights = self.w3[expert_indices].transpose(-1, -2) # [T, A, D, D] - w2_weights = self.w2[expert_indices] # [T, A, D, D] - x1 = F.silu(torch.einsum("ti,taio -> tao", x, w1_weights)) - x3 = torch.einsum("ti, taio -> tao", x, w3_weights) - expert_outs = torch.einsum("tao, taoi -> tai", (x1 * x3), w2_weights) - return expert_outs - - -class MOEFeedForward(nn.Module): - def __init__(self, config) -> None: - super().__init__() - self.gate = nn.Linear(config.dim, config.num_experts, bias=False) - self.cond_ffn = ConditionalFeedForward(config) - self.dim = config.dim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x.view(-1, self.dim) - # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts - # x: [T, D] - scores = self.gate(x) # [T, E] - expert_weights, expert_indices = torch.topk(scores, 2, dim=-1) # [T, A], [T, A] - expert_weights = expert_weights.softmax(dim=-1) # [T, A] - expert_outs = self.cond_ffn(x, expert_indices) - return torch.einsum("tai,ta -> ti", expert_outs, expert_weights) - - -class TransformerBlock(nn.Module): +class TransformerBlock(torch.nn.Module): def __init__(self, layer_id: int, args: ModelArgs): super().__init__() - self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.dim // args.n_heads - self.attention = Attention(args, layer_id) - if args.moe: - self.block_sparse_moe = MOEFeedForward(args) + self.attention = Attention(args) + if args.hidden_dim is None: + hidden_dim = 4 * args.dim + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = args.multiple_of * ( + (hidden_dim + args.multiple_of - 1) // args.multiple_of + ) else: - self.feed_forward = FeedForward(args) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) - - def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN - h = self.attention.forward( - self.attention_norm(x), freqs_cos, freqs_sin, input_pos + hidden_dim = args.hidden_dim + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=hidden_dim, + multiple_of=args.multiple_of, ) - - h = x + h - if hasattr(self, "block_sparse_moe"): - out = h + self.block_sparse_moe(self.ffn_norm(h)) + self.layer_id = layer_id + if args.use_layer_norm_op: + self.attention_norm = torch.nn.LayerNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = torch.nn.LayerNorm(args.dim, eps=args.norm_eps) + elif args.use_rms_norm_op: + self.attention_norm = torch.nn.RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = torch.nn.RMSNorm(args.dim, eps=args.norm_eps) else: - out = h + self.feed_forward(self.ffn_norm(h)) + self.attention_norm = torch.nn.RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = torch.nn.RMSNorm(args.dim, eps=args.norm_eps) + + def forward(self, x, freqs_cos, freqs_sin): + h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin) + out = h + self.feed_forward.forward(self.ffn_norm(h)) return out -class Transformer(nn.Module): +class LastTimeStepPool(torch.nn.Module): + def forward(self, logits: torch.Tensor, seq_lens: torch.Tensor) -> torch.Tensor: + bsz, _, dim = logits.shape + idx = seq_lens.unsqueeze(1).expand(bsz, dim).unsqueeze(1) + return logits.gather(1, idx - 1).squeeze(1) + + +class Transformer(torch.nn.Module): def __init__(self, params: ModelArgs): super().__init__() self.params = params self.vocab_size = params.vocab_size self.n_layers = params.n_layers - self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) + self.tok_embeddings = torch.nn.Embedding(params.vocab_size, params.dim) self.layers = torch.nn.ModuleList() for layer_id in range(params.n_layers): self.layers.append(TransformerBlock(layer_id, params)) - self.norm = RMSNorm(params.dim, eps=params.norm_eps) - self.output = nn.Linear(params.dim, params.vocab_size, bias=False) - self.use_kv_cache = params.use_kv_cache - self.generate_full_logits = params.generate_full_logits - self.max_seq_len = params.max_seq_len - self.output_prune_map = params.output_prune_map - if params.use_hf_rope: - self.precompute_freqs_cis = hf_precompute_freqs_cis + if params.use_layer_norm_op: + self.norm = torch.nn.LayerNorm(params.dim, eps=params.norm_eps) + elif params.use_rms_norm_op: + self.norm = torch.nn.RMSNorm(params.dim, eps=params.norm_eps) else: - self.precompute_freqs_cis = partial( - precompute_freqs_cis, use_scaled=params.use_scaled_rope - ) - freqs_cos, freqs_sin = self.precompute_freqs_cis( - params.dim // params.n_heads, - ( - params.max_seq_len # Normal llama2. - if params.ffn_dim_multiplier is None - else params.max_seq_len * 2 # Sharded checkpoint. - ), - params.rope_freq_base, + self.norm = torch.nn.RMSNorm(params.dim, eps=params.norm_eps) + self.out = torch.nn.Linear(params.dim, params.vocab_size, bias=False) + + freqs_cos, freqs_sin = precompute_freqs_cis( + self.params.dim // self.params.n_heads, self.params.max_seq_len ) self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False) - def forward( - self, - tokens: Optional[torch.LongTensor] = None, # tokens - input_pos: Optional[ - torch.LongTensor - ] = None, # Scalar tensor indicating size of window of the caches - h: Optional[torch.FloatTensor] = None, # embeddings - ) -> torch.Tensor: - if (tokens is None) ^ (h is not None): - raise ValueError( - "You cannot specify both tokens and h at the same time, and must specify either one" - ) - if tokens is not None and h is None: - h = self.tok_embeddings(tokens) - seqlen = h.shape[1] - - if self.use_kv_cache: - assert ( - input_pos is not None - ), "input_pos must be provided when use_kv_cache is True" - - if self.params.enable_dynamic_shape: - # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos. - input_pos_item = input_pos[-1].item() - torch._check_is_size(input_pos_item) - torch._check(input_pos_item < self.params.max_seq_len) - # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor - freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen) - # pyre-ignore: Incompatible parameter type [6] - freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen) - else: - # When not using dynamic shape, use of the .item results in - # symints, due to querying the data from tensor. - # this path avoids that for mps backend, although probably mps backend - # can support dynamic shape? - freqs_cos = self.freqs_cos[input_pos] - freqs_sin = self.freqs_sin[input_pos] - - else: - assert input_pos is None, "input_pos is unused when use_kv_cache is False" - freqs_cos = self.freqs_cos[:seqlen] - freqs_sin = self.freqs_sin[:seqlen] + def forward(self, tokens: torch.Tensor) -> torch.Tensor: + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens) + freqs_cos = self.freqs_cos[:seqlen] + freqs_sin = self.freqs_sin[:seqlen] for layer in self.layers: - h = layer( - h, - freqs_cos, - freqs_sin, - input_pos, - ) - - if not self.generate_full_logits: - # Only the last logit is used for the new generated token - h = h[:, -1, :] + h = layer(h, freqs_cos, freqs_sin) h = self.norm(h) - logits = self.output(h) - - if self.output_prune_map is not None: - # expand to original size so that downstream applications can use the logits as-is. - if self.generate_full_logits: - # (1, seq_len, pruned_size) -> (1, seq_len, original_size) - expanded_logits = torch.full( - [logits.shape[0], logits.shape[1], self.vocab_size], - float("-inf"), - device=logits.device, - dtype=logits.dtype, - ) - expanded_logits[:, :, list(self.output_prune_map.values())] = logits - else: - # (1, pruned_size) -> (1, original_size) - expanded_logits = torch.full( - [logits.shape[0], self.vocab_size], - float("-inf"), - device=logits.device, - dtype=logits.dtype, - ) - expanded_logits[:, list(self.output_prune_map.values())] = logits - logits = expanded_logits - - return logits + invocation_logits = self.out(h) + + return invocation_logits diff --git a/examples/models/llama2/main.cpp b/examples/models/llama2/main.cpp index 339b2abfdb4..61d4564a1c1 100644 --- a/examples/models/llama2/main.cpp +++ b/examples/models/llama2/main.cpp @@ -1,19 +1,467 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +#include +#include #include -#include +// #if defined(ET_USE_THREADPOOL) +// #include +// #include +// #endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +/* + +The end to end flow to run this cria is as follows: +1. Build the cria model using the following command: + +Get model checkpoint and tokenizer +``` +manifold get +assistant_nlu/tree/users/shreyd/cria_arbitration/HTP/llama3_386M_LD12_ckpt.pt +/tmp/llama3_386M_LD12_ckpt.pt --threads 20 + +manifold get executorch/tree/models/llama/llama3/tokenizer.model +/tmp/tokenizer.model --threads 20 +``` +Generate the model given the checkpoint and params +``` +buck run @mode/dev-nosan //bolt/nn/executorch/export:export_cria_model +``` +It will generate a model file in the tmp directory, as described by the log + +2. Build the runtime: +``` +buck build @arvr/mode/android/linux/dev +//arvr/projects/bolt/bolt/nn/apps:cria_prefill_runner_app +--out /tmp +``` + +3. Push models and binary to device +``` +adb push /tmp/cria_prefill_runner_app /vendor/bin +adb push /tmp/on_device_model.pte /data/local/tmp +adb push /tmp/tokenizer.model /data/local/tmp +``` +run the binary on device +``` +adb shell LD_LIBRARY_PATH=/vendor/lib64 cria_prefill_runner_app --model_path +/data/local/tmp/on_device_model.pte --tokenizer_path +/data/local/tmp/tokenizer.model +``` +*/ + +double get_interval( + const std::chrono::time_point& end, + const std::chrono::time_point& start) { + auto duration = + std::chrono::duration_cast(end - start); + return static_cast(duration.count()); +} + +namespace torch::executor { +using Stats = ::executorch::llm::Stats; + +class Runner { + public: + explicit Runner( + const std::string& model_path, + const std::string& tokenizer_path, + float temperature = 0.8f); + + [[nodiscard]] bool is_loaded() const; + Error load(); + Error generate( + const std::string& prompt, + int32_t seq_len = 128, + const std::function& token_callback = {}, + const std::function& stats_callback = {}); + void stop(); + + private: + // metadata + template + T getMetadataHelper(const std::string& method_name, T default_val); + int32_t logitsToToken(const exec_aten::Tensor& logits_tensor); + Result prefill( + const std::vector& tokens, + executorch::extension::TensorPtr& managed_tokens, + executorch::extension::TensorPtr& managed_start_pos, + const std::function& token_callback); + Result run_model_step( + int64_t input_token, + executorch::extension::TensorPtr& tokens, + executorch::extension::TensorPtr& start_pos, + size_t max_seq_len); + // metadata + int32_t vocab_size_{}; + int32_t bos_id_{}; + int32_t eos_id_{}; + int32_t n_bos_{}; + int32_t n_eos_{}; + int32_t max_seq_len_{}; + bool append_eos_{}; + std::unordered_set model_methods_; + std::string model_path_; + std::unique_ptr module_; + std::string tokenizer_path_; + float temperature_; + std::unique_ptr tokenizer_; + std::unique_ptr sampler_; + bool shouldStop_{false}; + Stats stats_; + bool enable_parallel_prefill_{}; +}; + +Runner::Runner( + const std::string& model_path, + const std::string& tokenizer_path, + const float temperature) + // NOTE: we observed ~2x loading performance increase on iPhone 15 + // and a ~5% improvement on Galaxy S22 by switching to + // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors. + : module_(std::make_unique(model_path, Module::LoadMode::File)), + tokenizer_path_(tokenizer_path), + temperature_(temperature) { + ET_LOG( + Info, + "Creating LLaMa runner: model_path=%s, tokenizer_path=%s", + model_path.c_str(), + tokenizer_path.c_str()); +} + +bool Runner::is_loaded() const { + return module_->is_loaded() && tokenizer_ && sampler_; +} + +Error Runner::load() { + if (is_loaded()) { + return Error::Ok; + } + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward")); + + // Read out metadata: vocab_size (expected by the model), BOS, EOS, n_BOS, + // n_EOS max_seq_len from the model + ET_LOG(Info, "Reading metadata from model"); + const auto method_names = module_->method_names(); + ET_CHECK_MSG(method_names.ok(), "Failed to read method names from model"); + model_methods_ = method_names.get(); + n_bos_ = getMetadataHelper("get_n_bos", 1); + n_eos_ = getMetadataHelper("get_n_eos", 1); + // max_seq_len_ = getMetadataHelper("get_max_seq_len", 33); + max_seq_len_ = 17; + append_eos_ = getMetadataHelper("append_eos_to_prompt", false); + enable_parallel_prefill_ = getMetadataHelper("enable_dynamic_shape", false); -#if defined(ET_USE_THREADPOOL) -#include -#include -#endif + tokenizer_ = example::get_tiktoken_for_llama(); + tokenizer_->load(tokenizer_path_); + + vocab_size_ = + getMetadataHelper("get_vocab_size", tokenizer_->vocab_size()); + bos_id_ = getMetadataHelper("get_bos_id", tokenizer_->bos_tok()); + eos_id_ = getMetadataHelper("get_eos_id", tokenizer_->eos_tok()); + + // Create sampler + sampler_ = std::make_unique( + vocab_size_, + temperature_, + ::executorch::llm::kTopp, + static_cast(std::time(nullptr))); + + return Error::Ok; +} + +template +T Runner::getMetadataHelper(const std::string& method_name, T default_val) { + T res = default_val; + if (model_methods_.count(method_name)) { + Result> outputs = module_->execute(method_name); + if (outputs.ok()) { + std::vector outs = outputs.get(); + if (!outs.empty()) { + res = outs[0].to(); + } + } + } else { + ET_LOG( + Info, + "The model does not contain %s method, using default value %lld", + method_name.c_str(), + (long long)default_val); + } + ET_LOG(Info, "%s: %lld", method_name.c_str(), (long long)res); + return res; +} + +int32_t Runner::logitsToToken(const exec_aten::Tensor& logits_tensor) { + ET_CHECK_MSG(logits_tensor.dim() == 3, "Logits tensor must be 3D"); + auto num_tokens = logits_tensor.size(1); + + switch (logits_tensor.scalar_type()) { + case ScalarType::Float: { + auto* logits = logits_tensor.mutable_data_ptr(); + float* logits_last = logits; + logits_last += (num_tokens - 1) * tokenizer_->vocab_size(); + return sampler_->sample(logits_last); + } + case ScalarType::Half: { + auto* logits = logits_tensor.mutable_data_ptr(); + exec_aten::Half* logits_last = logits; + logits_last += (num_tokens - 1) * tokenizer_->vocab_size(); + return sampler_->sample(logits_last); + } + default: + ET_CHECK_MSG( + false, + "Unsupported dtype output %hhd", + static_cast(logits_tensor.scalar_type())); + } +} + +Result Runner::prefill( + const std::vector& tokens, + executorch::extension::TensorPtr& managed_tokens, + executorch::extension::TensorPtr& /*managed_start_pos*/, + + const std::function& token_callback) { + // enable_parallel_prefill_ maybe set even when not using kv cache + // When kv cache is not used, start pos is ignored + int32_t num_tokens = tokens.size(); + ET_LOG(Info, "Prefilling %d tokens", num_tokens); + + ET_CHECK_OK_OR_RETURN_ERROR(executorch::extension::resize_tensor_ptr( + managed_tokens, {1, num_tokens})); + auto* tokens_ptr = managed_tokens->mutable_data_ptr(); + for (int i = 0; i < num_tokens; i++) { + // The following assumes batch size = 1 + tokens_ptr[i] = tokens[i]; + } + std::vector inputs; + + // inputs:[tokens, start_pos] + inputs.emplace_back(managed_tokens); + // inputs.push_back(start_pos); + + auto before_exec = std::chrono::high_resolution_clock::now(); + Result> outputs_res = module_->forward(inputs); + auto after_exec = std::chrono::high_resolution_clock::now(); + ET_LOG(Info, "execute took %f ms", get_interval(after_exec, before_exec)); + + ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); + ET_CHECK_MSG( + outputs_res.get()[0].isTensor(), + "Non Tensor Output returned from executing LLM"); + ET_CHECK_MSG( + outputs_res.get()[0].toTensor().size(1) == num_tokens, + "Expected number of output tokens %d does not match returned value %zu.", + num_tokens, + outputs_res.get()[0].toTensor().size(1)); + + // start_pos.mutable_data_ptr()[0] = num_tokens; + + uint64_t prev = tokens[0]; + uint64_t cur = 0; + for (int i = 1; i < num_tokens; i++) { + cur = tokens[i]; + auto piece_res = tokenizer_->decode(prev, cur); + ET_CHECK_OK_OR_RETURN_ERROR(piece_res.error()); + util::safe_printf(piece_res.get().c_str()); + fflush(stdout); + prev = cur; + if (token_callback) { + token_callback(piece_res.get()); + } + } + cur = logitsToToken(outputs_res.get()[0].toTensor()); + auto piece_res = tokenizer_->decode(prev, cur); + ET_CHECK(piece_res.ok()); + const char* piece = piece_res.get().c_str(); + util::safe_printf(piece); + fflush(stdout); + if (token_callback) { + token_callback(piece_res.get()); + } + + // Return the logits tensor + stats_.first_token_ms = util::time_in_ms(); + stats_.prompt_eval_end_ms = util::time_in_ms(); + return outputs_res.get()[0].toTensor(); +} + +// Given an input token. Set up the inputs for the model and execute a single +// step. Returning the logits tensor. +Result Runner::run_model_step( + int64_t input_token, + executorch::extension::TensorPtr& tokens, + executorch::extension::TensorPtr& start_pos, + size_t max_seq_len) { + std::vector inputs; + (void)start_pos; // unused + + // When not using kv-cache our input is the entire history of tokens we have + // seen, so resize input to be 1 larger and append the new token to the end. + // TODO does this work in ATen mode? + tokens->mutable_data_ptr()[tokens->size(1) - 1] = input_token; + + // inputs:[tokens] + inputs.emplace_back(tokens); + + auto outputs_res = module_->forward(inputs); + + ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); + ET_CHECK_MSG( + outputs_res.get().size() == 1, + "More then one output returned from executing LLM."); + ET_CHECK_MSG( + outputs_res.get()[0].isTensor(), + "Non Tensor Output returned from executing LLM"); + + if (tokens->size(1) < max_seq_len) { + // Resize the tokens tensor to be 1 larger for next step. + // Note that this relies on the fact that underlying memory is the same + // such that previous tokens stored there will still exist. + // Not a good thing to rely upon. + ET_CHECK_OK_OR_RETURN_ERROR(executorch::extension::resize_tensor_ptr( + tokens, {1, static_cast(tokens->size(1) + 1)})); + } + + // Return the logits tensor + return outputs_res.get()[0].toTensor(); +} + +Error Runner::generate( + const std::string& prompt, + int32_t seq_len, + const std::function& token_callback, + const std::function& stats_callback) { + // Prepare the inputs. + // Use ones-initialized inputs. + // auto generate_start = std::chrono::high_resolution_clock::now(); + ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); + if (!is_loaded()) { + stats_.model_load_start_ms = util::time_in_ms(); + ET_CHECK_OK_OR_RETURN_ERROR(load()); + stats_.model_load_end_ms = util::time_in_ms(); + } + + // First token time only measures the time it takes to encode the prompt and + // return a response token. + + stats_.inference_start_ms = util::time_in_ms(); + shouldStop_ = false; + + // Set the sequence length to the max seq length if not provided + seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_; + + // auto encode_start = std::chrono::high_resolution_clock::now(); + + Result> encode_res = + tokenizer_->encode(prompt, n_bos_, append_eos_ ? n_eos_ : 0); + + // auto encode_finish = std::chrono::high_resolution_clock::now(); + // ET_LOG(Info, "encoder took %f ms", get_interval(encode_finish, + // encode_start)); + + ET_CHECK_OK_OR_RETURN_ERROR( + encode_res.error(), "Failed to encode prompt %s", prompt.c_str()); + + // encode the (string) prompt into tokens sequence + std::vector prompt_tokens = encode_res.get(); + ET_LOG(Info, "Prompt tokens: %zu", prompt_tokens.size()); + int num_prompt_tokens = prompt_tokens.size(); + + ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token"); + ET_CHECK_MSG( + num_prompt_tokens < max_seq_len_, + "Max seq length exceeded - please increase max seq len value in .../llama2/model.py num_prompt_tokens: %d, max_seq_len: %d", + num_prompt_tokens, + max_seq_len_); + + ET_CHECK_MSG( + num_prompt_tokens < seq_len, + "Sequence length exceeded - please increase the seq_len value passed to generate()"); + + // start the main loop + int64_t pos = 0; // position in the sequence + + std::vector token_data; // allocate space for the tokens + std::vector token_shape = {1, seq_len}; + + std::vector start_pos_data; // allocate space for the tokens + std::vector start_pos_shape = {1}; + + token_data.resize(seq_len); + + // initialize tensor wrappers + auto tokens = executorch::extension::from_blob( + token_data.data(), token_shape, ScalarType::Long); + // Create with the max shape to approapriately set the capacity of this + // tensor, then resize back to 1 for first input. + ET_CHECK_OK_OR_RETURN_ERROR( + executorch::extension::resize_tensor_ptr(tokens, {1, 1})); + + auto start_pos = executorch::extension::from_blob( + start_pos_data.data(), start_pos_shape, ScalarType::Long); + + int64_t cur_token = prompt_tokens[0]; + + // Prefill first + // Here feed all tokens to the model and get the next predicted token + // after the prompt. After that we will enter generate loop. + auto prefill_res = prefill(prompt_tokens, tokens, start_pos, token_callback); + + ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); + exec_aten::Tensor& prefill_res_tensor = prefill_res.get(); + cur_token = logitsToToken(prefill_res_tensor); + ET_LOG(Info, "Prefill result: %ld", cur_token); + + ET_CHECK_OK_OR_RETURN_ERROR(executorch::extension::resize_tensor_ptr( + tokens, {1, num_prompt_tokens + 1})); + // tokens_managed.resize({1, num_prompt_tokens + 1}); + pos = num_prompt_tokens; + + stats_.inference_end_ms = util::time_in_ms(); + printf("\n"); + + if (pos == seq_len) { + ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len); + } + + stats_.num_prompt_tokens = num_prompt_tokens; + stats_.num_generated_tokens = pos - num_prompt_tokens; + ::executorch::llm::print_report(stats_); + if (stats_callback) { + stats_callback(stats_); + } + + return Error::Ok; +} + +void Runner::stop() { + shouldStop_ = true; +} + +// explicit instantiation of template methods +template int64_t Runner::getMetadataHelper( + const std::string& method_name, + int64_t default_val); +template bool Runner::getMetadataHelper( + const std::string& method_name, + bool default_val); +} // namespace torch::executor DEFINE_string( model_path, @@ -22,16 +470,19 @@ DEFINE_string( DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff."); -DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt."); +DEFINE_string( + prompt, + "How does artificial intelligence redefine the role of human creativity in the next decade?", + "Prompt."); DEFINE_double( temperature, - 0.8f, - "Temperature; Default is 0.8f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic"); + 0.0f, + "Temperature; Default is 0.0f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic"); DEFINE_int32( seq_len, - 128, + 17, "Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens."); DEFINE_int32( @@ -39,8 +490,6 @@ DEFINE_int32( -1, "Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device."); -DEFINE_bool(warmup, false, "Whether to run a warmup run."); - int32_t main(int32_t argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -52,32 +501,26 @@ int32_t main(int32_t argc, char** argv) { const char* tokenizer_path = FLAGS_tokenizer_path.c_str(); const char* prompt = FLAGS_prompt.c_str(); + ET_LOG(Info, "Prompt: %s", prompt); double temperature = FLAGS_temperature; int32_t seq_len = FLAGS_seq_len; - int32_t cpu_threads = FLAGS_cpu_threads; - - bool warmup = FLAGS_warmup; + // [[maybe_unused]] int32_t cpu_threads = FLAGS_cpu_threads; -#if defined(ET_USE_THREADPOOL) - uint32_t num_performant_cores = cpu_threads == -1 - ? ::executorch::extension::cpuinfo::get_num_performant_cores() - : static_cast(cpu_threads); - ET_LOG( - Info, "Resetting threadpool with num threads = %d", num_performant_cores); - if (num_performant_cores > 0) { - ::executorch::extension::threadpool::get_threadpool() - ->_unsafe_reset_threadpool(num_performant_cores); - } -#endif + // #if defined(ET_USE_THREADPOOL) + // uint32_t num_performant_cores = cpu_threads == -1 + // ? torch::executorch::cpuinfo::get_num_performant_cores() + // : static_cast(cpu_threads); + // ET_LOG(Info, "Resetting threadpool with num threads = %d", + // num_performant_cores); if (num_performant_cores > 0) { + // torch::executorch::threadpool::get_threadpool()->_unsafe_reset_threadpool(num_performant_cores); + // } + // #endif // create llama runner - example::Runner runner(model_path, tokenizer_path, temperature); + ::torch::executor::Runner runner(model_path, tokenizer_path, temperature); - if (warmup) { - runner.warmup(prompt, seq_len); - } // generate runner.generate(prompt, seq_len); diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index d8d0ff00ffa..e2364af2b62 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -150,16 +150,32 @@ def __init__(self, **kwargs): output_prune_map = {int(k): v for (k, v) in output_prune_map.items()} max_seq_len = self.max_seq_len max_batch_size = 1 + print("params: ", params) + params.pop("rope_theta", None) model_args: ModelArgs = ModelArgs( max_seq_len=max_seq_len, max_batch_size=max_batch_size, - use_kv_cache=self.use_kv_cache, - use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op, - generate_full_logits=self.generate_full_logits, - output_prune_map=output_prune_map, - enable_dynamic_shape=self.enable_dynamic_shape, + # input_vocab_size=params["input_vocab_size"], + # use_kv_cache=self.use_kv_cache, + # use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op, + # generate_full_logits=self.generate_full_logits, + # output_prune_map=output_prune_map, + # enable_dynamic_shape=self.enable_dynamic_shape, + use_layer_norm_op=True, **params, ) + # model_args: ModelArgs = ( + # ModelArgs( + # dim=512, + # hidden_dim=1536, + # n_heads=8, + # n_kv_heads=2, + # n_layers=19, + # vocab_size=128256, + # invocation_vocab_size=8, + # use_layer_norm_op=True, + # ), + # ) if kwargs.get("fairseq2", False): print("Using fairseq2 checkpoint") checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint) @@ -170,10 +186,24 @@ def __init__(self, **kwargs): print(f"{key} : {weights.numel()} : {weights.size()}") print("============= /weights ================") - # Within the device="meta" context, tensors that are created do not carry data. - # They possess all other metadata a tensor carries such as size, stride, requires_grad. - with torch.device("meta"): - self.model_ = Transformer(model_args) + # Within the device="meta" context, tensors that are created do not carry data. + # They possess all other metadata a tensor carries such as size, stride, requires_grad. + # with torch.device("meta"): + # self.model_ = Transformer(model_args) + # self.model_ = Transformer( + # ModelArgs( + # dim=512, + # hidden_dim=1536, + # n_heads=8, + # n_kv_heads=2, + # n_layers=19, + # vocab_size=128256, + # invocation_vocab_size=8, + # use_layer_norm_op=True, + # ), + # ) + self.model_ = Transformer(model_args) + print("model: ", self.model_) if "int8" in str(checkpoint_path): print("Using int8 weight-only quantization!") @@ -221,11 +251,11 @@ def __init__(self, **kwargs): # assign=True: load params/buffers by assignment instead of performing an in-place copy. # Because we are using device="meta", tensors do not have memory associated with them # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario. - missing, unexpected = self.model_.load_state_dict( - checkpoint, - strict=False, - assign=True, - ) # self.model_ = Transformer(gptconf) + # missing, unexpected = self.model_.load_state_dict( + # checkpoint, + # strict=False, + # assign=True, + # ) # self.model_ = Transformer(gptconf) if kwargs.get("verbose", False): print("============= missing keys ================") print(missing) @@ -254,11 +284,13 @@ def get_example_inputs(self): if self.use_kv_cache: return self.get_example_inputs_kvcache_sdpa() else: - return ( - torch.tensor( - [[1, 2, 3]], dtype=torch.long - ), # tokens, with kv cache our input token length is always just 1 token. - ) + # return ( + # torch.tensor( + # [[1, 2, 3]], dtype=torch.long + # ), # tokens, with kv cache our input token length is always just 1 token. + # ) + b = torch.ones(1, 16, dtype=torch.long) + return (b,) # assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working def get_example_inputs_kvcache_sdpa(self): diff --git a/examples/models/llama2/params/demo_config.json b/examples/models/llama2/params/demo_config.json index 13287f117e9..754d09b5ca2 100644 --- a/examples/models/llama2/params/demo_config.json +++ b/examples/models/llama2/params/demo_config.json @@ -1 +1 @@ -{"dim": 64, "multiple_of": 4, "n_heads": 8, "n_layers": 5, "norm_eps": 1e-05, "vocab_size": 512} \ No newline at end of file +{"dim": 64, "multiple_of": 4, "n_heads": 8, "n_layers": 1, "norm_eps": 1e-05, "vocab_size": 512} diff --git a/examples/models/llama2/runner/targets.bzl b/examples/models/llama2/runner/targets.bzl index 96d47ffce21..207f820d0f2 100644 --- a/examples/models/llama2/runner/targets.bzl +++ b/examples/models/llama2/runner/targets.bzl @@ -29,6 +29,7 @@ def define_common_targets(): ], # qnn_executorch_backend can be added below //executorch/backends/qualcomm:qnn_executorch_backend exported_deps = [ + # "//executorch/backends/qualcomm:qnn_executorch_backend", "//executorch/backends/xnnpack:xnnpack_backend", "//executorch/extension/llm/runner:stats", "//executorch/extension/llm/runner:text_decoder_runner" + aten_suffix, diff --git a/examples/models/model_factory.py b/examples/models/model_factory.py index fb317e3bca3..8913bd50484 100644 --- a/examples/models/model_factory.py +++ b/examples/models/model_factory.py @@ -35,9 +35,11 @@ def create_model( ValueError: If the provided model class is not found in the module. """ package_prefix = "executorch." if not os.getcwd().endswith("executorch") else "" - module = importlib.import_module( - f"{package_prefix}examples.models.{module_name}" - ) + print(f"package_prefix: {package_prefix}") + # module = importlib.import_module( + # f"{package_prefix}examples.models.{module_name}" + # ) + module = importlib.import_module(f"executorch.examples.models.{module_name}") if hasattr(module, model_class_name): model_class = getattr(module, model_class_name) diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 9c4cd4453f0..ac16343a8e8 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -83,6 +83,7 @@ def __init__( self.debug_output_path = f"{self.workspace}/debug_output.bin" self.output_folder = f"{self.workspace}/outputs" self.arch_table = { + "SSG2115P": "73", "SM8650": "75", "SM8550": "73", "SM8475": "69", diff --git a/extension/export_util/utils.py b/extension/export_util/utils.py index 40ceb6ffec2..eb7c50aa8a1 100644 --- a/extension/export_util/utils.py +++ b/extension/export_util/utils.py @@ -41,7 +41,7 @@ def _to_core_aten( model, example_inputs, dynamic_shapes=dynamic_shapes, strict=strict ) if verbose: - logging.info(f"Core ATen graph:\n{core_aten_ep.graph}") + print(f"Core ATen graph:\n{core_aten_ep.graph}") return core_aten_ep @@ -62,7 +62,7 @@ def _core_aten_to_edge( compile_config=edge_compile_config, ) if verbose: - logging.info(f"Exported graph:\n{edge_manager.exported_program()}") + print(f"Exported graph:\n{edge_manager.exported_program()}") return edge_manager @@ -117,7 +117,7 @@ def save_pte_program( try: with open(filename, "wb") as file: prog.write_to_file(file) - logging.info(f"Saved exported program to {filename}") + print(f"Saved exported program to {filename}") except Exception as e: logging.error(f"Error while saving to {filename}: {e}") diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index ae0ca6df757..16f77668839 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -9,6 +9,7 @@ # ExecuTorch. import logging +import os from enum import Enum from typing import Any, Callable, List, Optional @@ -34,6 +35,7 @@ from torch.ao.quantization.quantizer import Quantizer from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer from torch.nn.attention import SDPBackend +from tqdm import tqdm FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -150,6 +152,7 @@ def source_transform( return self def _get_dynamic_shape(self) -> Any: + return None if self.dynamic_shapes: return self.dynamic_shapes diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index 37b215a51ff..78aa2719cae 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -154,9 +154,9 @@ def get_qnn_partitioner( num_sharding: int = 0, soc_model: str = "SM8650", # default to SM8650 ): - assert ( - use_kv_cache is True - ), "Qualcomm backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment" + # assert ( + # use_kv_cache is True + # ), "Qualcomm backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment" try: # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.partition.qnn_partitioner` from executorch.backends.qualcomm.partition.qnn_partitioner import ( @@ -179,7 +179,8 @@ def get_qnn_partitioner( ) use_fp16 = True - skip_node_op_set = {"llama.fallback.default", "aten.embedding.default"} + # skip_node_op_set = {"llama.fallback.default", "aten.embedding.default"} + skip_node_op_set = {"llama.fallback.default"} if pt2e_quantize is not None: use_fp16 = False diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 45d9932724e..377fa56e7dd 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -148,6 +148,7 @@ def get_qnn_quantizer( try: from executorch.backends.qualcomm.quantizer.custom_annotation import ( # pyre-fixme[21] custom_annotate_llama_matmul_16a8w, + custom_annotate_matmul_16a8w, ) # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.quantizer.quantizer` @@ -199,7 +200,8 @@ def get_qnn_quantizer( ) qnn_quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. - custom_annotations = (custom_annotate_llama_matmul_16a8w,) + # custom_annotations = (custom_annotate_llama_matmul_16a8w,) + custom_annotations = (custom_annotate_matmul_16a8w,) else: raise AssertionError( f"No support for quant type {quant_config}. Support 8a8w, 16a16w and 16a4w." @@ -209,11 +211,11 @@ def get_qnn_quantizer( quantization_mode is None ), "Currently qnn backend only supports QnnQuantizer via pt2e flow" qnn_quantizer.add_custom_quant_annotations(custom_annotations) - qnn_quantizer.add_discard_ops( - [ - torch.ops.aten.embedding.default, - ] - ) + # qnn_quantizer.add_discard_ops( + # [ + # torch.ops.aten.embedding.default, + # ] + # ) return qnn_quantizer, quant_dtype