From dbd9a7c2f5bfd07c97c58a150bd9e0a270fe2f8e Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Fri, 21 Jun 2024 14:10:36 +0800 Subject: [PATCH 1/5] Qualcomm AI Engine Direct - Enable llama model in quantied and fp Summary: - Fully delegate meta llama model in fp and quantized - Add simple calibration - Use custom fallback op to split graph - Add model sharding argument - Add splill fill feature Note that if you want to run llama 7b due to memory limitations on the device, you need to specify num_sharding. And it is recommended to reboot device before running to ensure that the device has enough memory. --- .../qualcomm/partition/qnn_partitioner.py | 17 +-- backends/qualcomm/passes/layout_transform.py | 2 + examples/models/llama2/export_llama_lib.py | 41 ++++++- examples/models/llama2/llama_transformer.py | 3 + examples/models/llama2/model.py | 10 +- examples/models/llama2/runner/runner.cpp | 24 ++-- examples/models/llama2/runner/runner.h | 2 +- .../llama2/source_transformation/sdpa.py | 105 +++++++++++++++++- extension/llm/custom_ops/model_sharding.py | 93 ++++++++++++++++ extension/llm/custom_ops/op_fallback.cpp | 48 ++++++++ extension/llm/custom_ops/op_fallback.h | 20 ++++ extension/llm/custom_ops/targets.bzl | 4 +- extension/llm/export/builder.py | 31 +++++- extension/llm/export/partitioner_lib.py | 26 ++--- extension/llm/export/quantizer_lib.py | 12 +- 15 files changed, 372 insertions(+), 66 deletions(-) create mode 100644 extension/llm/custom_ops/model_sharding.py create mode 100644 extension/llm/custom_ops/op_fallback.cpp create mode 100644 extension/llm/custom_ops/op_fallback.h diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index c3afc23daeb..90aa96f7146 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -40,16 +40,7 @@ def __init__( ): self.node_visitors = node_visitor.get_node_visitors(edge_program) - self.skip_node_op_builder_set = set() - if skip_node_op_set is not None: - self.skip_node_op_builder_set = set( - [ - self.node_visitors[val] - for val in skip_node_op_set - if val in self.node_visitors - ] - ) - + self.skip_node_op_set = skip_node_op_set self.skip_node_id_set = skip_node_id_set self.nodes_to_wrappers = defaultdict(dict) self.qnn_manager = PyQnnManager.QnnManager( @@ -69,11 +60,7 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool: print(f"[QNN Partitioner Op Support]: {node.target.__name__} | Skipped") return False - if ( - self.skip_node_op_builder_set is not None - and self.node_visitors[node.target.__name__] - in self.skip_node_op_builder_set - ): + if node.target.__name__ in self.skip_node_op_set: print(f"[QNN Partitioner Op Support]: {node.target.__name__} | Skipped") return False diff --git a/backends/qualcomm/passes/layout_transform.py b/backends/qualcomm/passes/layout_transform.py index bdee2c8196a..bd898e2decd 100644 --- a/backends/qualcomm/passes/layout_transform.py +++ b/backends/qualcomm/passes/layout_transform.py @@ -53,6 +53,8 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.hardswish.default, exir_ops.edge.aten.hardsigmoid.default, exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.index.Tensor, + exir_ops.edge.aten.index_put.default, exir_ops.edge.aten.leaky_relu.default, exir_ops.edge.aten.linear.default, exir_ops.edge.aten._log_softmax.default, diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index d3148c95421..0c21ecf78c4 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -52,7 +52,9 @@ from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis from .source_transformation.sdpa import ( replace_causal_mask, + replace_kv_cache_with_simple_kv_cache, replace_sdpa_with_custom_op, + replace_sdpa_with_flex_sdpa, replace_sdpa_with_simple_sdpa, ) @@ -197,6 +199,12 @@ def build_args_parser() -> argparse.ArgumentParser: action="store_true", help="Whether to use sdpa_with_kv_cache update op when using kv cache", ) + parser.add_argument( + "--num_sharding", + type=int, + default=None, + help="Specify the number of splits which is generated with custom op. Expect to be able to divide num layer.", + ) parser.add_argument( "--disable_dynamic_shape", dest="enable_dynamic_shape", @@ -385,7 +393,12 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: transforms.append(replace_sdpa_with_custom_op) if args.use_kv_cache: - if args.qnn or args.coreml or args.mps: + if args.qnn: + transforms.append(replace_kv_cache_with_simple_kv_cache) + transforms.append(replace_sdpa_with_flex_sdpa) + transforms.append(replace_causal_mask) + + elif args.coreml or args.mps: # Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition # to get free perf gain. transforms.append(replace_sdpa_with_simple_sdpa) @@ -486,11 +499,11 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 modelname = f"coreml_{modelname}" if args.qnn: + from executorch.examples.models.llama2.custom_ops import model_sharding + partitioners.append( get_qnn_partitioner( - quant_dtype, - args.use_kv_cache, - args.pt2e_quantize, + args.use_kv_cache, args.pt2e_quantize, args.num_sharding ) ) # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` @@ -498,6 +511,12 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program` _transform(builder_exported_to_edge.edge_manager.exported_program()) + if args.num_sharding is not None: + model_sharding.split_graph( + builder_exported_to_edge.edge_manager.exported_program(), + builder_exported_to_edge.metadata["get_n_layers"], + shares=args.num_sharding, + ) if args.generate_etrecord: if not builder_exported_to_edge.edge_manager: @@ -506,7 +525,12 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 logging.info("Generating etrecord") # Copy the edge manager which will be serialized into etrecord. This is memory-wise expensive. edge_manager_copy = copy.deepcopy(builder_exported_to_edge.edge_manager) - builder = builder_exported_to_edge.to_backend(partitioners).to_executorch() + builder = builder_exported_to_edge.to_backend(partitioners) + if args.num_sharding is not None: + from executorch.backends.qualcomm.utils.utils import canonicalize_program + + canonicalize_program(builder.edge_manager.exported_program()) + builder = builder.to_executorch() # Generate ETRecord if edge_manager_copy: @@ -517,7 +541,12 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 ) logging.info("Generated etrecord.bin") else: - builder = builder_exported_to_edge.to_backend(partitioners).to_executorch() + builder = builder_exported_to_edge.to_backend(partitioners) + if args.num_sharding is not None: + from executorch.backends.qualcomm.utils.utils import canonicalize_program + + canonicalize_program(builder.edge_manager.exported_program()) + builder = builder.to_executorch() if args.profile_memory: generate_memory_trace(builder.export_program, "memory_profile.json") diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 56bf4a96c39..dacf9eb1fdc 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -161,6 +161,9 @@ def __init__( else: cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) + self.max_batch_size = max_batch_size + self.n_heads = n_heads + self.head_dim = head_dim self.transpose_cache = transpose_cache self.enable_dynamic_shape = enable_dynamic_shape self.register_buffer( diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index fdf0dc707e4..6efca35e34a 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -203,7 +203,7 @@ def get_example_inputs(self): else: return ( torch.tensor( - [[1, 2, 3]], dtype=torch.long + [[1, 2, 3]], dtype=torch.int32 ), # tokens, with kv cache our input token length is always just 1 token. ) @@ -211,15 +211,15 @@ def get_example_inputs(self): def get_example_inputs_kvcache_sdpa(self): if self.enable_dynamic_shape: return ( - torch.tensor([[2, 3, 4]], dtype=torch.long), - torch.tensor([0], dtype=torch.long), + torch.tensor([[2, 3, 4]], dtype=torch.int32), + torch.tensor([0, 1, 2], dtype=torch.int32), ) else: return ( torch.tensor( - [[1]], dtype=torch.long + [[1]], dtype=torch.int32 ), # tokens, with kv cache our input token length is always just 1 token. torch.tensor( - [0], dtype=torch.long + [0], dtype=torch.int32 ), # start_pos, what token of output are we on. ) diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp index cd5346bacda..e376b1c4149 100644 --- a/examples/models/llama2/runner/runner.cpp +++ b/examples/models/llama2/runner/runner.cpp @@ -259,7 +259,7 @@ Result Runner::prefill( // Given an input token. Set up the inputs for the model and execute a single // step. Returning the logits tensor. Result Runner::run_model_step( - int64_t input_token, + int32_t input_token, ManagedTensor& managed_tokens, ManagedTensor& managed_start_pos, size_t max_seq_len) { @@ -270,7 +270,7 @@ Result Runner::run_model_step( // When using kv-cache our input is always 1 token, so just update to the // latest. - tokens.mutable_data_ptr()[0] = input_token; + tokens.mutable_data_ptr()[0] = input_token; Result> outputs_res = module_->forward({tokens, start_pos}); @@ -283,7 +283,7 @@ Result Runner::run_model_step( "Non Tensor Output returned from executing LLM"); // Bump start_pos by 1 - start_pos.mutable_data_ptr()[0]++; + start_pos.mutable_data_ptr()[0]++; // Return the logits tensor return outputs_res.get()[0].toTensor(); @@ -295,7 +295,7 @@ Result Runner::run_model_step( // When not using kv-cache our input is the entire history of tokens we have // seen, so resize input to be 1 larger and append the new token to the end. // TODO does this work in ATen mode? - tokens.mutable_data_ptr()[tokens.size(1) - 1] = input_token; + tokens.mutable_data_ptr()[tokens.size(1) - 1] = input_token; // inputs:[tokens] inputs.push_back(tokens); @@ -365,12 +365,12 @@ Error Runner::generate( "Sequence length exceeded - please increase the seq_len value passed to generate()"); // start the main loop - int64_t pos = 0; // position in the sequence + int32_t pos = 0; // position in the sequence - std::vector token_data; // allocate space for the tokens + std::vector token_data; // allocate space for the tokens std::vector token_shape = {1, seq_len}; - std::vector start_pos_data; // allocate space for the tokens + std::vector start_pos_data; // allocate space for the tokens std::vector start_pos_shape = {1}; token_data.resize(seq_len); @@ -381,17 +381,17 @@ Error Runner::generate( } // initialize tensor wrappers - ManagedTensor tokens_managed( - token_data.data(), token_shape, ScalarType::Long); + ManagedTensor tokens_managed(token_data.data(), token_shape, ScalarType::Int); // Create with the max shape to approapriately set the capacity of this // tensor, then resize back to 1 for first input. tokens_managed.resize({1, 1}); ManagedTensor start_pos_managed( - start_pos_data.data(), start_pos_shape, ScalarType::Long); - int64_t prev_token; - int64_t cur_token = prompt_tokens[0]; + start_pos_data.data(), start_pos_shape, ScalarType::Int); + + int32_t prev_token; + int32_t cur_token = prompt_tokens[0]; // Prefill first // Here feed all tokens to the model and get the next predicted token diff --git a/examples/models/llama2/runner/runner.h b/examples/models/llama2/runner/runner.h index 407527531df..47c32faab23 100644 --- a/examples/models/llama2/runner/runner.h +++ b/examples/models/llama2/runner/runner.h @@ -54,7 +54,7 @@ class Runner { ManagedTensor& managed_start_pos, std::function token_callback); Result run_model_step( - int64_t input_token, + int32_t input_token, ManagedTensor& tokens, ManagedTensor& start_pos, size_t max_seq_len); diff --git a/examples/models/llama2/source_transformation/sdpa.py b/examples/models/llama2/source_transformation/sdpa.py index 4e0ac718689..7e2e9d95544 100644 --- a/examples/models/llama2/source_transformation/sdpa.py +++ b/examples/models/llama2/source_transformation/sdpa.py @@ -9,6 +9,7 @@ # Example script for exporting Llama2 to flatbuffer import math +from typing import Tuple import torch @@ -112,6 +113,43 @@ def forward( return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) +class SDPAFlex(torch.nn.Module): + + def __init__( + self, + kv_cache: KVCache, + dim: int, + ): + super().__init__() + self.kv_cache = kv_cache + self.dim = dim + + def forward( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz, + seqlen, + mask, + ): + q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + k = k + v = v + + k, v = self.kv_cache.update(input_pos, k, v) + attn_mask = mask[input_pos] + + scale_factor = 1 / math.sqrt(q.size(-1)) + attn_weight = q @ k.transpose(-2, -1) * scale_factor + attn_weight += attn_mask + attn_weight = torch.softmax(attn_weight, dim=-1) + y = attn_weight @ v + + return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + def replace_sdpa_with_simple_sdpa(module: torch.nn.Module): for name, child in module.named_children(): if isinstance(child, SDPA): @@ -125,6 +163,71 @@ def replace_sdpa_with_simple_sdpa(module: torch.nn.Module): return module +def replace_sdpa_with_flex_sdpa(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, SDPA): + setattr( + module, + name, + SDPAFlex(child.kv_cache, child.dim), + ) + else: + replace_sdpa_with_flex_sdpa(child) + return module + + +class KVCacheSimple(torch.nn.Module): + def __init__( + self, + max_batch_size: int, + max_seq_length: int, + n_heads: int, + head_dim: int, + dtype=torch.float32, + ): + super().__init__() + cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) + self.register_buffer( + "past_k_caches", + torch.zeros(cache_shape, dtype=dtype, device="cpu"), + persistent=False, + ) + self.register_buffer( + "past_v_caches", + torch.zeros(cache_shape, dtype=dtype, device="cpu"), + persistent=False, + ) + + def update( + self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + k_out = torch.ops.aten.index_put_(self.past_k_caches, [None, input_pos], k_val) + v_out = torch.ops.aten.index_put_(self.past_v_caches, [None, input_pos], v_val) + + k_out = k_out.transpose(1, 2) + v_out = v_out.transpose(1, 2) + return k_out, v_out + + +def replace_kv_cache_with_simple_kv_cache(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, KVCache): + setattr( + module, + name, + KVCacheSimple( + child.max_batch_size, + child.max_seq_length, + child.n_heads, + child.head_dim, + child.k_cache.dtype, + ), + ) + else: + replace_kv_cache_with_simple_kv_cache(child) + return module + + def replace_causal_mask(module: torch.nn.Module): for buffer_fqn_name, buffer in module.named_buffers(): buffer_name = buffer_fqn_name.split(".")[-1] @@ -132,7 +235,7 @@ def replace_causal_mask(module: torch.nn.Module): max_seq_len = buffer.shape[-1] mask = torch.full( (max_seq_len, max_seq_len), - float("-inf"), + float("-255"), device="cpu", ) diff --git a/extension/llm/custom_ops/model_sharding.py b/extension/llm/custom_ops/model_sharding.py new file mode 100644 index 00000000000..5d3bcc1ee32 --- /dev/null +++ b/extension/llm/custom_ops/model_sharding.py @@ -0,0 +1,93 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import re +from typing import List + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch.export.exported_program import ExportedProgram +from torch.library import impl, Library + + +fallback_op_lib = Library("llama", "DEF") +# registering an operator. +fallback_op_lib.define("fallback(Tensor input) -> Tensor") + + +@impl(fallback_op_lib, "fallback") +def fallback_impl(a: torch.Tensor) -> torch.Tensor: + return a + + +# registering the out variant. +fallback_op_lib.define("fallback.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)") + + +@impl(fallback_op_lib, "fallback.out") +def fallback_out_impl(a: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor: + out.copy_(a) + return out + + +class SplitGraph(ExportPass): + """ + Handle to split the llama model to multiple partitions. + Because there are limited memory on the device, it could + not load all llama model in one pte. + """ + + def __init__(self, shard_layers: List[int]): + super().__init__() + self.shard_layers = shard_layers + + def _insert_fallback_op( + self, graph_module: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + pattern = r"layers.(\d+)" + prev_node = None + prev_layer = None + for node in graph_module.graph.nodes: + if node.op != "call_function" or "nn_module_stack" not in node.meta: + continue + + module_values_list = list(node.meta["nn_module_stack"].values()) + full_qualified_name = module_values_list[-1][0] + match = re.search(pattern, full_qualified_name) + if match is None: + continue + + cur_layer = int(match.group(1)) + # Check the current node which is the last node of the layer + if cur_layer in self.shard_layers and prev_layer == cur_layer - 1: + with graph_module.graph.inserting_after(prev_node): + users = list(prev_node.users.keys()) + inserted_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.llama.fallback.default, + (prev_node,), + ) + inserted_node.meta["val"] = prev_node.meta["val"] + if prev_node.meta.get("quant_attrs", None): + inserted_node.meta["quant_attrs"] = prev_node.meta[ + "quant_attrs" + ] + for user in users: + user.replace_input_with(prev_node, inserted_node) + + prev_layer = cur_layer + prev_node = node + + def call(self, graph_module: torch.fx.GraphModule): + self._insert_fallback_op(graph_module) + graph_module.recompile() + return PassResult(graph_module, True) + + +def split_graph(edge_program: ExportedProgram, num_layers: int, shares: int): + graph_module = edge_program.graph_module + shard_layers = list(range(0, num_layers, int(num_layers / shares))) + return SplitGraph(shard_layers)(graph_module) diff --git a/extension/llm/custom_ops/op_fallback.cpp b/extension/llm/custom_ops/op_fallback.cpp new file mode 100644 index 00000000000..05f8186853e --- /dev/null +++ b/extension/llm/custom_ops/op_fallback.cpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include + +namespace torch { +namespace executor { + +namespace native { + +// Copy from op_clone.cpp +Tensor& fallback_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) { + (void)ctx; + + ET_KERNEL_CHECK( + ctx, + resize_tensor(out, in.sizes()) == torch::executor::Error::Ok, + InvalidArgument, + out); + + // The input and out shall share same dtype and size + ET_KERNEL_CHECK( + ctx, tensors_have_same_shape_and_dtype(in, out), InvalidArgument, out); + + if (in.nbytes() > 0) { + // Note that this check is important. It's valid for a tensor with numel 0 + // to have a null data pointer, but in some environments it's invalid to + // pass a null pointer to memcpy() even when the size is zero. + memcpy(out.mutable_data_ptr(), in.const_data_ptr(), in.nbytes()); + } + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch + +EXECUTORCH_LIBRARY( + llama, + "fallback.out", + torch::executor::native::fallback_out); diff --git a/extension/llm/custom_ops/op_fallback.h b/extension/llm/custom_ops/op_fallback.h new file mode 100644 index 00000000000..62a2c0d53eb --- /dev/null +++ b/extension/llm/custom_ops/op_fallback.h @@ -0,0 +1,20 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace torch { +namespace executor { + +namespace native { +Tensor& fallback_out(RuntimeContext& ctx, const Tensor& in, Tensor& out); +} // namespace native +} // namespace executor +} // namespace torch diff --git a/extension/llm/custom_ops/targets.bzl b/extension/llm/custom_ops/targets.bzl index 8c38eb6a0a0..e3ed9fe0a99 100644 --- a/extension/llm/custom_ops/targets.bzl +++ b/extension/llm/custom_ops/targets.bzl @@ -8,8 +8,8 @@ def define_common_targets(): """ runtime.cxx_library( name = "custom_ops", - srcs = ["op_sdpa.cpp"], - exported_headers = ["op_sdpa.h"], + srcs = ["op_sdpa.cpp", "op_fallback.cpp"], + exported_headers = ["op_sdpa.h", "op_fallback.h"], exported_deps = [ "//executorch/runtime/kernel:kernel_includes", "//executorch/kernels/portable/cpu:scalar_utils", diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 264e1e95ad3..78ccd0ea02d 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -166,7 +166,33 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager": ) return self - def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager": + def calibrate(self, module: torch.fx.GraphModule): + from sentencepiece import SentencePieceProcessor + + sp_model = SentencePieceProcessor(model_file="tokenizer.model") + + # TODO: change criteria & support batch inputs if necessary + pos = torch.tensor(0, dtype=torch.int32) + token_list = [sp_model.bos_id()] + user_prompts = ["Once", "upon", "a", "time"] + for prompt in user_prompts: + token_list += sp_model.encode(prompt) + + with torch.no_grad(): + while token_list[-1] != sp_model.eos_id() and pos < 128: + logits = module( + torch.full((1, 1), token_list[pos]), + torch.full((1, 1), pos), + ) + pos += 1 + if pos >= len(token_list): + token_list.append(torch.argmax(logits[:, -1], dim=-1).item()) + + print(f"calibration data:\n{sp_model.decode(token_list)}") + + def pt2e_quantize( + self, quantizers: Optional[List[Quantizer]] + ) -> "LlamaEdgeManager": """ Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model. Args: @@ -189,7 +215,8 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage ), "Please run capture_pre_autograd_graph first" m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer) # Calibrate - m(*self.example_inputs) + self.calibrate(m) + # m(*self.example_inputs) m = convert_pt2e(m) DuplicateDynamicQuantChainPass()(m) self.pre_autograd_graph_module = m diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index bcbeeeee159..dfdab1e0b19 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -105,7 +105,9 @@ def get_coreml_partitioner( def get_qnn_partitioner( - quant_dtype, use_kv_cache: bool = False, pt2e_quantize: Optional[str] = None + use_kv_cache: bool = False, + pt2e_quantize: Optional[str] = None, + num_sharding: int = None, ): assert ( use_kv_cache is True @@ -116,9 +118,6 @@ def get_qnn_partitioner( QnnPartitioner, ) - # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.quantizer.quantizer` - from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype - # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.serialization.qnn_compile_spec_schema` from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( QcomChipset, @@ -135,27 +134,20 @@ def get_qnn_partitioner( ) use_fp16 = True - skip_node_op_set = {} + skip_node_op_set = {"llama.fallback.default"} if pt2e_quantize is not None: use_fp16 = False - # TODO: fix the lowering error without skipping nodes - - if quant_dtype == QuantDtype.use_8a8w: - raise NotImplementedError("8a8w for llama is still under development") - - elif quant_dtype == QuantDtype.use_16a16w: - raise NotImplementedError("16a16w for llama is still under development") - - elif quant_dtype == QuantDtype.use_16a4w: - raise NotImplementedError("16a4w for llama is still under development") return QnnPartitioner( generate_qnn_executorch_compiler_spec( soc_model=QcomChipset.SM8650, # default to SM8650 - backend_options=generate_htp_compiler_spec(use_fp16=use_fp16), + backend_options=generate_htp_compiler_spec( + use_fp16=use_fp16, + use_multi_contexts=num_sharding is not None, + ), debug=False, saver=False, ), - skip_node_id_set={}, + skip_node_id_set=None, skip_node_op_set=skip_node_op_set, ) diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index fe6ad1c201a..2c1c4552236 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -153,6 +153,7 @@ def get_qnn_quantizer( QnnQuantizer, QuantDtype, ) + from torch.ao.quantization.observer import MinMaxObserver except ImportError: raise ImportError( @@ -169,19 +170,20 @@ def get_qnn_quantizer( # more custom quantization are supported including 16a4w etc. default to 8bit quantized custom_annotations = () if quant_config == "8a8w": - raise NotImplementedError("8a8w for llama is still under development") quant_dtype = QuantDtype.use_8a8w pass elif quant_config == "16a16w": - raise NotImplementedError("16a16w for llama is still under development") quant_dtype = QuantDtype.use_16a16w qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS) - qnn_quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config()) + qnn_quantizer.set_bit16_op_quant_config( + get_default_16bit_qnn_ptq_config(act_observer=MinMaxObserver) + ) elif quant_config == "16a4w": - raise NotImplementedError("16a4w for llama is still under development") quant_dtype = QuantDtype.use_16a4w qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS) - qnn_quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config()) + qnn_quantizer.set_bit16_op_quant_config( + get_16a4w_qnn_ptq_config(act_observer=MinMaxObserver) + ) qnn_quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") else: raise AssertionError( From d8d05576cb54722e26f0e70e8d728aaae5ccb924 Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Wed, 10 Jul 2024 15:56:00 +0800 Subject: [PATCH 2/5] Back to int64 for inputs for minimum changed. But it will result in embedding op fallback. If change pos_ids to int32, it will be fully delegated. --- examples/models/llama2/model.py | 10 +++++----- examples/models/llama2/runner/runner.cpp | 24 ++++++++++++------------ examples/models/llama2/runner/runner.h | 2 +- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index 6efca35e34a..a5755571c15 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -203,7 +203,7 @@ def get_example_inputs(self): else: return ( torch.tensor( - [[1, 2, 3]], dtype=torch.int32 + [[1, 2, 3]], dtype=torch.long ), # tokens, with kv cache our input token length is always just 1 token. ) @@ -211,15 +211,15 @@ def get_example_inputs(self): def get_example_inputs_kvcache_sdpa(self): if self.enable_dynamic_shape: return ( - torch.tensor([[2, 3, 4]], dtype=torch.int32), - torch.tensor([0, 1, 2], dtype=torch.int32), + torch.tensor([[2, 3, 4]], dtype=torch.long), + torch.tensor([0, 1, 2], dtype=torch.long), ) else: return ( torch.tensor( - [[1]], dtype=torch.int32 + [[1]], dtype=torch.long ), # tokens, with kv cache our input token length is always just 1 token. torch.tensor( - [0], dtype=torch.int32 + [0], dtype=torch.long ), # start_pos, what token of output are we on. ) diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp index e376b1c4149..cd5346bacda 100644 --- a/examples/models/llama2/runner/runner.cpp +++ b/examples/models/llama2/runner/runner.cpp @@ -259,7 +259,7 @@ Result Runner::prefill( // Given an input token. Set up the inputs for the model and execute a single // step. Returning the logits tensor. Result Runner::run_model_step( - int32_t input_token, + int64_t input_token, ManagedTensor& managed_tokens, ManagedTensor& managed_start_pos, size_t max_seq_len) { @@ -270,7 +270,7 @@ Result Runner::run_model_step( // When using kv-cache our input is always 1 token, so just update to the // latest. - tokens.mutable_data_ptr()[0] = input_token; + tokens.mutable_data_ptr()[0] = input_token; Result> outputs_res = module_->forward({tokens, start_pos}); @@ -283,7 +283,7 @@ Result Runner::run_model_step( "Non Tensor Output returned from executing LLM"); // Bump start_pos by 1 - start_pos.mutable_data_ptr()[0]++; + start_pos.mutable_data_ptr()[0]++; // Return the logits tensor return outputs_res.get()[0].toTensor(); @@ -295,7 +295,7 @@ Result Runner::run_model_step( // When not using kv-cache our input is the entire history of tokens we have // seen, so resize input to be 1 larger and append the new token to the end. // TODO does this work in ATen mode? - tokens.mutable_data_ptr()[tokens.size(1) - 1] = input_token; + tokens.mutable_data_ptr()[tokens.size(1) - 1] = input_token; // inputs:[tokens] inputs.push_back(tokens); @@ -365,12 +365,12 @@ Error Runner::generate( "Sequence length exceeded - please increase the seq_len value passed to generate()"); // start the main loop - int32_t pos = 0; // position in the sequence + int64_t pos = 0; // position in the sequence - std::vector token_data; // allocate space for the tokens + std::vector token_data; // allocate space for the tokens std::vector token_shape = {1, seq_len}; - std::vector start_pos_data; // allocate space for the tokens + std::vector start_pos_data; // allocate space for the tokens std::vector start_pos_shape = {1}; token_data.resize(seq_len); @@ -381,17 +381,17 @@ Error Runner::generate( } // initialize tensor wrappers - ManagedTensor tokens_managed(token_data.data(), token_shape, ScalarType::Int); + ManagedTensor tokens_managed( + token_data.data(), token_shape, ScalarType::Long); // Create with the max shape to approapriately set the capacity of this // tensor, then resize back to 1 for first input. tokens_managed.resize({1, 1}); ManagedTensor start_pos_managed( + start_pos_data.data(), start_pos_shape, ScalarType::Long); - start_pos_data.data(), start_pos_shape, ScalarType::Int); - - int32_t prev_token; - int32_t cur_token = prompt_tokens[0]; + int64_t prev_token; + int64_t cur_token = prompt_tokens[0]; // Prefill first // Here feed all tokens to the model and get the next predicted token diff --git a/examples/models/llama2/runner/runner.h b/examples/models/llama2/runner/runner.h index 47c32faab23..407527531df 100644 --- a/examples/models/llama2/runner/runner.h +++ b/examples/models/llama2/runner/runner.h @@ -54,7 +54,7 @@ class Runner { ManagedTensor& managed_start_pos, std::function token_callback); Result run_model_step( - int32_t input_token, + int64_t input_token, ManagedTensor& tokens, ManagedTensor& start_pos, size_t max_seq_len); From ddd9c6233f835e7c9524cdc08aeb6ff72c4d2a6f Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Tue, 16 Jul 2024 18:03:24 +0800 Subject: [PATCH 3/5] annotate matmul 16a8w --- extension/llm/export/quantizer_lib.py | 81 ++++++++++++++++++++++++++- 1 file changed, 79 insertions(+), 2 deletions(-) diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 2c1c4552236..8dab9d593ca 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -160,6 +160,81 @@ def get_qnn_quantizer( "Please install the Qualcomm backend follwing https://pytorch.org/executorch/main/build-run-qualcomm.html" ) + def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: + """ + This function is specific for matmul op 16a8w. + """ + from executorch.backends.qualcomm.quantizer.quantizer import ( + get_16a8w_qnn_ptq_config, + get_default_8bit_qnn_ptq_config, + QuantizationConfig, + ) + from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY + from torch.ao.quantization.quantizer import ( + QuantizationAnnotation, + SharedQuantizationSpec, + ) + from torch.fx import Node + + def annotate_matmul(node: Node, quantization_config: QuantizationConfig): + input_qspec_map = {} + input_act = node.args[0] + input_spec = quantization_config.input_activation + input_qspec_map[input_act] = input_spec + + input_act1 = node.args[1] + input_spec1 = quantization_config.weight + input_qspec_map[input_act1] = input_spec1 + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None: + input = node.args[0] + value = node.args[2] + + input_qspec_map = {} + input_qspec_map[input] = quantization_config.input_activation + input_qspec_map[value] = SharedQuantizationSpec((input, node)) + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=SharedQuantizationSpec((input, node)), + _annotated=True, + ) + + def annotate_single_in_single_out( + node: Node, quantization_config: QuantizationConfig + ) -> None: + + input_qspec_map = {} + input_act = node.args[0] + input_qspec_map[input_act] = quantization_config.input_activation + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + def annotate_matmul_input1(node: Node): + quantization_config_8a8w = get_default_8bit_qnn_ptq_config(act_symmetric=True) + while isinstance(node, Node) and node.op == "call_function": + if node.target == torch.ops.aten.index_put_.default: + annotate_index_put(node, quantization_config_8a8w) + break + annotate_single_in_single_out(node, quantization_config_8a8w) + node = node.args[0] + + quantization_config_16a8w = get_16a8w_qnn_ptq_config() + + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: + annotate_matmul(node, quantization_config_16a8w) + annotate_matmul_input1(node.args[1]) + backend, quant_config = pt2e_quantize.split("_") assert ( backend == "qnn" @@ -168,16 +243,17 @@ def get_qnn_quantizer( qnn_quantizer.set_per_channel_conv_quant(enable=True) qnn_quantizer.set_per_channel_linear_quant(enable=True) # more custom quantization are supported including 16a4w etc. default to 8bit quantized - custom_annotations = () + if quant_config == "8a8w": quant_dtype = QuantDtype.use_8a8w - pass + custom_annotations = () elif quant_config == "16a16w": quant_dtype = QuantDtype.use_16a16w qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS) qnn_quantizer.set_bit16_op_quant_config( get_default_16bit_qnn_ptq_config(act_observer=MinMaxObserver) ) + custom_annotations = () elif quant_config == "16a4w": quant_dtype = QuantDtype.use_16a4w qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS) @@ -185,6 +261,7 @@ def get_qnn_quantizer( get_16a4w_qnn_ptq_config(act_observer=MinMaxObserver) ) qnn_quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") + custom_annotations = (annotate_matmul_16a8w, ) else: raise AssertionError( f"No support for quant type {quant_config}. Support 8a8w, 16a16w and 16a4w." From 0eb89cfb34972e7148e8b2348ec71dfc7776cd12 Mon Sep 17 00:00:00 2001 From: chunit-quic Date: Tue, 6 Aug 2024 15:56:54 +0800 Subject: [PATCH 4/5] [rebase to dev 20240806] --- examples/models/llama2/export_llama_lib.py | 2 +- extension/llm/custom_ops/op_fallback.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 0c21ecf78c4..3f7e4ea466f 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -499,7 +499,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 modelname = f"coreml_{modelname}" if args.qnn: - from executorch.examples.models.llama2.custom_ops import model_sharding + from executorch.extension.llm.custom_ops import model_sharding partitioners.append( get_qnn_partitioner( diff --git a/extension/llm/custom_ops/op_fallback.cpp b/extension/llm/custom_ops/op_fallback.cpp index 05f8186853e..11a1b4e7faf 100644 --- a/extension/llm/custom_ops/op_fallback.cpp +++ b/extension/llm/custom_ops/op_fallback.cpp @@ -5,7 +5,7 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ -#include +#include #include #include From e68e225a742cc0132064cd061343319c8216a8ef Mon Sep 17 00:00:00 2001 From: chunit-quic Date: Tue, 6 Aug 2024 16:38:39 +0800 Subject: [PATCH 5/5] [llama 3 wip] - Support GQA, repeating kv caches - Support TicTokenizer for llm/export/builder.py - Support --embedding-quantize option for qualcomm lowering flow --- backends/qualcomm/partition/common_defs.py | 1 + .../qualcomm/partition/qnn_partitioner.py | 2 +- .../qualcomm/passes/replace_inf_buffer.py | 5 +- examples/models/llama2/export_llama_lib.py | 2 +- .../llama2/source_transformation/quantize.py | 22 ++++-- .../llama2/source_transformation/sdpa.py | 23 +++++- extension/llm/export/builder.py | 34 ++++++--- extension/llm/export/quantizer_lib.py | 71 +++++++++++++++---- 8 files changed, 127 insertions(+), 33 deletions(-) diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index c60afc2dd33..b6d1f3708a4 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -14,6 +14,7 @@ exir_ops.edge.aten.full.default, exir_ops.edge.aten.slice_scatter.default, exir_ops.edge.aten.copy.default, + exir_ops.edge.quantized_decomposed.embedding_4bit.dtype, ] allow_list_operator = [ diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index 90aa96f7146..8080947f929 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -60,7 +60,7 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool: print(f"[QNN Partitioner Op Support]: {node.target.__name__} | Skipped") return False - if node.target.__name__ in self.skip_node_op_set: + if self.skip_node_op_set is not None and node.target.__name__ in self.skip_node_op_set: print(f"[QNN Partitioner Op Support]: {node.target.__name__} | Skipped") return False diff --git a/backends/qualcomm/passes/replace_inf_buffer.py b/backends/qualcomm/passes/replace_inf_buffer.py index bafa3fdb18b..1dc06630ca3 100644 --- a/backends/qualcomm/passes/replace_inf_buffer.py +++ b/backends/qualcomm/passes/replace_inf_buffer.py @@ -14,8 +14,9 @@ def __init__(self): def call(self, graph_module: torch.fx.GraphModule): for buf_name, tensor in graph_module.named_buffers(): if tensor.is_floating_point(): - tensor[tensor == float("inf")] = torch.finfo(torch.float32).max - tensor[tensor == float("-inf")] = torch.finfo(torch.float32).min + # An arbitrary number + tensor[tensor == float("inf")] = 1000 + tensor[tensor == float("-inf")] = -1000 setattr(graph_module, buf_name, tensor) graph_module.recompile() diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 3f7e4ea466f..6e88277af68 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -120,7 +120,7 @@ def build_args_parser() -> argparse.ArgumentParser: "--embedding-quantize", default=None, type=str, - help="type of embedding quantization, ',', e.g., '8,1024'.", + help="type of embedding quantization, ',,', e.g., '8,1024,32'.", ) parser.add_argument( "--pt2e_quantize", diff --git a/examples/models/llama2/source_transformation/quantize.py b/examples/models/llama2/source_transformation/quantize.py index bb014145bd8..de2d39f422d 100644 --- a/examples/models/llama2/source_transformation/quantize.py +++ b/examples/models/llama2/source_transformation/quantize.py @@ -384,7 +384,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: def replace_embedding_weight_only_grouped_int8_per_channel( - module, device, bitwidth: int = 8, group_size: Optional[int] = None, packed=False + module, device, bitwidth: int = 8, group_size: Optional[int] = None, packed=False, qge_dtype=torch.half ): for name, child in module.named_children(): # print(f"name: {name}") @@ -400,11 +400,12 @@ def replace_embedding_weight_only_grouped_int8_per_channel( embedding_dim=child.weight.shape[1], group_size=group_size, packed=packed, + dtype=qge_dtype, ), ) else: replace_embedding_weight_only_grouped_int8_per_channel( - child, device, bitwidth, group_size, packed + child, device, bitwidth, group_size, packed, qge_dtype ) @@ -417,6 +418,7 @@ def __init__( bitwidth: int = 8, group_size: Optional[int] = None, packed=False, + qge_dtype=torch.half, ): if isinstance(packed, str): packed = packed == "True" @@ -425,6 +427,7 @@ def __init__( self.group_size = group_size self.bitwidth = bitwidth self.packed = packed + self.qge_dtype = qge_dtype if (bitwidth != 4) and packed: raise RuntimeError("pack only works with bitsize 4") @@ -484,7 +487,7 @@ def create_quantized_state_dict(self, packed=False) -> Dict: def convert_for_runtime(self) -> nn.Module: replace_embedding_weight_only_grouped_int8_per_channel( - self.mod, self.device, self.bitwidth, self.group_size, self.packed + self.mod, self.device, self.bitwidth, self.group_size, self.packed, self.qge_dtype ) return self.mod @@ -554,17 +557,28 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: def get_quant_embedding_transform(args): - bitwidth, group_size = args.embedding_quantize.split(",") + quant_args = [a.strip() for a in args.embedding_quantize.split(",")] + bitwidth, group_size = quant_args[:2] if group_size == "none" or group_size == "None" or group_size == "0": group_size = None else: group_size = int(group_size) bitwidth = int(bitwidth) + + if len(quant_args) == 3: + qge_dtype = quant_args[2] + if qge_dtype in ("32", "torch.float32"): + qge_dtype = torch.float32 + else: + print(f"Use default qge_dtype, {torch.half}") + qge_dtype = torch.half + return lambda model: EmbeddingQuantHandler( model, bitwidth=bitwidth, group_size=group_size, packed=(bitwidth == 4), + qge_dtype=qge_dtype, ).quantized_model() diff --git a/examples/models/llama2/source_transformation/sdpa.py b/examples/models/llama2/source_transformation/sdpa.py index 7e2e9d95544..3f8f19d890d 100644 --- a/examples/models/llama2/source_transformation/sdpa.py +++ b/examples/models/llama2/source_transformation/sdpa.py @@ -9,13 +9,27 @@ # Example script for exporting Llama2 to flatbuffer import math -from typing import Tuple +from typing import Tuple, List import torch from executorch.examples.models.llama2.llama_transformer import KVCache, SDPA +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + + new_kv = [] + batch, n_heads, seqlen, head_dim = hidden_states.shape + n_heads *= n_rep + for h in hidden_states[0]: + new_kv += [h] * n_rep + return torch.cat(new_kv, 0).reshape(batch, n_heads, seqlen, head_dim) + + class SDPACustom(torch.nn.Module): def __init__( self, @@ -135,10 +149,13 @@ def forward( mask, ): q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k - v = v k, v = self.kv_cache.update(input_pos, k, v) + + k_repeat_num = q.shape[1] // k.shape[1] + v_repeat_num = q.shape[1] // v.shape[1] + k = repeat_kv(k, k_repeat_num) + v = repeat_kv(v, v_repeat_num) attn_mask = mask[input_pos] scale_factor = 1 / math.sqrt(q.size(-1)) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 78ccd0ea02d..0760f80e08b 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -11,8 +11,11 @@ import logging from enum import Enum from typing import Any, Callable, List, Optional +from functools import partial import torch +from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer as Tiktoken +from sentencepiece import SentencePieceProcessor from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( DuplicateDynamicQuantChainPass, ) @@ -167,28 +170,23 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager": return self def calibrate(self, module: torch.fx.GraphModule): - from sentencepiece import SentencePieceProcessor - - sp_model = SentencePieceProcessor(model_file="tokenizer.model") + tokenizer = SimpleTokenizer("tokenizer.model") # TODO: change criteria & support batch inputs if necessary pos = torch.tensor(0, dtype=torch.int32) - token_list = [sp_model.bos_id()] - user_prompts = ["Once", "upon", "a", "time"] - for prompt in user_prompts: - token_list += sp_model.encode(prompt) + token_list = [tokenizer.bos_id] + tokenizer.encode("Once upon a time") with torch.no_grad(): - while token_list[-1] != sp_model.eos_id() and pos < 128: + while token_list[-1] != tokenizer.eos_id and pos < 128: logits = module( torch.full((1, 1), token_list[pos]), - torch.full((1, 1), pos), + torch.tensor((pos, )), ) pos += 1 if pos >= len(token_list): token_list.append(torch.argmax(logits[:, -1], dim=-1).item()) - print(f"calibration data:\n{sp_model.decode(token_list)}") + print(f"calibration data:\n{tokenizer.decode(token_list)}") def pt2e_quantize( self, quantizers: Optional[List[Quantizer]] @@ -321,3 +319,19 @@ def get_saved_pte_filename(self) -> Optional[str]: Return the filename of the most recenet saved .pte file. Return None if the model is not saved. """ return self._saved_pte_filename + +class SimpleTokenizer: + def __init__(self, model_path): + try: + module = SentencePieceProcessor(model_file=model_path) + self.bos_id = module.bos_id() + self.eos_id = module.eos_id() + self.encode = module.encode + self.decode = module.decode + except Exception: + print("Using Tiktokenizer") + module = Tiktoken(model_path=model_path) + self.bos_id = module.bos_id + self.eos_id = module.eos_id + self.encode = partial(module.encode, bos=False, eos=False) + self.decode = module.decode diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 8dab9d593ca..7e4f237bb9e 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -8,7 +8,7 @@ import logging from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Sequence import torch @@ -191,7 +191,10 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig): output_qspec=quantization_config.output_activation, _annotated=True, ) - def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None: + + def annotate_index_put( + node: Node, quantization_config: QuantizationConfig + ) -> None: input = node.args[0] value = node.args[2] @@ -219,21 +222,65 @@ def annotate_single_in_single_out( _annotated=True, ) - def annotate_matmul_input1(node: Node): - quantization_config_8a8w = get_default_8bit_qnn_ptq_config(act_symmetric=True) - while isinstance(node, Node) and node.op == "call_function": - if node.target == torch.ops.aten.index_put_.default: - annotate_index_put(node, quantization_config_8a8w) - break - annotate_single_in_single_out(node, quantization_config_8a8w) - node = node.args[0] + def annotate_cat(node: Node, quantization_config: QuantizationConfig): + input_nodes = node.args[0] + + assert isinstance(input_nodes, Sequence) + + first_input_node = input_nodes[0] + input_qspec_map = {} + assert isinstance(first_input_node, Node) + assert isinstance(node, Node) + input_qspec_map[first_input_node] = quantization_config.input_activation + share_qparams_with_input_act0_qspec = SharedQuantizationSpec( + (first_input_node, node) + ) + + for input_node in input_nodes[1:]: + if input_node not in input_qspec_map: + assert isinstance(input_node, Node) + input_qspec_map[input_node] = share_qparams_with_input_act0_qspec + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=share_qparams_with_input_act0_qspec, + _annotated=True, + ) + + def is_edge_condition(node: Node): + if not isinstance(node, Node) or node.op != "call_function": + return True + return False + + def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig): + if is_edge_condition(node): + return + + if node.target == torch.ops.aten.index_put_.default: + annotate_index_put(node, quantization_config) + annotate_matmul_input1(node.args[0], quantization_config) + elif node.target == torch.ops.aten.cat.default: + annotate_cat(node, quantization_config) + # Expect that the inputs of the cat op are select ops + for arg in node.args[0][1:]: + annotate_single_in_single_out(arg, quantization_config) + annotate_matmul_input1(node.args[0][0], quantization_config) + else: + annotate_single_in_single_out(node, quantization_config) + annotate_matmul_input1(node.args[0], quantization_config) + # Annotate 16a8w for matmul op to get better performance quantization_config_16a8w = get_16a8w_qnn_ptq_config() + # Annotate 8a8w for second input of matmul until past_kv_cache + quantization_config_8a8w = get_default_8bit_qnn_ptq_config(act_symmetric=True) for node in gm.graph.nodes: - if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.matmul.default + ): annotate_matmul(node, quantization_config_16a8w) - annotate_matmul_input1(node.args[1]) + annotate_matmul_input1(node.args[1], quantization_config_8a8w) backend, quant_config = pt2e_quantize.split("_") assert (