diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py new file mode 100644 index 0000000000..de4b7ce2cc --- /dev/null +++ b/backends/qualcomm/_passes/__init__.py @@ -0,0 +1,34 @@ +from .annotate_and_quant_scalar import AnnotateAndQuantScalar +from .annotate_decomposed import AnnotateDecomposed +from .annotate_quant_attrs import AnnotateQuantAttrs +from .convert_bmm_to_matmul import ConvertBmmToMatmul +from .convert_interpolate_with_upsample2d import ConvertInterpolateWithUpsample2D +from .convert_prelu import ConvertPReLU +from .convert_to_linear import ConvertToLinear +from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape +from .fold_qdq import FoldQDQ +from .i64_to_i32 import I64toI32 +from .layout_transform import LayoutTransform +from .recompose_pixel_unshuffle import RecomposePixelUnshuffle +from .recompose_rms_norm import RecomposeRmsNorm +from .remove_redundancy import RemoveRedundancy +from .replace_index_put_input import ReplaceIndexPutInput + + +__all__ = [ + AnnotateAndQuantScalar, + AnnotateDecomposed, + AnnotateQuantAttrs, + ConvertBmmToMatmul, + ConvertInterpolateWithUpsample2D, + ConvertPReLU, + ConvertToLinear, + ExpandBroadcastTensorShape, + FoldQDQ, + I64toI32, + LayoutTransform, + RecomposePixelUnshuffle, + RecomposeRmsNorm, + RemoveRedundancy, + ReplaceIndexPutInput, +] diff --git a/backends/qualcomm/_passes/annotate_and_quant_scalar.py b/backends/qualcomm/_passes/annotate_and_quant_scalar.py index 1db50694ec..86475c39b1 100644 --- a/backends/qualcomm/_passes/annotate_and_quant_scalar.py +++ b/backends/qualcomm/_passes/annotate_and_quant_scalar.py @@ -53,7 +53,9 @@ def _get_source_scalar_node(self, node: torch.fx.Node) -> torch.fx.Node: if node.op == "placeholder": if not (shape := node.meta["val"].size()): return node - assert f"The output of node {node} is not a scalar, but a tensor with shape {shape}" + assert ( + not shape + ), f"The output of node {node} is not a scalar, but a tensor with shape {shape}" return self._get_source_scalar_node(node.args[0]) def _update_scalar_node_attrs(self, node: torch.fx.Node, quant_attrs: Dict) -> Dict: diff --git a/backends/qualcomm/_passes/i64_to_i32.py b/backends/qualcomm/_passes/i64_to_i32.py index 1d2171cc37..29c747d1a1 100644 --- a/backends/qualcomm/_passes/i64_to_i32.py +++ b/backends/qualcomm/_passes/i64_to_i32.py @@ -3,6 +3,8 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import FrozenSet + import torch from executorch.backends.qualcomm.builders.utils import get_parameter, is_constant from executorch.exir.dialects._ops import ops as exir_ops @@ -15,9 +17,14 @@ class I64toI32(ExportPass): Cast unsupported int64 datatype into int32. """ - def __init__(self, edge_program: torch.export.ExportedProgram): + def __init__( + self, + edge_program: torch.export.ExportedProgram, + skip_node: FrozenSet[str] = frozenset(), + ): super(I64toI32, self).__init__() self.edge_program = edge_program + self.skip_node = skip_node # pyre-ignore[4] self.copy_op = exir_ops.edge.aten._to_copy.default @@ -42,6 +49,8 @@ def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool: def _cast_to_int32(self, graph_module: torch.fx.GraphModule): for n in graph_module.graph.nodes: + if n.target in self.skip_node: + continue if is_constant(n, self.edge_program): param = get_parameter(n, self.edge_program) if param.dtype == torch.int64: diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index ac6525ae76..a606a21c62 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -43,3 +43,63 @@ def get_quant_attrs( quant_attrs[QCOM_ENCODING] = quant_node.target return quant_attrs + + +def get_passes_dependency_for_capture_program(): + """ + This function records the dependencies for passes used in the capture_program. + + It returns a dictionary where the keys are pass classes and the values are lists of + dependencies required by each pass. This helps in managing and organizing the sequence + of passes needed for the capture_program to function correctly. + + Returns: + dict: A dictionary mapping each pass to its corresponding list of dependencies. + """ + from executorch.backends.qualcomm._passes import ( + AnnotateAndQuantScalar, + AnnotateDecomposed, + AnnotateQuantAttrs, + ConvertBmmToMatmul, + ConvertInterpolateWithUpsample2D, + ConvertPReLU, + ConvertToLinear, + ExpandBroadcastTensorShape, + FoldQDQ, + I64toI32, + LayoutTransform, + RecomposePixelUnshuffle, + RecomposeRmsNorm, + RemoveRedundancy, + ReplaceIndexPutInput, + ) + + return { + RecomposePixelUnshuffle: [RemoveRedundancy], + RecomposeRmsNorm: [RemoveRedundancy], + ConvertToLinear: [RecomposePixelUnshuffle], + ConvertPReLU: [RemoveRedundancy], + ConvertBmmToMatmul: [ConvertToLinear], + ConvertInterpolateWithUpsample2D: [RemoveRedundancy], + I64toI32: [RemoveRedundancy], + AnnotateQuantAttrs: [ + RecomposePixelUnshuffle, + RecomposeRmsNorm, + ConvertToLinear, + ConvertPReLU, + ConvertBmmToMatmul, + ConvertInterpolateWithUpsample2D, + ], + AnnotateAndQuantScalar: [ + AnnotateQuantAttrs, + ], + AnnotateDecomposed: [RemoveRedundancy], + FoldQDQ: [AnnotateQuantAttrs, AnnotateAndQuantScalar, AnnotateDecomposed], + ExpandBroadcastTensorShape: [RemoveRedundancy], + LayoutTransform: [ + AnnotateQuantAttrs, + AnnotateAndQuantScalar, + ExpandBroadcastTensorShape, + ], + ReplaceIndexPutInput: [LayoutTransform], + } diff --git a/backends/qualcomm/scripts/build.sh b/backends/qualcomm/scripts/build.sh index ed77a87351..506bb92752 100755 --- a/backends/qualcomm/scripts/build.sh +++ b/backends/qualcomm/scripts/build.sh @@ -87,6 +87,7 @@ if [ "$BUILD_AARCH64" = true ]; then -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK_ROOT/build/cmake/android.toolchain.cmake \ -DANDROID_ABI='arm64-v8a' \ -DANDROID_NATIVE_API_LEVEL=23 \ + -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ -B$BUILD_ROOT @@ -101,6 +102,7 @@ if [ "$BUILD_AARCH64" = true ]; then -DANDROID_ABI='arm64-v8a' \ -DANDROID_NATIVE_API_LEVEL=23 \ -DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \ + -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ -DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \ -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ -B$EXAMPLE_ROOT @@ -125,6 +127,7 @@ if [ "$BUILD_X86_64" = true ]; then -DEXECUTORCH_BUILD_QNN=ON \ -DEXECUTORCH_BUILD_DEVTOOLS=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_ENABLE_EVENT_TRACER=ON \ -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ diff --git a/backends/qualcomm/utils/constants.py b/backends/qualcomm/utils/constants.py index 1cc51690ff..4f73d331ad 100644 --- a/backends/qualcomm/utils/constants.py +++ b/backends/qualcomm/utils/constants.py @@ -26,8 +26,8 @@ QCOM_SCALE_OFFSET = "scale_offset" QCOM_ZERO_POINT = "zero_point" QCOM_ZERO_POINTS = "zero_points" -QCOM_PASS_EXPAND_BROADCAST_SHAPE = "expand_broadcast_shape" -QCOM_PASS_SKIP_ADVANCED_REQUANT = "skip_advanced_requant" +QCOM_PASS_ACTIVATE_KEY = "activate" +QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY = "args_kwargs_defaults" # constants in backends/qualcomm/tests QCOM_ANNOTATION = "annotation" diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index e15050fe4c..1bcfa3a6f6 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -3,13 +3,13 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - +import inspect import operator import re import time import warnings from collections import OrderedDict -from typing import Any, Callable, Dict, FrozenSet, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor @@ -46,6 +46,9 @@ from executorch.backends.qualcomm._passes.replace_index_put_input import ( ReplaceIndexPutInput, ) +from executorch.backends.qualcomm._passes.utils import ( + get_passes_dependency_for_capture_program, +) from executorch.backends.qualcomm.builders.node_visitor import ( QNN_QUANT_TYPE_MAP, @@ -74,8 +77,8 @@ option_to_flatbuffer, ) from executorch.backends.qualcomm.utils.constants import ( - QCOM_PASS_EXPAND_BROADCAST_SHAPE, - QCOM_PASS_SKIP_ADVANCED_REQUANT, + QCOM_PASS_ACTIVATE_KEY, + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, QCOM_QNN_COMPILE_SPEC, QCOM_QUANTIZED_IO, ) @@ -89,10 +92,12 @@ from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.capture import ExecutorchBackendConfig from executorch.exir.lowered_backend_module import LoweredBackendModule +from executorch.exir.passes import PassManager from executorch.exir.program._program import _get_updated_graph_signature from torch._decomp import core_aten_decompositions, remove_decompositions from torch.export.exported_program import ExportedProgram from torch.fx import passes +from torch.fx.passes.infra.pass_manager import this_before_that_pass_constraint from torch.fx.passes.operator_support import OperatorSupportBase from torch.library import Library @@ -299,33 +304,87 @@ def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]: return source_decompositions +def get_capture_program_passes(): + """ + Defines and returns the default ordered passes for the capture program. + This function creates an OrderedDict containing a series of default passes. + + Returns: + OrderedDict: An ordered dictionary containing all default passes along with their activation status and initialization parameters. + """ + + # The second value in each tuple in `default_passes_and_setting` indicates whether the corresponding pass is activated by default. + # If a pass is activated, it will be executed by default. + default_passes_and_setting = [ + (RemoveRedundancy, True), + (RecomposePixelUnshuffle, True), + (RecomposeRmsNorm, True), + (ConvertToLinear, True), + (ConvertPReLU, True), + (ConvertBmmToMatmul, True), + (ConvertInterpolateWithUpsample2D, True), + (I64toI32, True), + (AnnotateQuantAttrs, True), + (AnnotateAndQuantScalar, True), + (AnnotateDecomposed, True), + (FoldQDQ, True), + (ExpandBroadcastTensorShape, False), + (LayoutTransform, True), + (ReplaceIndexPutInput, True), + ] + + passes = OrderedDict() + for p, act in default_passes_and_setting: + init_signature = inspect.signature(p.__init__) + + args_kwargs_defaults = { + k: v.default if v.default is not inspect.Parameter.empty else None + for k, v in init_signature.parameters.items() + if k != "self" + } + + passes[p] = { + QCOM_PASS_ACTIVATE_KEY: act, + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY: args_kwargs_defaults, + } + + return passes + + +def _topological_sort_passes(passes: OrderedDict): + dep_table = get_passes_dependency_for_capture_program() + pm = PassManager() + for p in passes: + pm.add_pass(p) + + for that, these in dep_table.items(): + for this in these: + pm.add_constraint(this_before_that_pass_constraint(this, that)) + + pm.solve_constraints() + sorted_passes = OrderedDict() + for p in pm.passes: + sorted_passes[p] = passes[p] + return sorted_passes + + def _transform( - edge_program: ExportedProgram, custom_pass_config: FrozenSet[str] = frozenset() + edge_program: ExportedProgram, passes_job: OrderedDict = None ) -> ExportedProgram: # currently ExirExportedProgram.transform does not accept # changes of input number which was caused by FoldQDQ # apply passes one by one here to avoid IR capture failure graph_module = edge_program.graph_module - RemoveRedundancy()(graph_module) - RecomposePixelUnshuffle()(graph_module) - RecomposeRmsNorm()(graph_module) - ConvertToLinear()(graph_module) - ConvertPReLU(edge_program)(graph_module) - ConvertBmmToMatmul()(graph_module) - ConvertInterpolateWithUpsample2D()(graph_module) - I64toI32(edge_program)(graph_module) - AnnotateQuantAttrs( - edge_program, QCOM_PASS_SKIP_ADVANCED_REQUANT in custom_pass_config - )(graph_module) - AnnotateAndQuantScalar(edge_program)(graph_module) - AnnotateDecomposed(edge_program)(graph_module) - FoldQDQ()(graph_module) - # this pass is not necessary for network without layout-sensitive ops - # enable defaultly will introduce overhead from extra view_copy nodes - if QCOM_PASS_EXPAND_BROADCAST_SHAPE in custom_pass_config: - ExpandBroadcastTensorShape()(graph_module) - LayoutTransform(edge_program)(graph_module) - ReplaceIndexPutInput(edge_program)(graph_module) + passes_job = passes_job if passes_job is not None else get_capture_program_passes() + passes_job = _topological_sort_passes(passes_job) + for p in passes_job: + if not passes_job[p][QCOM_PASS_ACTIVATE_KEY]: + continue + + kwargs = passes_job[p][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY] + if "edge_program" in kwargs: + kwargs["edge_program"] = edge_program + p(**kwargs)(graph_module) # Since QDQ nodes are stripped, update graph signature again to validate program edge_program._graph_signature = _get_updated_graph_signature( @@ -339,7 +398,7 @@ def _transform( def capture_program( module: torch.nn.Module, inputs: Tuple[torch.Tensor], - custom_pass_config: FrozenSet[str] = frozenset(), + passes_job: OrderedDict = None, ) -> exir.ExirExportedProgram: ep = torch.export.export(module, inputs, strict=True) decomposed_ep = ep.run_decompositions(get_decomp_table()) @@ -350,7 +409,8 @@ def capture_program( core_ep = ExirExportedProgram(decomposed_ep, False) core_ep.transform(ConvertBinaryOpsWithScalar()) edge_ep = core_ep.to_edge(qnn_edge_config()) - _transform(edge_ep.exported_program, custom_pass_config) + + _transform(edge_ep.exported_program, passes_job) return edge_ep @@ -906,28 +966,34 @@ def generate_multi_graph_program( def generate_composite_llama_program( + llama_model: torch.nn.Module, graph_names: List[str], sample_inputs_list: List[Tuple[Any]], lower_module_dict: Dict[str, List[LoweredBackendModule]], call_delegate_node_name_dict: Dict[str, List[str]], call_delegate_inputs_dict: Dict[str, List[Tuple[str, int | None]]], outputs_dict: Dict[str, List[Tuple[str, int]]], + embedding_quantize: str, backend_config: ExecutorchBackendConfig = None, constant_methods: Optional[Dict[str, Any]] = None, ) -> ExecutorchProgramManager: class CompositeLlamaModule(torch.nn.Module): def __init__( self, + llama_model, lower_module_list, call_delegate_node_name_list, call_delegate_inputs_list, outputs_list, + embedding_quantize, ) -> None: super().__init__() + self.llama_model = llama_model self.lower_module_list = lower_module_list self.call_delegate_node_name_list = call_delegate_node_name_list self.call_delegate_inputs_list = call_delegate_inputs_list self.outputs_list = outputs_list + self.embedding_quantize = embedding_quantize def reorder( self, @@ -960,6 +1026,13 @@ def forward( } for num, arg in enumerate(args): module_input_dict[f"args_{num}"] = arg + + if self.embedding_quantize: + hidden_states = self.llama_model.tok_embeddings(tokens) + module_input_dict["quantized_decomposed_embedding_4bit_dtype"] = ( + hidden_states + ) + for lower_module, call_delegate_node_name, call_delegate_inputs in zip( self.lower_module_list, self.call_delegate_node_name_list, @@ -976,10 +1049,12 @@ def forward( progs_dict = {} for graph_name, sample_inputs in zip(graph_names, sample_inputs_list): composite_llama_module = CompositeLlamaModule( + llama_model, lower_module_dict[graph_name], call_delegate_node_name_dict[graph_name], call_delegate_inputs_dict[graph_name], outputs_dict[graph_name], + embedding_quantize, ) prog = torch.export.export(composite_llama_module, sample_inputs) progs_dict[graph_name] = prog diff --git a/examples/qualcomm/oss_scripts/fastvit.py b/examples/qualcomm/oss_scripts/fastvit.py index 0e2c695ab3..f0d2f4c3f0 100644 --- a/examples/qualcomm/oss_scripts/fastvit.py +++ b/examples/qualcomm/oss_scripts/fastvit.py @@ -10,6 +10,9 @@ import numpy as np import torch +from executorch.backends.qualcomm._passes.expand_broadcast_tensor_shape import ( + ExpandBroadcastTensorShape, +) from executorch.backends.qualcomm.quantizer.annotators import ( QuantizationConfig, QuantizationSpec, @@ -23,10 +26,11 @@ ) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype -from executorch.backends.qualcomm.utils.constants import ( - QCOM_PASS_EXPAND_BROADCAST_SHAPE, +from executorch.backends.qualcomm.utils.constants import QCOM_PASS_ACTIVATE_KEY +from executorch.backends.qualcomm.utils.utils import ( + convert_linear_to_conv2d, + get_capture_program_passes, ) -from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d from executorch.examples.qualcomm.utils import ( build_executorch_binary, get_imagenet_dataset, @@ -111,6 +115,8 @@ def main(args): bias=q_config.bias, ) # lower to QNN + passes_job = get_capture_program_passes() + passes_job[ExpandBroadcastTensorShape][QCOM_PASS_ACTIVATE_KEY] = True build_executorch_binary( convert_linear_to_conv2d(get_instance(args.oss_repo, args.pretrained_weight)), inputs[0], @@ -121,7 +127,7 @@ def main(args): skip_node_op_set=skip_node_op_set, quant_dtype=QuantDtype.use_8a8w, custom_quantizer=quantizer, - custom_pass_config={QCOM_PASS_EXPAND_BROADCAST_SHAPE}, + passes_job=passes_job, shared_buffer=args.shared_buffer, ) diff --git a/examples/qualcomm/oss_scripts/llama/CMakeLists.txt b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt index d7d355ee4d..4059ae7151 100644 --- a/examples/qualcomm/oss_scripts/llama/CMakeLists.txt +++ b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt @@ -45,6 +45,8 @@ target_include_directories( qnn_llama_runner PUBLIC ${_common_include_directories} ) +target_link_options_shared_lib(quantized_ops_lib) + target_link_libraries( qnn_llama_runner qnn_executorch_backend @@ -55,6 +57,8 @@ target_link_libraries( gflags re2::re2 custom_ops + quantized_ops_lib + quantized_kernels ) target_compile_options( qnn_llama_runner PUBLIC ${_common_compile_options} diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index e80e0c2808..09f9ce4444 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -14,10 +14,12 @@ import os import sys import time +from collections import OrderedDict from functools import partial from multiprocessing.connection import Client import torch +from executorch.backends.qualcomm._passes.i64_to_i32 import I64toI32 from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner @@ -29,7 +31,15 @@ from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset -from executorch.backends.qualcomm.utils.constants import QCOM_QUANTIZED_IO + +from executorch.backends.qualcomm.serialization.qc_schema_serialize import ( + flatbuffer_to_option, + option_to_flatbuffer, +) +from executorch.backends.qualcomm.utils.constants import ( + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, + QCOM_QUANTIZED_IO, +) from executorch.backends.qualcomm.utils.utils import ( capture_program, convert_linear_to_conv2d, @@ -37,9 +47,13 @@ generate_htp_compiler_spec, generate_multi_graph_program, generate_qnn_executorch_compiler_spec, + get_capture_program_passes, get_soc_to_chipset_map, update_spill_fill_size, ) +from executorch.examples.models.llama.source_transformation.quantize import ( + get_quant_embedding_transform, +) from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import ( LlamaModel, @@ -108,6 +122,7 @@ def _kv_calibrate( tokenizer, max_seq_len=512, updator=smart_mask_updator, + use_i64_token=False, ): _, atten_mask, _, k_caches, v_caches = example_inputs @@ -128,8 +143,10 @@ def _kv_calibrate( with torch.no_grad(): while token_list[-1] != tokenizer.eos_id and pos < max_cache_len: + dtype = torch.int64 if use_i64_token else torch.int32 + token = torch.full((1, 1), token_list[pos], dtype=dtype) logits, new_k_caches, new_v_caches = module( - torch.full((1, 1), token_list[pos], dtype=torch.int32), + token, atten_mask, torch.full((1, 1), pos), *k_caches, @@ -150,6 +167,7 @@ def _prefill_calibrate( module: torch.fx.GraphModule, tokenizer, max_seq_len=512, + use_i64_token=False, ): _, atten_mask = example_inputs max_cache_len = max_seq_len - 1 @@ -168,15 +186,16 @@ def _prefill_calibrate( raise RuntimeError("Unkown tokenizer") pos = len(token_list) + dtype = torch.int64 if use_i64_token else torch.int32 with torch.no_grad(): while token_list[-1] != tokenizer.eos_id and pos < max_cache_len: - tmp_token_list = torch.tensor(token_list).reshape(1, -1) + tmp_token_list = torch.tensor(token_list, dtype=dtype).reshape(1, -1) if pos < max_cache_len: tmp_token_list = torch.cat( [ tmp_token_list, - torch.zeros((1, max_cache_len - pos), dtype=torch.int32), + torch.zeros((1, max_cache_len - pos), dtype=dtype), ], dim=1, ) @@ -197,6 +216,7 @@ def calibrate( tokenizer, max_seq_len=512, kv_updator=smart_mask_updator, + use_i64_token=False, ): if len(example_inputs) == 2: _prefill_calibrate( @@ -205,6 +225,7 @@ def calibrate( module, tokenizer, max_seq_len, + use_i64_token, ) elif len(example_inputs) == 5: _kv_calibrate( @@ -214,6 +235,7 @@ def calibrate( tokenizer, max_seq_len, updator=kv_updator, + use_i64_token=use_i64_token, ) else: raise RuntimeError("Get wrong inputs") @@ -235,6 +257,7 @@ def __init__(self, llama_model, pte_filename) -> None: else: tokens, atten_mask = self.get_example_inputs(use_kv_cache=False) self.inputs = (tokens, atten_mask) + self.llama_graph_module = llama_model def _tag_ios(self, gm: torch.fx.GraphModule, fixed_point_type): if not self.has_quant_io: @@ -340,7 +363,7 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): with torch.no_grad(): fx_graph_module = torch.export.export( - self.llama_model, self.inputs, strict=True + self.llama_graph_module, self.inputs, strict=True ).module() fx_graph_module = prepare_pt2e(fx_graph_module, quantizer) @@ -352,9 +375,10 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): tokenizer=tokenizer, max_seq_len=self.llama_meta["get_max_seq_len"], kv_updator=args.kv_updator, + use_i64_token=args.embedding_quantize is not None, ) - self.llama_model = convert_pt2e(fx_graph_module) + self.llama_graph_module = convert_pt2e(fx_graph_module) def lowering_modules( self, @@ -362,7 +386,8 @@ def lowering_modules( fixed_point_type, use_fp16=False, soc_model=QcomChipset.SM8650, - num_sharding=0, + num_sharding=1, + passes_job=OrderedDict(), shared_buffer=False, ): executorch_config = ExecutorchBackendConfig( @@ -379,7 +404,7 @@ def lowering_modules( with torch.no_grad(): # backend option backend_options = generate_htp_compiler_spec( - use_fp16=use_fp16, use_multi_contexts=num_sharding > 0 + use_fp16=use_fp16, use_multi_contexts=num_sharding > 1 ) compiler_specs = generate_qnn_executorch_compiler_spec( soc_model=soc_model, @@ -391,10 +416,12 @@ def lowering_modules( compiler_specs, skip_node_op_set=skip_node_op_set ) edge_prog = capture_program( - self.llama_model, self.inputs, custom_pass_config=frozenset() + self.llama_graph_module, + self.inputs, + passes_job, ) - if num_sharding > 0: + if num_sharding > 1: model_sharding.split_graph( edge_prog.exported_program, self.llama_meta["get_n_layers"], @@ -411,7 +438,7 @@ def lowering_modules( compile_config=EdgeCompileConfig(_check_ir_validity=False), ) edge_prog_mgr = edge_prog_mgr.to_backend(partitioner) - if num_sharding > 0: + if num_sharding > 1: update_spill_fill_size(edge_prog_mgr.exported_program()) exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) with open(f"{work_space}/{self.pte_filename}.pte", "wb") as file: @@ -444,21 +471,34 @@ def compile(args, pte_filename, tokenizer): ) llama_instance_list = [] + use_i64_token = args.embedding_quantize is not None with torch.device("meta"): if args.model_mode == "kv": llama_instance_list.append( - LlamaModel(kv_config, output_new_cache_only=True) + LlamaModel( + kv_config, output_new_cache_only=True, use_i64_token=use_i64_token + ) ) elif args.model_mode == "prefill": llama_instance_list.append( - LlamaModel(prefill_config, output_new_cache_only=False) + LlamaModel( + prefill_config, + output_new_cache_only=False, + use_i64_token=use_i64_token, + ) ) elif args.model_mode == "hybrid": llama_instance_list.append( - LlamaModel(kv_config, output_new_cache_only=True) + LlamaModel( + kv_config, output_new_cache_only=True, use_i64_token=use_i64_token + ) ) llama_instance_list.append( - LlamaModel(prefill_config, output_new_cache_only=False) + LlamaModel( + prefill_config, + output_new_cache_only=False, + use_i64_token=use_i64_token, + ) ) else: raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") @@ -500,6 +540,7 @@ def compile(args, pte_filename, tokenizer): assert args.tokenizer_model is not None, "Need tokenizer model for calibration" + passes_job = get_capture_program_passes() if args.dtype_override is not None: dtype_override = DType[args.dtype_override] for i in range(len(llama_instance_list)): @@ -508,6 +549,13 @@ def compile(args, pte_filename, tokenizer): ) for i in range(len(llama_instance_list)): + if args.embedding_quantize: + llama_instance_list[i] = get_quant_embedding_transform(args)( + llama_instance_list[i] + ) + passes_job[I64toI32][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY]["skip_node"] = { + "tokens" + } llama_instance_list[i] = convert_linear_to_conv2d(llama_instance_list[i]) llama_instance_list[i] = SingleLlama( llama_instance_list[i].eval(), pte_filename @@ -532,7 +580,7 @@ def compile(args, pte_filename, tokenizer): # If hybrid mode, we store kv output quant_attrs and apply to prefill output quant_attrs later if i == 0 and args.model_mode == "hybrid": output_indices = 0 - for node in llama_instance.llama_model.graph.nodes: + for node in llama_instance.llama_graph_module.graph.nodes: if node.op == "output": for output in node.args[0]: kv_quant_attrs[output_indices] = output.args[1:] @@ -550,28 +598,33 @@ def compile(args, pte_filename, tokenizer): start_lowering_ts = time.time() quant_attrs = None - if len(llama_instance_list) == 1: + if args.model_mode in ["kv", "prefill"]: llama_instance_list[0].lowering_modules( args.artifact, fixed_point_type, use_fp16=use_fp16, soc_model=get_soc_to_chipset_map()[args.model], num_sharding=args.num_sharding, + passes_job=passes_job, shared_buffer=args.shared_buffer, ) quant_attrs = llama_instance_list[0].get_quant_attrs() - else: + elif args.model_mode == "hybrid": sample_inputs_list = [ llama_instace.inputs for llama_instace in llama_instance_list ] edge_progs = [ - capture_program(llama_instance.llama_model, sample_input) + capture_program( + llama_instance.llama_graph_module, + sample_input, + passes_job=passes_job, + ) for llama_instance, sample_input in zip( llama_instance_list, sample_inputs_list ) ] - if args.num_sharding > 0: + if args.num_sharding > 1: for i in range(len(llama_instance_list)): model_sharding.split_graph( edge_progs[i].exported_program, @@ -585,7 +638,7 @@ def compile(args, pte_filename, tokenizer): fixed_point_type, ) backend_options = generate_htp_compiler_spec( - use_fp16=use_fp16, use_multi_contexts=args.num_sharding > 0 + use_fp16=use_fp16, use_multi_contexts=args.num_sharding > 1 ) graph_names = ["kv_forward", "prefill_forward"] compiler_specs = [ @@ -606,9 +659,13 @@ def compile(args, pte_filename, tokenizer): ) for i, edge_prog in enumerate(edge_progs) ] - if args.num_sharding > 0: - for exported_program in exported_programs: - update_spill_fill_size(exported_program) + if args.num_sharding > 1: + max_sf_size = update_spill_fill_size(exported_programs) + qnn_executorch_options = flatbuffer_to_option(compiler_specs[0][0].value) + qnn_executorch_options.backend_options.htp_options.max_sf_buf_size = ( + max_sf_size + ) + compiler_specs[0][0].value = option_to_flatbuffer(qnn_executorch_options) executorch_config = ExecutorchBackendConfig( # For shared buffer, user must pass the memory address @@ -622,6 +679,7 @@ def compile(args, pte_filename, tokenizer): extract_delegate_segments=True, ) + bundle_progs_list = [] lower_module_dict = {name: [] for name in graph_names} call_delegate_inputs_dict = {name: [] for name in graph_names} call_delegate_node_name_dict = {name: [] for name in graph_names} @@ -637,11 +695,17 @@ def compile(args, pte_filename, tokenizer): call_delegate_inputs_list = [] for arg in node.args: if arg.op == "call_function": - while "getitem" not in arg.name: - arg = arg.args[0] - call_delegate_inputs_list.append( - (arg.args[0].name, arg.args[1]) - ) + if ( + arg.target + == exir_ops.edge.quantized_decomposed.embedding_4bit.dtype + ): + call_delegate_inputs_list.append((arg.name, None)) + else: + while "getitem" not in arg.name: + arg = arg.args[0] + call_delegate_inputs_list.append( + (arg.args[0].name, arg.args[1]) + ) elif arg.op == "placeholder": call_delegate_inputs_list.append((arg.name, None)) # No extra needs to do for get_attr node @@ -651,88 +715,52 @@ def compile(args, pte_filename, tokenizer): elif node.op == "output": for arg in node.args[0]: outputs_dict[graph_name].append((arg.args[0].name, arg.args[1])) - - if args.num_sharding > 0: - bundle_progs_list = [] - for num in range(args.num_sharding - 1, -1, -1): - processed_bytes = [] - for prog, graph_name in zip(exported_programs, graph_names): - processed_bytes.append( - getattr( - prog.graph_module, f"lowered_module_{num}" - ).processed_bytes - ) - - call_delegate_node = [ - list(node.users.keys())[0] - for node in prog.graph_module.graph.nodes - if node.op == "get_attr" - and node.name == f"lowered_module_{num}" - ] - input_nodes_dict[graph_name] = [ - node - for node in call_delegate_node[0].args - if node.op == "placeholder" - ] - - prog_mgr, bundle_progs = generate_multi_graph_program( - compiler_specs=compiler_specs[0], - processed_bytes=processed_bytes, - input_nodes_dict=input_nodes_dict, - backend_config=executorch_config, - constant_methods=llama_instance_list[ - 1 - ].llama_meta, # kv method meta - ) - bundle_progs_list.append(bundle_progs) - for graph_name in graph_names: - lower_module_dict[graph_name].append( - prog_mgr.exported_program(graph_name).graph_module._modules.get( - "lowered_module_0" - ) - ) - - exec_prog = generate_composite_llama_program( - graph_names=graph_names, - sample_inputs_list=sample_inputs_list, - lower_module_dict=lower_module_dict, - call_delegate_node_name_dict=call_delegate_node_name_dict, - call_delegate_inputs_dict=call_delegate_inputs_dict, - outputs_dict=outputs_dict, - backend_config=executorch_config, - constant_methods=llama_instance_list[0].llama_meta, # kv method meta - ) - with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file: - exec_prog.write_to_file(file) - else: + for num in range(args.num_sharding - 1, -1, -1): processed_bytes = [] - input_nodes_dict = {name: [] for name in graph_names} - output_nodes_dict = {name: [] for name in graph_names} for prog, graph_name in zip(exported_programs, graph_names): processed_bytes.append( - prog.graph_module.lowered_module_0.processed_bytes + getattr(prog.graph_module, f"lowered_module_{num}").processed_bytes ) - input_nodes_dict[graph_name] = [ - node + call_delegate_node = [ + list(node.users.keys())[0] for node in prog.graph_module.graph.nodes - if node.op == "placeholder" + if node.op == "get_attr" and node.name == f"lowered_module_{num}" ] - output_nodes_dict[graph_name] = [ + input_nodes_dict[graph_name] = [ node - for node in prog.graph_module.graph.nodes - if node.op == "output" + for node in call_delegate_node[0].args + if node.op == "placeholder" + or node.target + == exir_ops.edge.quantized_decomposed.embedding_4bit.dtype ] - - prog_mgr, _ = generate_multi_graph_program( + prog_mgr, bundle_progs = generate_multi_graph_program( compiler_specs=compiler_specs[0], processed_bytes=processed_bytes, input_nodes_dict=input_nodes_dict, - output_nodes_dict=output_nodes_dict, backend_config=executorch_config, constant_methods=llama_instance_list[0].llama_meta, # kv method meta ) - with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file: - prog_mgr.write_to_file(file) + bundle_progs_list.append(bundle_progs) + for graph_name in graph_names: + lower_module_dict[graph_name].append( + prog_mgr.exported_program(graph_name).graph_module._modules.get( + "lowered_module_0" + ) + ) + exec_prog = generate_composite_llama_program( + llama_model=llama_instance_list[1].llama_model, + graph_names=graph_names, + sample_inputs_list=sample_inputs_list, + lower_module_dict=lower_module_dict, + call_delegate_node_name_dict=call_delegate_node_name_dict, + call_delegate_inputs_dict=call_delegate_inputs_dict, + outputs_dict=outputs_dict, + embedding_quantize=args.embedding_quantize, + backend_config=executorch_config, + constant_methods=llama_instance_list[1].llama_meta, # kv method meta + ) + with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file: + exec_prog.write_to_file(file) end_lowering_ts = time.time() logging.info(f"Time for compiling: {end_lowering_ts - start_lowering_ts}") @@ -910,7 +938,7 @@ def main(): parser.add_argument( "--num_sharding", type=int, - default=0, + default=1, help="Specify the number of splits by inserting the fallback custom op. The graph will be split evenly by layers.", ) @@ -944,6 +972,14 @@ def main(): type=str, ) + parser.add_argument( + "-E", + "--embedding-quantize", + default=None, + type=str, + help="Fallback to cpu embedding operator and type of embedding quantization, ',', e.g., '4,32'.", + ) + args = parser.parse_args() if args.compile_only and args.pre_gen_pte: exit("Cannot set both compile_only and pre_gen_pte as true") diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index d1b618ed07..253abc9578 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -12,10 +12,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from executorch.examples.models.llama.llama_transformer import ( - ModelArgs, - precompute_freqs_cis, -) +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.rope import precompute_freqs_cis def apply_rotary_emb_single( @@ -299,7 +297,9 @@ def forward( class LlamaModel(nn.Module): - def __init__(self, config: ModelArgs, output_new_cache_only=True): + def __init__( + self, config: ModelArgs, output_new_cache_only=True, use_i64_token=False + ): super().__init__() self.dim = config.dim self.head_dim = config.dim // config.n_heads @@ -312,6 +312,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=True): self.rope_freq_base = config.rope_freq_base self.use_kv_cache = config.use_kv_cache self.output_new_cache_only = output_new_cache_only + self.use_i64_token = use_i64_token self.layers = nn.ModuleList( [ @@ -390,10 +391,12 @@ def forward( return logits, output_k_cache, output_v_cache def get_example_inputs(self, use_kv_cache=True): + dtype = torch.int64 if self.use_i64_token else torch.int32 if use_kv_cache: tokens = torch.randint( - self.vocab_size, (self.max_batch_size, 1), dtype=torch.int32 + self.vocab_size, (self.max_batch_size, 1), dtype=dtype ) + pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32) k_cache, v_cache = [], [] atten_mask = torch.full((self.max_batch_size, self.max_seq_len), -255.0) @@ -424,7 +427,7 @@ def get_example_inputs(self, use_kv_cache=True): ) max_promp = self.max_seq_len - 1 - tokens = torch.arange(0, max_promp, 1, dtype=torch.int32).unsqueeze(0) + tokens = torch.arange(0, max_promp, 1, dtype=dtype).unsqueeze(0) atten_mask = torch.triu(torch.rand((max_promp, max_promp)), 1) atten_mask[atten_mask != 0] = -255 return ( diff --git a/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp b/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp index b2fcef9149..7992913a58 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp @@ -62,7 +62,8 @@ ShiftPointerIoMgr::ShiftPointerIoMgr( int32_t num_heads, EvalMode eval_mode, const std::string& prefill_forward_name, - const std::string& kv_forward_name) + const std::string& kv_forward_name, + const bool use_int64_token) : IoMgrBase(modules), shard_layers_({num_layers}), kv_cache_len_(kv_cache_len), @@ -73,7 +74,8 @@ ShiftPointerIoMgr::ShiftPointerIoMgr( num_heads_(num_heads), eval_mode_(eval_mode), prefill_forward_name_(prefill_forward_name), - kv_forward_name_(kv_forward_name) { + kv_forward_name_(kv_forward_name), + use_int64_token_(use_int64_token) { if (!prefill_forward_name_.empty()) { input_tensors_[prefill_forward_name_] = std::vector>(modules.size()); @@ -399,7 +401,8 @@ void ShiftPointerIoMgr::update_prefill_to_kv_io( prefill_cache_len_ != 0, "prefill_cache_len_ should not equal to 0"); IO* ptr = static_cast(data_ptr_.get()); - ptr->input_tok = static_cast(cur_token); + ptr->input_tok = + use_int64_token_ ? cur_token : static_cast(cur_token); ptr->input_pos = static_cast(pos); // If prompt len is 30, prefill will handle to pos = 30. // At this point, pos should be 31. @@ -455,7 +458,8 @@ void ShiftPointerIoMgr::update_kv_io( std::vector>& output_tensors) { IO* ptr = static_cast(data_ptr_.get()); // update input_tok - ptr->input_tok = static_cast(cur_token); + ptr->input_tok = + use_int64_token_ ? cur_token : static_cast(cur_token); // update position_ids ptr->input_pos = static_cast(pos); // update causal mask for next token @@ -503,20 +507,39 @@ void ShiftPointerIoMgr::update_prefill_io( std::vector>& output_tensors) { (void)output_tensors; IO* ptr = static_cast(data_ptr_.get()); - ptr->prefill_input_toks[pos] = static_cast(cur_token); + // Support CPU 4-bit embedding, which requires int64 input. + // However, for QNN embedding, only int32 input is needed. + // Therefore, we need to cast to the correct type to write the data. + if (use_int64_token_) { + ptr->prefill_input_toks[pos] = cur_token; + } else { + int32_t* prefill_input_toks_ptr = + reinterpret_cast(ptr->prefill_input_toks.data()); + prefill_input_toks_ptr[pos] = static_cast(cur_token); + } } void ShiftPointerIoMgr::fill_prefill_toks( std::vector& prompt_tokens) { IO* ptr = static_cast(get_mutable_ptr()); for (int i = 0; i < prompt_tokens.size(); i++) { - ptr->prefill_input_toks[i] = static_cast(prompt_tokens[i]); + // Support CPU 4-bit embedding, which requires int64 input. + // However, for QNN embedding, only int32 input is needed. + // Therefore, we need to cast to the correct type to write the data. + if (use_int64_token_) { + ptr->prefill_input_toks[i] = prompt_tokens[i]; + } else { + int32_t* prefill_input_toks_ptr = + reinterpret_cast(ptr->prefill_input_toks.data()); + prefill_input_toks_ptr[i] = static_cast(prompt_tokens[i]); + } } } void ShiftPointerIoMgr::fill_kv_tok_mask(int64_t pos, int64_t cur_token) { IO* ptr = static_cast(get_mutable_ptr()); - ptr->input_tok = static_cast(cur_token); + ptr->input_tok = + use_int64_token_ ? cur_token : static_cast(cur_token); ptr->kv_attention_mask[kv_cache_len_] = 65535; } @@ -530,7 +553,8 @@ SmartMaskIoMgr::SmartMaskIoMgr( int32_t num_heads, EvalMode eval_mode, const std::string& prefill_forward_name, - const std::string& kv_forward_name) + const std::string& kv_forward_name, + const bool use_int64_token) : IoMgrBase(modules), shard_layers_({num_layers}), prefill_cache_len_(prefill_cache_len), @@ -541,7 +565,8 @@ SmartMaskIoMgr::SmartMaskIoMgr( num_heads_(num_heads), eval_mode_(eval_mode), prefill_forward_name_(prefill_forward_name), - kv_forward_name_(kv_forward_name) { + kv_forward_name_(kv_forward_name), + use_int64_token_(use_int64_token) { if (!prefill_forward_name_.empty()) { input_tensors_[prefill_forward_name_] = std::vector>(modules.size()); @@ -630,7 +655,7 @@ void SmartMaskIoMgr::IO::init_io_ptrs( std::string key = iter.first; size_t size = iter.second; if (key == "input_tok_bytes") { - input_tok = reinterpret_cast(cur_ptr); + input_tok = reinterpret_cast(cur_ptr); } else if (key == "input_pos_bytes") { input_pos = reinterpret_cast(cur_ptr); } else if (key == "cache_in_bytes" || key == "cache_out_bytes") { @@ -659,7 +684,7 @@ void SmartMaskIoMgr::IO::init_io_ptrs( } else if (key == "kv_logits_bytes") { kv_logits = reinterpret_cast(cur_ptr); } else if (key == "prefill_input_toks_bytes") { - prefill_input_toks = reinterpret_cast(cur_ptr); + prefill_input_toks = reinterpret_cast(cur_ptr); } else if (key == "prefill_atten_mask_bytes") { prefill_atten_mask = reinterpret_cast(cur_ptr); } else if (key == "prefill_logits_bytes") { @@ -681,7 +706,7 @@ void SmartMaskIoMgr::IO::add_custom_mem_info( executorch::runtime::TensorInfo& tensor_info) { if (auto it = io_pos_map.find(static_cast(ptr)); it == io_pos_map.end()) { - ET_LOG(Error, "Shared buffer pointer %p is not found %p", ptr); + ET_LOG(Error, "Shared buffer pointer %p is not found", ptr); } size_t pos = io_pos_map[static_cast(ptr)]; uint32_t rank = tensor_info.sizes().size(); @@ -890,7 +915,8 @@ void SmartMaskIoMgr::update_kv_io( IO* ptr = static_cast(data_ptr_.get()); size_t cache_len = std::max(kv_cache_len_, prefill_cache_len_); // update input_tok - *ptr->input_tok = static_cast(cur_token); + *ptr->input_tok = + use_int64_token_ ? cur_token : static_cast(cur_token); // update position_ids *ptr->input_pos = static_cast(pos); // update smart mask for previous cache @@ -1033,7 +1059,8 @@ void SmartMaskIoMgr::update_prefill_to_kv_io( std::vector>& output_tensors) { IO* ptr = static_cast(data_ptr_.get()); - *ptr->input_tok = static_cast(cur_token); + *ptr->input_tok = + use_int64_token_ ? cur_token : static_cast(cur_token); *ptr->input_pos = static_cast(pos); // pos means the cur_token pos for (int i = 0; i < pos; i++) { @@ -1061,19 +1088,38 @@ void SmartMaskIoMgr::update_prefill_io( std::vector>& output_tensors) { (void)output_tensors; IO* ptr = static_cast(data_ptr_.get()); - ptr->prefill_input_toks[pos] = static_cast(cur_token); + // Support CPU 4-bit embedding, which requires int64 input. + // However, for QNN embedding, only int32 input is needed. + // Therefore, we need to cast to the correct type to write the data. + if (use_int64_token_) { + ptr->prefill_input_toks[pos] = cur_token; + } else { + int32_t* prefill_input_toks_ptr = + reinterpret_cast(ptr->prefill_input_toks); + prefill_input_toks_ptr[pos] = static_cast(cur_token); + } } void SmartMaskIoMgr::fill_prefill_toks(std::vector& prompt_tokens) { IO* ptr = static_cast(get_mutable_ptr()); for (int i = 0; i < prompt_tokens.size(); i++) { - ptr->prefill_input_toks[i] = static_cast(prompt_tokens[i]); + // Support CPU 4-bit embedding, which requires int64 input. + // However, for QNN embedding, only int32 input is needed. + // Therefore, we need to cast to the correct type to write the data. + if (use_int64_token_) { + ptr->prefill_input_toks[i] = prompt_tokens[i]; + } else { + int32_t* prefill_input_toks_ptr = + reinterpret_cast(ptr->prefill_input_toks); + prefill_input_toks_ptr[i] = static_cast(prompt_tokens[i]); + } } } void SmartMaskIoMgr::fill_kv_tok_mask(int64_t pos, int64_t cur_token) { IO* ptr = static_cast(get_mutable_ptr()); - *ptr->input_tok = static_cast(cur_token); + *ptr->input_tok = + use_int64_token_ ? cur_token : static_cast(cur_token); ptr->kv_attention_mask[kv_cache_len_] = 65535; } diff --git a/examples/qualcomm/oss_scripts/llama/runner/io_manager.h b/examples/qualcomm/oss_scripts/llama/runner/io_manager.h index e86b2eab87..3a59ab6924 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/io_manager.h +++ b/examples/qualcomm/oss_scripts/llama/runner/io_manager.h @@ -89,7 +89,8 @@ class ShiftPointerIoMgr : public IoMgrBase { int32_t num_heads, EvalMode eval_mode, const std::string& prefill_forward_name, - const std::string& kv_forward_name); + const std::string& kv_forward_name, + const bool use_int64_token); void init_io() override; void prepare_prefill_io( @@ -118,14 +119,14 @@ class ShiftPointerIoMgr : public IoMgrBase { std::vector>& output_tensors) override; struct IO { - int32_t input_tok; + int64_t input_tok; int32_t input_pos; std::vector>> k_cache; std::vector> v_cache; std::vector> k_cache_out; std::vector kv_attention_mask; std::vector kv_logits; - std::vector prefill_input_toks; + std::vector prefill_input_toks; std::vector prefill_atten_mask; std::vector prefill_logits; }; @@ -165,6 +166,7 @@ class ShiftPointerIoMgr : public IoMgrBase { EvalMode eval_mode_; std::string prefill_forward_name_; std::string kv_forward_name_; + const bool use_int64_token_{false}; }; class SmartMaskIoMgr : public IoMgrBase { @@ -179,7 +181,8 @@ class SmartMaskIoMgr : public IoMgrBase { int32_t num_heads, EvalMode eval_mode, const std::string& prefill_forward_name, - const std::string& kv_forward_name); + const std::string& kv_forward_name, + const bool use_int64_token); void init_io() override; void prepare_prefill_io( @@ -213,7 +216,7 @@ class SmartMaskIoMgr : public IoMgrBase { struct IO { void* shared_buffer_base; - int32_t* input_tok; + int64_t* input_tok; int32_t* input_pos; // layer -> head -> head_dim * seq_len std::vector> k_cache; @@ -225,7 +228,7 @@ class SmartMaskIoMgr : public IoMgrBase { uint16_t* kv_attention_mask; // vocab_size uint16_t* kv_logits; - int32_t* prefill_input_toks; + int64_t* prefill_input_toks; // prefill_cache_len_ ^ 2 uint16_t* prefill_atten_mask; // vocab_size * prefill_cache_len_ @@ -283,6 +286,7 @@ class SmartMaskIoMgr : public IoMgrBase { EvalMode eval_mode_; std::string prefill_forward_name_; std::string kv_forward_name_; + const bool use_int64_token_{false}; }; } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index 158c6a13ca..4b45863147 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -125,6 +125,8 @@ Error Runner::load() { int64_t head_dim = method_meta.output_tensor_meta(1)->sizes()[1]; // k_cache int64_t num_heads = (method_meta.num_outputs() - 1) / (num_layers * 2); vocab_size_ = method_meta.output_tensor_meta(0)->sizes()[2]; // logit_tensor + use_int64_token_ = method_meta.input_tensor_meta(0)->scalar_type() == + executorch::aten::ScalarType::Long; ET_CHECK_MSG(num_layers != -1, "Could not retrieve num layers"); if (kv_updator_ == "SmartMask") { @@ -138,7 +140,8 @@ Error Runner::load() { num_heads, eval_mode_, prefill_forward_name_, - kv_forward_name_); + kv_forward_name_, + use_int64_token_); } else if (kv_updator_ == "ShiftPointer") { io_mgr_ = std::make_unique( modules_, @@ -150,7 +153,8 @@ Error Runner::load() { num_heads, eval_mode_, prefill_forward_name_, - kv_forward_name_); + kv_forward_name_, + use_int64_token_); } else { ET_LOG(Error, "Using an unknown updator %s", kv_updator_.c_str()); } diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index 844ea32290..b6ba1360bf 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -106,6 +106,7 @@ class Runner { Stats stats_; std::unique_ptr io_mgr_; EvalMode eval_mode_; + bool use_int64_token_{false}; std::string prefill_forward_name_; std::string kv_forward_name_; std::vector method_names_; diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 23e384dee1..1ba15969e0 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -256,7 +256,7 @@ def build_executorch_binary( shared_buffer=False, metadata=None, dump_intermediate_outputs=False, - custom_pass_config=frozenset(), + passes_job=None, qat_training_data=None, ): """ @@ -296,9 +296,9 @@ def build_executorch_binary( annotated_model = ptq_calibrate(captured_model, quantizer, dataset) quantized_model = convert_pt2e(annotated_model) - edge_prog = capture_program(quantized_model, inputs, custom_pass_config) + edge_prog = capture_program(quantized_model, inputs, passes_job) else: - edge_prog = capture_program(model, inputs, custom_pass_config) + edge_prog = capture_program(model, inputs, passes_job) backend_options = generate_htp_compiler_spec( use_fp16=False if quant_dtype else True