diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 9b05ad871f4..f8552e4fd4b 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -3154,9 +3154,9 @@ def test_llama3_2_1b(self): "llama3_2", "--model_mode", "hybrid", - "--prefill_seq_len", + "--prefill_ar_len", "32", - "--kv_seq_len", + "--max_seq_len", "512", "--num_sharding", "4", @@ -3234,9 +3234,9 @@ def test_llama_stories_110m(self): "stories110m", "--model_mode", "hybrid", - "--prefill_seq_len", + "--prefill_ar_len", "32", - "--kv_seq_len", + "--max_seq_len", "128", ] if self.compile_only: diff --git a/examples/qualcomm/oss_scripts/llama/README.md b/examples/qualcomm/oss_scripts/llama/README.md index 439278cb424..cd468eebb26 100644 --- a/examples/qualcomm/oss_scripts/llama/README.md +++ b/examples/qualcomm/oss_scripts/llama/README.md @@ -8,11 +8,16 @@ This file provides you the instructions to run LLAMA model with different parame We offer the following modes to execute the model: -Prefill Mode: This is also known as batch prefill mode, where the model takes in a list of tokens as input and generates the next token along with the key-value (KV) cache for all tokens. This mode is efficient for encoding the user's prompt. - KV Cache Mode: In KV Cache mode, the model takes in a single previous token and generates the next predicted token along with its KV cache. It is efficient for generating subsequent tokens after the initial prompt. -Hybrid Mode: Hybrid mode leverages the strengths of both batch prefill and KV cache modes to optimize token generation speed. Initially, it uses prefill mode to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens. +Hybrid Mode: Hybrid mode leverages the strengths of both AR-N model and KV cache modes to optimize token generation speed. Initially, it uses AR-N model to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens. + - AR-N model: The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use it to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor in hybrid mode. + - Prompt processing with AR-N model: +
+ Prompt Processing With AR-N Model +
Prompt processing is done using a for-loop. An N-token block is taken, and the KV cache is updated for that block. This process is repeated until all tokens are consumed, with the last block potentially requiring padding. For flexibility, the AR-N model can handle any input length less than the maximum sequence length. For TTFT, the input length (or number of blocks) will vary depending on the actual input length, rather than always being the same. +
+
## Instructions @@ -50,13 +55,13 @@ At the end of this step, users should have the following files ready: `consolida ### Step3: Run default examples using hybrid mode. #### LLAMA2 ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --llama_model stories110m --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "Once upon a time" +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --llama_model stories110m --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "Once upon a time" ``` #### LLAMA3.2 Default example using hybrid mode. ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1" +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" ``` ### KV Cache update mechanism @@ -109,16 +114,16 @@ We have two distinct mechanisms for updating the key-value (KV) cache, which can ### Additional Configs when running the script If you would like to compile the model only, we have provided the flag `--compile_only`. Taking LLAMA3.2 as an example: ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1" --compile_only +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --compile_only ``` On the other hand, if you already have a pre-compiled .pte model, you can perform inference by providing the flag `--pre_gen_pte` and specifying the folder that contains the .pte model. Taking LLAMA3.2 as an example: ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE} +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE} ``` You can select the KV Cache update mechanism at runtime by setting the `KV_UPDATER` variable to either "shift_pointer" or "smart_mask". By default, it is set to "smart_mask". `KV_UPDATER` = "shift_pointer" ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER} +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER} ``` diff --git a/examples/qualcomm/oss_scripts/llama/assets/PromptProcessingWithARN.png b/examples/qualcomm/oss_scripts/llama/assets/PromptProcessingWithARN.png new file mode 100644 index 00000000000..228b846f7c3 Binary files /dev/null and b/examples/qualcomm/oss_scripts/llama/assets/PromptProcessingWithARN.png differ diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index e853812a949..9cad2499730 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -89,32 +89,38 @@ logging.getLogger().setLevel(logging.INFO) -def smart_mask_updator(atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches): - for i, k_cache in enumerate(k_caches): - k_cache[:, :, pos] = new_k_caches[i][:, :, 0] +def smart_mask_updater( + ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches +): + # Update the KV cache input for the next inference when the position exceeds the autoregressive length. + if pos >= ar_len: + for i, k_cache in enumerate(k_caches): + k_cache[:, :, pos - ar_len] = new_k_caches[i][:, :, 0] - for i, v_cache in enumerate(v_caches): - v_cache[:, pos, :] = new_v_caches[i] + for i, v_cache in enumerate(v_caches): + v_cache[:, pos - ar_len, :] = new_v_caches[i][:, 0, :] + atten_mask[:, :, pos - ar_len] = 0 - atten_mask[0][pos] = 0 pos += 1 return (atten_mask, pos, k_caches, v_caches) -def shift_pointer_updator( - atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches +def shift_pointer_updater( + ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches ): - k_caches = [ - torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1) - for i, k_cache in enumerate(k_caches) - ] - v_caches = [ - torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1) - for i, v_cache in enumerate(v_caches) - ] + # Update the KV cache input for the next inference when the position exceeds the autoregressive length. + if pos >= ar_len: + k_caches = [ + torch.cat([k_cache[:, :, 1:], new_k_caches[i][:, :, :1]], dim=-1) + for i, k_cache in enumerate(k_caches) + ] + v_caches = [ + torch.cat([v_cache[:, 1:, :], new_v_caches[i][:, :1, :]], dim=1) + for i, v_cache in enumerate(v_caches) + ] + atten_mask[:, :, -pos - 1] = 0 pos += 1 - atten_mask[0][-pos - 1] = 0 return (atten_mask, pos, k_caches, v_caches) @@ -123,15 +129,15 @@ def _kv_calibrate( user_prompts, module: torch.fx.GraphModule, tokenizer, + ar_len=1, max_seq_len=512, - updator=smart_mask_updator, + updater=smart_mask_updater, use_i64_token=False, ): _, atten_mask, _, k_caches, v_caches = example_inputs # TODO: change criteria & support batch inputs if necessary - pos = torch.tensor(0, dtype=torch.int32) - max_cache_len = max_seq_len - 1 + all_pos = torch.arange(0, max_seq_len, 1, dtype=torch.int32).unsqueeze(0) token_list = [] # Llama2 tokenizer has no special tokens @@ -144,21 +150,50 @@ def _kv_calibrate( else: raise RuntimeError("Unkown tokenizer") + pos = len(token_list) if len(token_list) < ar_len else ar_len + dtype = torch.int64 if use_i64_token else torch.int32 + with torch.no_grad(): - while token_list[-1] != tokenizer.eos_id and pos < max_cache_len: - dtype = torch.int64 if use_i64_token else torch.int32 - token = torch.full((1, 1), token_list[pos], dtype=dtype) + while token_list[-1] != tokenizer.eos_id and pos < max_seq_len: + tmp_token_list = torch.tensor( + token_list[pos - ar_len : pos], dtype=dtype + ).reshape(1, -1) + tmp_pos = all_pos[:, pos - ar_len : pos] + tmp_atten_mask = atten_mask + if pos < ar_len: + tmp_token_list = torch.cat( + [ + torch.zeros((1, ar_len - pos), dtype=dtype), + torch.tensor(token_list, dtype=dtype).reshape(1, -1), + ], + dim=1, + ) + tmp_pos = torch.cat( + [ + torch.zeros((1, ar_len - pos), dtype=torch.int32), + all_pos[:, :pos], + ], + dim=1, + ) + tmp_atten_mask = torch.cat( + [ + torch.ones(1, ar_len, max_seq_len - pos) * -255.0, + atten_mask[:, :, -pos:], + ], + dim=-1, + ) + logits, new_k_caches, new_v_caches = module( - token, - atten_mask, - torch.full((1, 1), pos), + tmp_token_list, + tmp_atten_mask, + tmp_pos, *k_caches, *v_caches, ) - atten_mask, pos, k_caches, v_caches = updator( - atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches + atten_mask, pos, k_caches, v_caches = updater( + ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches ) - if pos >= len(token_list): + if pos > len(token_list): token_list.append(torch.argmax(logits[:, -1], dim=-1).item()) print(f"kv calibration data:\n{tokenizer.decode(token_list)}") @@ -173,7 +208,6 @@ def _prefill_calibrate( use_i64_token=False, ): _, atten_mask = example_inputs - max_cache_len = max_seq_len - 1 # TODO: change criteria & support batch inputs if necessary @@ -192,20 +226,24 @@ def _prefill_calibrate( dtype = torch.int64 if use_i64_token else torch.int32 with torch.no_grad(): - while token_list[-1] != tokenizer.eos_id and pos < max_cache_len: + while token_list[-1] != tokenizer.eos_id and pos < max_seq_len: tmp_token_list = torch.tensor(token_list, dtype=dtype).reshape(1, -1) - if pos < max_cache_len: + if pos < max_seq_len: tmp_token_list = torch.cat( [ tmp_token_list, - torch.zeros((1, max_cache_len - pos), dtype=dtype), + torch.zeros((1, max_seq_len - pos), dtype=dtype), ], dim=1, ) - logits, new_k_caches, new_v_caches = module( + results = module( tmp_token_list, atten_mask, ) + if len(results) == 3: + logits, new_k_caches, new_v_caches = results + elif len(results) == 1: + logits = results token_list.append(torch.argmax(logits[:, pos - 1], dim=-1).item()) pos += 1 @@ -217,8 +255,9 @@ def calibrate( user_prompts, module: torch.fx.GraphModule, tokenizer, + ar_len=1, max_seq_len=512, - kv_updator=smart_mask_updator, + kv_updater=smart_mask_updater, use_i64_token=False, ): if len(example_inputs) == 2: @@ -236,8 +275,9 @@ def calibrate( user_prompts, module, tokenizer, + ar_len, max_seq_len, - updator=kv_updator, + updater=kv_updater, use_i64_token=use_i64_token, ) else: @@ -268,56 +308,36 @@ def _tag_ios(self, gm: torch.fx.GraphModule, fixed_point_type): # shape of k caches and v caches kv_cache_shape = { - # single head, kv mode input + # single head, kv input (self.llama_meta["get_head_dim"], self.llama_meta["get_max_seq_len"]), (self.llama_meta["get_max_seq_len"], self.llama_meta["get_head_dim"]), - # single head, kv mode output - (self.llama_meta["get_head_dim"], 1), - (1, self.llama_meta["get_head_dim"]), - # single head, bert mode - (self.llama_meta["get_head_dim"], self.llama_meta["get_max_seq_len"] - 1), - (self.llama_meta["get_max_seq_len"] - 1, self.llama_meta["get_head_dim"]), + # single head, kv output + (self.llama_meta["get_head_dim"], self.llama_meta["get_ar_len"]), + (self.llama_meta["get_ar_len"], self.llama_meta["get_head_dim"]), } io_shape = { - # kv mode + # logit output ( self.llama_meta["get_max_batch_size"], - 1, - self.llama_meta["get_vocab_size"], - ), - # bert mode - ( - self.llama_meta["get_max_batch_size"], - self.llama_meta["get_max_seq_len"] - 1, + self.llama_meta["get_ar_len"], self.llama_meta["get_vocab_size"], ), } atten_mask_shape = { - # kv mode - (self.llama_meta["get_max_batch_size"], self.llama_meta["get_max_seq_len"]), - # bert mode ( - self.llama_meta["get_max_seq_len"] - 1, - self.llama_meta["get_max_seq_len"] - 1, + self.llama_meta["get_max_batch_size"], + self.llama_meta["get_ar_len"], + self.llama_meta["get_max_seq_len"], ), } freq_shape = { - # kv mode - (1, self.llama_meta["get_head_dim"] // 2), - # bert mode - ( - self.llama_meta["get_max_seq_len"] - 1, - self.llama_meta["get_head_dim"] // 2, - ), + (self.llama_meta["get_ar_len"], self.llama_meta["get_head_dim"] // 2), } freq_op = { - # kv mode exir_ops.edge.aten.select.int, - # bert mode - exir_ops.edge.aten.slice_copy.Tensor, } for n in gm.graph.nodes: @@ -376,8 +396,9 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): args.prompt, fx_graph_module, tokenizer=tokenizer, + ar_len=self.llama_meta["get_ar_len"], max_seq_len=self.llama_meta["get_max_seq_len"], - kv_updator=args.kv_updator, + kv_updater=args.kv_updater, use_i64_token=args.embedding_quantize is not None, ) @@ -467,12 +488,14 @@ def compile(args, pte_filename, tokenizer): kv_config = ModelArgs(**json.load(f)) # TODO: support batch inputs if necessary kv_config.max_batch_size = 1 - kv_config.max_seq_len = args.kv_seq_len + kv_config.max_seq_len = args.max_seq_len kv_config.use_kv_cache = True prefill_config = copy.copy(kv_config) - prefill_config.max_seq_len = args.prefill_seq_len - prefill_config.use_kv_cache = False + prefill_config.max_seq_len = args.max_seq_len + prefill_config.use_kv_cache = ( + False if args.max_seq_len == args.prefill_ar_len else True + ) state_dict = torch.load( args.checkpoint, weights_only=True, map_location="cpu", mmap=True @@ -484,27 +507,29 @@ def compile(args, pte_filename, tokenizer): if args.model_mode == "kv": llama_instance_list.append( LlamaModel( - kv_config, output_new_cache_only=True, use_i64_token=use_i64_token - ) - ) - elif args.model_mode == "prefill": - llama_instance_list.append( - LlamaModel( - prefill_config, - output_new_cache_only=False, + kv_config, + ar_len=1, + output_new_cache_only=True, + output_cache=True, use_i64_token=use_i64_token, ) ) elif args.model_mode == "hybrid": llama_instance_list.append( LlamaModel( - kv_config, output_new_cache_only=True, use_i64_token=use_i64_token + kv_config, + ar_len=1, + output_new_cache_only=True, + output_cache=True, + use_i64_token=use_i64_token, ) ) llama_instance_list.append( LlamaModel( prefill_config, - output_new_cache_only=False, + ar_len=args.prefill_ar_len, + output_new_cache_only=True, + output_cache=True, use_i64_token=use_i64_token, ) ) @@ -606,7 +631,7 @@ def compile(args, pte_filename, tokenizer): start_lowering_ts = time.time() quant_attrs = None - if args.model_mode in ["kv", "prefill"]: + if args.model_mode in ["kv"]: llama_instance_list[0].lowering_modules( args.artifact, fixed_point_type, @@ -783,12 +808,10 @@ def compile(args, pte_filename, tokenizer): def inference(args, quant_attrs, pte_filename, runtime_tokenizer_path, pre_gen_pte=""): workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama" - if args.model_mode == "prefill": + if args.model_mode == "kv": eval_mode = 0 - elif args.model_mode == "kv": - eval_mode = 1 elif args.model_mode == "hybrid": - eval_mode = 2 + eval_mode = 1 else: raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") @@ -807,7 +830,7 @@ def post_process(): with open(f"{args.artifact}/outputs/outputs.txt", "r") as f: outputs.append(f.read()) - seq_len = args.prefill_seq_len if args.model_mode == "prefill" else args.kv_seq_len + seq_len = args.max_seq_len runner_args = " ".join( [ f'--prompt "{args.prompt}"', @@ -824,9 +847,9 @@ def post_process(): # x86 emulator is intended for CI and not performance. Check only the first few tokens. seq_len = min(seq_len, 16) - if args.kv_updator == smart_mask_updator: + if args.kv_updater == smart_mask_updater: logging.warning( - "x86 only support ShiftPointer, overwrite kv_updator to ShiftPointer" + "x86 only support ShiftPointer, overwrite kv_updater to ShiftPointer" ) qnn_sdk = os.getenv("QNN_SDK_ROOT") @@ -839,7 +862,7 @@ def post_process(): f"--model_path {pte_path}", f"--seq_len {seq_len}", f"--output_path {args.artifact}/outputs/outputs.txt", - f"--kv_updator ShiftPointer", + f"--kv_updater ShiftPointer", runner_args, ] ) @@ -859,7 +882,7 @@ def post_process(): f"--model_path {pte_filename}.pte", f"--seq_len {seq_len}", "--output_path outputs/outputs.txt", - f"--kv_updator {'SmartMask' if args.kv_updator == smart_mask_updator else 'ShiftPointer'}", + f"--kv_updater {'SmartMask' if args.kv_updater == smart_mask_updater else 'ShiftPointer'}", runner_args, ] ) @@ -998,28 +1021,28 @@ def _build_parser(): parser.add_argument( "--model_mode", - help="Export and inference prefill mode, kv mode or hybrid mode", + help="Export and inference kv mode or hybrid mode", default="kv", - choices=["prefill", "kv", "hybrid"], + choices=["kv", "hybrid"], type=str, ) parser.add_argument( - "--prefill_seq_len", - help="Ouput sequence length for llama. Use this option for prefill or hybrid mode", - default=32, + "--max_seq_len", + help="This refers to maximum number of tokens that the model can process & consider at once to generate predictions/responses.", + default=512, type=int, ) parser.add_argument( - "--kv_seq_len", - help="Ouput sequence length for llama. Use this option for kv or hybrid mode", - default=512, + "--prefill_ar_len", + help="The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use this option to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor for hybrid mode.", + default=32, type=int, ) parser.add_argument( - "--kv_updator", + "--kv_updater", help="Choose how to update kv cache during runtime", choices=["smart_mask", "shift_pointer"], default="smart_mask", @@ -1045,12 +1068,10 @@ def export_llama(args) -> None: if args.model_mode == "kv": pte_filename = "kv_llama_qnn" - elif args.model_mode == "prefill": - pte_filename = "prefill_llama_qnn" elif args.model_mode == "hybrid": assert ( - args.kv_seq_len >= args.prefill_seq_len - ), "Please ensure kv_seq_len is >= prefill_seq_len" + args.max_seq_len >= args.prefill_ar_len + ), "Please ensure max_seq_len is >= prefill_ar_len" pte_filename = "hybrid_llama_qnn" else: raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") @@ -1073,13 +1094,13 @@ def export_llama(args) -> None: else: raise RuntimeError(f"Unknown llama_model: {args.llama_model}.") - if args.kv_updator == "smart_mask": + if args.kv_updater == "smart_mask": args.shared_buffer = True - args.kv_updator = smart_mask_updator - elif args.kv_updator == "shift_pointer": - args.kv_updator = shift_pointer_updator + args.kv_updater = smart_mask_updater + elif args.kv_updater == "shift_pointer": + args.kv_updater = shift_pointer_updater else: - exit(f"Using an unkown kv update {args.kv_updator}") + exit(f"Using an unkown kv update {args.kv_updater}") if args.pre_gen_pte: quant_attrs = json.load( diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index 253abc9578c..09cc7504224 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -153,10 +153,7 @@ def forward_sha( y = y.reshape(bsz, seq_len, -1) if self.output_new_cache_only: - if k_caches and v_caches: - return y, k, v - # batch_prefill mode. Consider to remove, it's not really used - return y, k[-1], v[-1] + return y, k, v return y, kh, vh @@ -298,7 +295,12 @@ def forward( class LlamaModel(nn.Module): def __init__( - self, config: ModelArgs, output_new_cache_only=True, use_i64_token=False + self, + config: ModelArgs, + ar_len=1, + output_new_cache_only=True, + output_cache=True, + use_i64_token=False, ): super().__init__() self.dim = config.dim @@ -311,8 +313,10 @@ def __init__( self.vocab_size = config.vocab_size self.rope_freq_base = config.rope_freq_base self.use_kv_cache = config.use_kv_cache + self.ar_len = ar_len self.output_new_cache_only = output_new_cache_only self.use_i64_token = use_i64_token + self.output_cache = output_cache self.layers = nn.ModuleList( [ @@ -359,10 +363,10 @@ def forward( output_v_cache = [] # following tensors should be invariant across batches freqs_cos = ( - self.freqs_cos[input_pos][0] if self.use_kv_cache else self.freqs_cos[:-1] + self.freqs_cos[input_pos][0] if self.use_kv_cache else self.freqs_cos ) freqs_sin = ( - self.freqs_sin[input_pos][0] if self.use_kv_cache else self.freqs_sin[:-1] + self.freqs_sin[input_pos][0] if self.use_kv_cache else self.freqs_sin ) hidden_states = self.tok_embeddings(tokens) @@ -388,19 +392,36 @@ def forward( hidden_states = self.norm(hidden_states) logits = self.output(hidden_states) - return logits, output_k_cache, output_v_cache + if self.output_cache: + return logits, output_k_cache, output_v_cache + return logits def get_example_inputs(self, use_kv_cache=True): dtype = torch.int64 if self.use_i64_token else torch.int32 - if use_kv_cache: - tokens = torch.randint( - self.vocab_size, (self.max_batch_size, 1), dtype=dtype - ) + tokens = torch.randint( + self.vocab_size, (self.max_batch_size, self.ar_len), dtype=dtype + ) - pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32) + atten_mask = torch.full((self.ar_len, self.ar_len), torch.tensor(-255.0)) + mask_cond = torch.arange(atten_mask.size(-1)) + atten_mask.masked_fill_( + mask_cond < (mask_cond + 1).view(atten_mask.size(-1), 1), 0 + ) + if self.max_seq_len != self.ar_len: + atten_mask = torch.cat( + [ + torch.ones(self.ar_len, self.max_seq_len - self.ar_len) * -255.0, + atten_mask, + ], + dim=-1, + ) + atten_mask = atten_mask[None, :, :].expand( + self.max_batch_size, self.ar_len, self.max_seq_len + ) + if use_kv_cache: + pos_ids = torch.zeros((self.max_batch_size, self.ar_len), dtype=torch.int32) k_cache, v_cache = [], [] - atten_mask = torch.full((self.max_batch_size, self.max_seq_len), -255.0) - atten_mask[:, -1] = 0 + for _ in range(self.n_layers): for _ in range(self.n_kv_heads): # transpose first to decrease the runtime efforts @@ -408,13 +429,13 @@ def get_example_inputs(self, use_kv_cache=True): torch.zeros( self.max_batch_size, self.head_dim, - self.max_seq_len - 1, + self.max_seq_len - self.ar_len, ) ) v_cache.append( torch.zeros( self.max_batch_size, - self.max_seq_len - 1, + self.max_seq_len - self.ar_len, self.head_dim, ) ) @@ -426,10 +447,6 @@ def get_example_inputs(self, use_kv_cache=True): v_cache, ) - max_promp = self.max_seq_len - 1 - tokens = torch.arange(0, max_promp, 1, dtype=dtype).unsqueeze(0) - atten_mask = torch.triu(torch.rand((max_promp, max_promp)), 1) - atten_mask[atten_mask != 0] = -255 return ( tokens, atten_mask, @@ -438,6 +455,7 @@ def get_example_inputs(self, use_kv_cache=True): def get_metadata(self): # TODO: modify this when enabling LLAMA 7B return { + "get_ar_len": self.ar_len, "get_bos_id": 1, "get_eos_id": 2, "get_dim": self.dim, diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index 1bc90a11f9d..0a1635223e6 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -48,11 +48,11 @@ DEFINE_int32( DEFINE_int32( eval_mode, 1, - "0: PromptProcessor(prefill) / 1: TokenGenerator(kv) / 2: HybridMode (prefill+kv)"); + "0: TokenGenerator(kv) / 1: HybridMode (prefill+kv)"); DEFINE_double(logits_scale, 0.0, "Logits scale"); DEFINE_int32(logits_offset, 0, "Logits offset"); DEFINE_string( - kv_updator, + kv_updater, "How to update kv cache. Choose between SmartMask and ShiftPointer", "SmartMask"); @@ -67,7 +67,7 @@ int main(int argc, char** argv) { FLAGS_logits_offset, FLAGS_temperature, FLAGS_eval_mode, - FLAGS_kv_updator); + FLAGS_kv_updater); std::vector buf; buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char std::ofstream fout(FLAGS_output_path.c_str()); diff --git a/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp b/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp index badaea0ca73..cfa3b392894 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp @@ -54,7 +54,10 @@ std::vector IoMgrBase::get_output_tensors( ShiftPointerIoMgr::ShiftPointerIoMgr( std::vector>& modules, + int32_t context_len, + int32_t prefill_ar_len, int32_t prefill_cache_len, + int32_t kv_ar_len, int32_t kv_cache_len, int32_t vocab_size, int32_t num_layers, @@ -66,7 +69,10 @@ ShiftPointerIoMgr::ShiftPointerIoMgr( const bool use_int64_token) : IoMgrBase(modules), shard_layers_({num_layers}), + context_len_(context_len), + kv_ar_len_(kv_ar_len), kv_cache_len_(kv_cache_len), + prefill_ar_len_(prefill_ar_len), prefill_cache_len_(prefill_cache_len), vocab_size_(vocab_size), num_layers_(num_layers), @@ -75,7 +81,8 @@ ShiftPointerIoMgr::ShiftPointerIoMgr( eval_mode_(eval_mode), prefill_forward_name_(prefill_forward_name), kv_forward_name_(kv_forward_name), - use_int64_token_(use_int64_token) { + use_int64_token_(use_int64_token), + is_bert_(prefill_cache_len_ == 0) { if (!prefill_forward_name_.empty()) { input_tensors_[prefill_forward_name_] = std::vector>(modules.size()); @@ -113,15 +120,14 @@ void ShiftPointerIoMgr::init_io() { IO* ptr = static_cast(data_ptr_.get()); std::memset(ptr, 0, sizeof(IO)); - int32_t max_cache_len = std::max(kv_cache_len_, prefill_cache_len_); - int32_t k_in_size = (head_dim_ + 1) * max_cache_len; - int32_t v_cache_size = (num_heads_ + 1) * max_cache_len * head_dim_; - int32_t k_cache_out_size = num_heads_ * head_dim_; - if (eval_mode_ == EvalMode::kHybrid || eval_mode_ == EvalMode::kPrefill) { - k_cache_out_size *= prefill_cache_len_; - } + int32_t max_ar_len = std::max(kv_ar_len_, prefill_ar_len_); + int32_t k_in_size = (head_dim_ + 1) * kv_cache_len_; + // Use context length to prevent exceeding the range when the AR-N model + // updates the last block in hybrid mode. + int32_t v_cache_size = (num_heads_ + 1) * context_len_ * head_dim_; + int32_t k_cache_out_size = num_heads_ * max_ar_len * head_dim_; - // Init kv vector shape, general enough to be shared across all 3 modes. + // Init kv vector shape, general enough to be shared across all modes. ptr->k_cache_out.reserve(num_layers_); ptr->v_cache.reserve(num_layers_); for (int layer = 0; layer < num_layers_; layer++) { @@ -130,14 +136,15 @@ void ShiftPointerIoMgr::init_io() { } auto init_prefill = [&]() { - ptr->prefill_input_toks.resize(prefill_cache_len_); - ptr->prefill_atten_mask.resize(prefill_cache_len_ * prefill_cache_len_); - ptr->prefill_logits.resize(prefill_cache_len_ * vocab_size_); + ptr->prefill_input_toks.resize(prefill_ar_len_, 0); + ptr->prefill_input_pos.resize(prefill_ar_len_, 0); + ptr->prefill_attention_mask.resize((prefill_ar_len_ * context_len_), 0); + ptr->prefill_logits.resize(prefill_ar_len_ * vocab_size_); }; auto init_kv = [&]() { - ptr->kv_logits.resize(vocab_size_); - ptr->kv_attention_mask.resize((kv_cache_len_ + 1), 0); + ptr->kv_logits.resize(kv_ar_len_ * vocab_size_); + ptr->kv_attention_mask.resize((kv_ar_len_ * context_len_), 0); ptr->k_cache.reserve(num_layers_); for (int layer = 0; layer < num_layers_; layer++) { ptr->k_cache.emplace_back(); @@ -149,9 +156,6 @@ void ShiftPointerIoMgr::init_io() { }; switch (eval_mode_) { - case EvalMode::kPrefill: - init_prefill(); - break; case EvalMode::kKVCached: init_kv(); break; @@ -177,37 +181,38 @@ void ShiftPointerIoMgr::prepare_kv_io( IO* ptr = static_cast(data_ptr_.get()); // [I]: input_tokens - Result input_tok = methods_meta[0]->input_tensor_meta(0); - input_tok_ = std::make_unique( - input_tok->scalar_type(), - input_tok->sizes().size(), - const_cast(input_tok->sizes().data()), - &ptr->input_tok, - const_cast(input_tok->dim_order().data())); - input_tensors_[kv_forward_name_][0].push_back(input_tok_.get()); + Result kv_input_toks = methods_meta[0]->input_tensor_meta(0); + kv_input_toks_ = std::make_unique( + kv_input_toks->scalar_type(), + kv_input_toks->sizes().size(), + const_cast(kv_input_toks->sizes().data()), + &ptr->kv_input_toks, + const_cast(kv_input_toks->dim_order().data())); + input_tensors_[kv_forward_name_][0].push_back(kv_input_toks_.get()); // [I]: atten_mask - Result atten_mask = methods_meta[0]->input_tensor_meta(1); - attention_mask_ = std::make_unique( - atten_mask->scalar_type(), - atten_mask->sizes().size(), - const_cast(atten_mask->sizes().data()), + Result kv_attention_mask = methods_meta[0]->input_tensor_meta(1); + kv_attention_mask_ = std::make_unique( + kv_attention_mask->scalar_type(), + kv_attention_mask->sizes().size(), + const_cast(kv_attention_mask->sizes().data()), ptr->kv_attention_mask.data(), - const_cast(atten_mask->dim_order().data())); - input_tensors_[kv_forward_name_][0].push_back(attention_mask_.get()); + const_cast( + kv_attention_mask->dim_order().data())); + input_tensors_[kv_forward_name_][0].push_back(kv_attention_mask_.get()); // [I]: input_pos - Result input_pos = methods_meta[0]->input_tensor_meta(2); - input_pos_ = std::make_unique( - input_pos->scalar_type(), - input_pos->sizes().size(), - const_cast(input_pos->sizes().data()), - &ptr->input_pos, - const_cast(input_pos->dim_order().data())); - input_tensors_[kv_forward_name_][0].push_back(input_pos_.get()); + Result kv_input_pos = methods_meta[0]->input_tensor_meta(2); + kv_input_pos_ = std::make_unique( + kv_input_pos->scalar_type(), + kv_input_pos->sizes().size(), + const_cast(kv_input_pos->sizes().data()), + &ptr->kv_input_pos, + const_cast(kv_input_pos->dim_order().data())); + input_tensors_[kv_forward_name_][0].push_back(kv_input_pos_.get()); // [I] kv_cache - int index = 3; // bypass input_tokens, input_pos, atten_mask + int index = 3; // bypass input_tokens, atten_mask, input_pos for (int offset = 0, shard_index = 0, v_stride = kv_cache_len_ * head_dim_; shard_index < modules_.size(); offset += shard_layers_[shard_index], shard_index++) { @@ -304,7 +309,7 @@ void ShiftPointerIoMgr::prepare_prefill_io( IO* ptr = static_cast(data_ptr_.get()); - // [I]: pre_input_tokens + // [I]: prefill_input_tokens Result prefill_input_toks = methods_meta[0]->input_tensor_meta(0); prefill_input_toks_ = std::make_unique( prefill_input_toks->scalar_type(), @@ -314,25 +319,81 @@ void ShiftPointerIoMgr::prepare_prefill_io( const_cast( prefill_input_toks->dim_order().data())); input_tensors_[prefill_forward_name_][0].push_back(prefill_input_toks_.get()); - // [I]: prefill_attn_mask - for (int i = 0; i < prefill_cache_len_; ++i) { - for (int j = 0; j < prefill_cache_len_; ++j) { - if (i < j) { - ptr->prefill_atten_mask[i * prefill_cache_len_ + j] = 0; - } else { - ptr->prefill_atten_mask[i * prefill_cache_len_ + j] = 65535; + // [I]: prefill_attention_mask + for (int i = 0; i < prefill_ar_len_; ++i) { + for (int j = 0, + offset = i * context_len_ + (context_len_ - prefill_ar_len_); + j < prefill_ar_len_; + ++j) { + if (i >= j) { + ptr->prefill_attention_mask[j + offset] = 65535; } } } - Result prefill_atten_mask = methods_meta[0]->input_tensor_meta(1); - prefill_attn_mask_ = std::make_unique( - prefill_atten_mask->scalar_type(), - prefill_atten_mask->sizes().size(), - const_cast(prefill_atten_mask->sizes().data()), - ptr->prefill_atten_mask.data(), + Result prefill_attention_mask = + methods_meta[0]->input_tensor_meta(1); + prefill_attention_mask_ = std::make_unique( + prefill_attention_mask->scalar_type(), + prefill_attention_mask->sizes().size(), + const_cast( + prefill_attention_mask->sizes().data()), + ptr->prefill_attention_mask.data(), const_cast( - prefill_atten_mask->dim_order().data())); - input_tensors_[prefill_forward_name_][0].push_back(prefill_attn_mask_.get()); + prefill_attention_mask->dim_order().data())); + input_tensors_[prefill_forward_name_][0].push_back( + prefill_attention_mask_.get()); + + if (!is_bert_) { + // [I]: prefill_input_pos + Result prefill_input_pos = + methods_meta[0]->input_tensor_meta(2); + prefill_input_pos_ = std::make_unique( + prefill_input_pos->scalar_type(), + prefill_input_pos->sizes().size(), + const_cast(prefill_input_pos->sizes().data()), + ptr->prefill_input_pos.data(), + const_cast( + prefill_input_pos->dim_order().data())); + input_tensors_[prefill_forward_name_][0].push_back( + prefill_input_pos_.get()); + + // [I] kv_cache + int index = 3; // bypass input_tokens, atten_mask, input_pos + // Add prefill offset to align the v_out pointer with the decode model. + for (int offset = 0, + shard_index = 0, + v_stride = kv_cache_len_ * head_dim_, + prefill_offset = (kv_cache_len_ - prefill_cache_len_) * head_dim_; + shard_index < modules_.size(); + offset += shard_layers_[shard_index], shard_index++) { + for (int cache_group = 0; cache_group < 2; ++cache_group) { + for (int layer = 0; layer < shard_layers_[shard_index]; ++layer) { + for (int head = 0; head < num_heads_; ++head, ++index) { + Result kv_cache = + methods_meta[shard_index]->input_tensor_meta(index); + std::vector>& cache = + (cache_group == 0 ? k_cache_in_[prefill_forward_name_] + : v_cache_in_[prefill_forward_name_]); + void* cache_ptr = (cache_group == 0) + ? static_cast(ptr->k_cache[layer + offset][head].data()) + : static_cast( + ptr->v_cache[layer + offset].data() + head * v_stride + + prefill_offset); + + cache.emplace_back(std::make_unique( + kv_cache->scalar_type(), + kv_cache->sizes().size(), + const_cast(kv_cache->sizes().data()), + cache_ptr, + const_cast( + kv_cache->dim_order().data()))); + input_tensors_[prefill_forward_name_][shard_index].push_back( + cache.back().get()); + } + } + } + } + } // [O]: logits int logit_index = 0; Result logits = @@ -348,18 +409,11 @@ void ShiftPointerIoMgr::prepare_prefill_io( // [O] kv_cache int index = 1; - // prefill_k_stride should be equal to prefill_v_stride in prefill mode. // In hybrid mode, we use kv mode cache len for v stride since we want to // update prefill's result onto kv modes input. - int32_t prefill_k_stride = prefill_cache_len_ * head_dim_; - int32_t prefill_v_stride = - std::max(prefill_cache_len_, kv_cache_len_) * head_dim_; + int32_t prefill_k_stride = prefill_ar_len_ * head_dim_; + int32_t prefill_v_stride = kv_cache_len_ * head_dim_; - if (eval_mode_ == EvalMode::kPrefill) { - ET_CHECK_MSG( - prefill_k_stride == prefill_v_stride, - "prefill_k_stride should be equal to prefill_v_stride"); - } for (int offset = 0, shard_index = 0; shard_index < modules_.size(); offset += shard_layers_[shard_index], shard_index++) { for (int cache_group = 0; cache_group < 2; ++cache_group) { @@ -397,13 +451,11 @@ void ShiftPointerIoMgr::update_prefill_to_kv_io( int64_t pos, std::vector>& output_tensors) { ET_CHECK_MSG(kv_cache_len_ != 0, "k_cache_len_ should not equal to 0"); - ET_CHECK_MSG( - prefill_cache_len_ != 0, "prefill_cache_len_ should not equal to 0"); IO* ptr = static_cast(data_ptr_.get()); - ptr->input_tok = + ptr->kv_input_toks = use_int64_token_ ? cur_token : static_cast(cur_token); - ptr->input_pos = static_cast(pos); + ptr->kv_input_pos = static_cast(pos); // If prompt len is 30, prefill will handle to pos = 30. // At this point, pos should be 31. for (int i = 0; i < pos + 1; i++) { @@ -435,17 +487,29 @@ void ShiftPointerIoMgr::update_prefill_to_kv_io( } } + // Update k_cache std::vector>& k_cache_in = k_cache_in_[kv_forward_name_]; std::vector>& k_cache_out = k_cache_out_[prefill_forward_name_]; + // copy from last to prevent from overwriting values + size_t copied_size = pos * sizeof(uint8_t); for (int i = 0; i < k_cache_in.size(); ++i) { uint8_t* ptr_in = k_cache_in[i]->mutable_data(); - const uint8_t* ptr_out = k_cache_out[i]->data(); - for (size_t j = 0, offset = kv_cache_len_; j < head_dim_; - ++j, offset += kv_cache_len_) { - for (int k = 0, k_stride = j * prefill_cache_len_; k < pos; k++) { - ptr_in[offset + k] = ptr_out[k_stride + k]; + if (is_bert_) { + const uint8_t* ptr_out = k_cache_out[i]->data(); + for (size_t j = 0, offset = kv_cache_len_; j < head_dim_; + ++j, offset += kv_cache_len_) { + for (int k = 0, k_stride = j * prefill_ar_len_; k < pos; k++) { + ptr_in[offset + k] = ptr_out[k_stride + k]; + } + } + } else { + for (int j = head_dim_; j > -1; --j) { + memcpy( + ptr_in + j * kv_cache_len_, + ptr_in + j * prefill_cache_len_, + copied_size); } } k_cache_in[i]->set_data(ptr_in + pos); @@ -458,10 +522,10 @@ void ShiftPointerIoMgr::update_kv_io( std::vector>& output_tensors) { IO* ptr = static_cast(data_ptr_.get()); // update input_tok - ptr->input_tok = + ptr->kv_input_toks = use_int64_token_ ? cur_token : static_cast(cur_token); // update position_ids - ptr->input_pos = static_cast(pos); + ptr->kv_input_pos = static_cast(pos); // update causal mask for next token ptr->kv_attention_mask[kv_cache_len_ - pos] = 65535; @@ -505,47 +569,101 @@ void ShiftPointerIoMgr::update_prefill_io( int64_t cur_token, int64_t pos, std::vector>& output_tensors) { + (void)cur_token; (void)output_tensors; - IO* ptr = static_cast(data_ptr_.get()); - // Support CPU 4-bit embedding, which requires int64 input. - // However, for QNN embedding, only int32 input is needed. - // Therefore, we need to cast to the correct type to write the data. - if (use_int64_token_) { - ptr->prefill_input_toks[pos] = cur_token; - } else { - int32_t* prefill_input_toks_ptr = - reinterpret_cast(ptr->prefill_input_toks.data()); - prefill_input_toks_ptr[pos] = static_cast(cur_token); + + if (!is_bert_) { + // update v_cache + auto& v_cache_in = v_cache_in_[prefill_forward_name_]; + auto& v_cache_out = v_cache_out_[prefill_forward_name_]; + for (int i = 0; i < v_cache_in.size(); i++) { + v_cache_in[i]->set_data( + v_cache_in[i]->mutable_data() + prefill_ar_len_ * head_dim_); + v_cache_out[i]->set_data( + v_cache_out[i]->mutable_data() + + prefill_ar_len_ * head_dim_); + } + + for (int shard = 0; shard < output_tensors.size(); shard++) { + for (int index = 0; index < output_tensors[shard].size(); index++) { + ET_CHECK_MSG( + modules_[shard]->set_output( + prefill_forward_name_, output_tensors[shard][index], index) == + Error::Ok, + "failed to set output tensor for module %d's %d'th output " + "while updating kv_cache output tensors", + shard, + index); + } + } + + auto& k_cache_in = k_cache_in_[prefill_forward_name_]; + auto& k_cache_out = k_cache_out_[prefill_forward_name_]; + // update k_cache by single thread, this part is cpu cache sensitive + for (int i = 0; i < k_cache_in.size(); ++i) { + uint8_t* ptr_in = k_cache_in[i]->mutable_data(); + const uint8_t* ptr_out = k_cache_out[i]->data(); + for (size_t j = 0, offset = prefill_cache_len_; j < head_dim_; + ++j, offset += prefill_cache_len_) { + for (int k = 0, k_stride = j * prefill_ar_len_; k < prefill_ar_len_; + k++) { + ptr_in[offset + k] = ptr_out[k_stride + k]; + } + } + k_cache_in[i]->set_data(ptr_in + prefill_ar_len_); + } } } void ShiftPointerIoMgr::fill_prefill_toks( + int64_t start_pos, std::vector& prompt_tokens) { IO* ptr = static_cast(get_mutable_ptr()); - for (int i = 0; i < prompt_tokens.size(); i++) { - // Support CPU 4-bit embedding, which requires int64 input. - // However, for QNN embedding, only int32 input is needed. - // Therefore, we need to cast to the correct type to write the data. - if (use_int64_token_) { - ptr->prefill_input_toks[i] = prompt_tokens[i]; - } else { - int32_t* prefill_input_toks_ptr = - reinterpret_cast(ptr->prefill_input_toks.data()); - prefill_input_toks_ptr[i] = static_cast(prompt_tokens[i]); + for (int i = 0; i < prefill_ar_len_; i++) { + if (!is_bert_) { + ptr->prefill_input_pos[i] = start_pos + i; + } + + if (start_pos + i < prompt_tokens.size()) { + // Support CPU 4-bit embedding, which requires int64 input. + // However, for QNN embedding, only int32 input is needed. + // Therefore, we need to cast to the correct type to write the data. + if (use_int64_token_) { + ptr->prefill_input_toks[i] = prompt_tokens[start_pos + i]; + } else { + int32_t* prefill_input_toks_ptr = + reinterpret_cast(ptr->prefill_input_toks.data()); + prefill_input_toks_ptr[i] = + static_cast(prompt_tokens[start_pos + i]); + } + } + if (start_pos >= prefill_ar_len_) { + for (int j = 0, + offset = i * context_len_ + + (context_len_ - prefill_ar_len_ - start_pos); + j < prefill_ar_len_; + ++j) { + ptr->prefill_attention_mask[offset + j] = 65535; + } } } } void ShiftPointerIoMgr::fill_kv_tok_mask(int64_t pos, int64_t cur_token) { IO* ptr = static_cast(get_mutable_ptr()); - ptr->input_tok = + ptr->kv_input_toks = use_int64_token_ ? cur_token : static_cast(cur_token); + ptr->kv_input_pos = static_cast(pos); + ; ptr->kv_attention_mask[kv_cache_len_] = 65535; } SmartMaskIoMgr::SmartMaskIoMgr( std::vector>& modules, + int32_t context_len, + int32_t prefill_ar_len, int32_t prefill_cache_len, + int32_t kv_ar_len, int32_t kv_cache_len, int32_t vocab_size, int32_t num_layers, @@ -557,7 +675,10 @@ SmartMaskIoMgr::SmartMaskIoMgr( const bool use_int64_token) : IoMgrBase(modules), shard_layers_({num_layers}), + context_len_(context_len), + kv_ar_len_(kv_ar_len), kv_cache_len_(kv_cache_len), + prefill_ar_len_(prefill_ar_len), prefill_cache_len_(prefill_cache_len), vocab_size_(vocab_size), num_layers_(num_layers), @@ -566,12 +687,17 @@ SmartMaskIoMgr::SmartMaskIoMgr( eval_mode_(eval_mode), prefill_forward_name_(prefill_forward_name), kv_forward_name_(kv_forward_name), - use_int64_token_(use_int64_token) { + use_int64_token_(use_int64_token), + is_bert_(prefill_cache_len == 0) { if (!prefill_forward_name_.empty()) { input_tensors_[prefill_forward_name_] = std::vector>(modules.size()); output_tensors_[prefill_forward_name_] = std::vector>(modules.size()); + k_cache_in_[prefill_forward_name_] = + std::vector>(); + v_cache_in_[prefill_forward_name_] = + std::vector>(); k_cache_out_[prefill_forward_name_] = std::vector>(); v_cache_out_[prefill_forward_name_] = @@ -597,20 +723,20 @@ SmartMaskIoMgr::SmartMaskIoMgr( } std::unordered_map SmartMaskIoMgr::get_io_elements() { - size_t cache_len = std::max(kv_cache_len_, prefill_cache_len_); - size_t cache_in_ele = num_layers_ * num_heads_ * head_dim_ * cache_len; - size_t cache_out_ele = num_layers_ * num_heads_ * head_dim_; + int32_t max_ar_len = std::max(kv_ar_len_, prefill_ar_len_); + size_t cache_in_ele = num_layers_ * num_heads_ * head_dim_ * kv_cache_len_; + size_t cache_out_ele = num_layers_ * num_heads_ * head_dim_ * max_ar_len; return std::unordered_map{ - {"input_tok_ele", 1}, - {"input_pos_ele", 1}, + {"kv_input_toks_ele", kv_ar_len_}, + {"kv_input_pos_ele", kv_ar_len_}, {"cache_in_ele", cache_in_ele}, {"cache_out_ele", cache_out_ele}, - // 1 for the input prompt - {"atten_mask_ele", cache_len + 1}, - {"kv_logits_ele", vocab_size_}, - {"prefill_input_toks_ele", prefill_cache_len_}, - {"prefill_atten_mask_ele", prefill_cache_len_ * prefill_cache_len_}, - {"prefill_logits_ele", prefill_cache_len_ * vocab_size_}}; + {"kv_attention_mask_ele", kv_ar_len_ * context_len_}, + {"kv_logits_ele", kv_ar_len_ * vocab_size_}, + {"prefill_input_toks_ele", prefill_ar_len_}, + {"prefill_input_pos_ele", prefill_ar_len_}, + {"prefill_attention_mask_ele", prefill_ar_len_ * context_len_}, + {"prefill_logits_ele", prefill_ar_len_ * vocab_size_}}; } std::unordered_map SmartMaskIoMgr::get_io_bytes() { @@ -623,21 +749,23 @@ std::unordered_map SmartMaskIoMgr::get_io_bytes() { byte % static_cast(alignment)); }; return std::unordered_map{ - {"input_tok_bytes", - align(element_map["input_tok_ele"] * sizeof(int32_t))}, - {"input_pos_bytes", - align(element_map["input_pos_ele"] * sizeof(int32_t))}, + {"kv_input_toks_bytes", + align(element_map["kv_input_toks_ele"] * sizeof(int32_t))}, + {"kv_input_pos_bytes", + align(element_map["kv_input_pos_ele"] * sizeof(int32_t))}, {"cache_in_bytes", align(element_map["cache_in_ele"] * sizeof(uint8_t))}, {"cache_out_bytes", align(element_map["cache_out_ele"] * sizeof(uint8_t))}, - {"atten_mask_bytes", - align(element_map["atten_mask_ele"] * sizeof(uint16_t))}, + {"kv_attention_mask_bytes", + align(element_map["kv_attention_mask_ele"] * sizeof(uint16_t))}, {"kv_logits_bytes", align(element_map["kv_logits_ele"] * sizeof(uint16_t))}, {"prefill_input_toks_bytes", align(element_map["prefill_input_toks_ele"] * sizeof(int32_t))}, - {"prefill_atten_mask_bytes", - align(element_map["prefill_atten_mask_ele"] * sizeof(uint16_t))}, + {"prefill_input_pos_bytes", + align(element_map["prefill_input_pos_ele"] * sizeof(int32_t))}, + {"prefill_attention_mask_bytes", + align(element_map["prefill_attention_mask_ele"] * sizeof(uint16_t))}, {"prefill_logits_bytes", align(element_map["prefill_logits_ele"] * sizeof(uint16_t))}}; } @@ -654,10 +782,10 @@ void SmartMaskIoMgr::IO::init_io_ptrs( for (const auto& iter : io_bytes_map) { std::string key = iter.first; size_t size = iter.second; - if (key == "input_tok_bytes") { - input_tok = reinterpret_cast(cur_ptr); - } else if (key == "input_pos_bytes") { - input_pos = reinterpret_cast(cur_ptr); + if (key == "kv_input_toks_bytes") { + kv_input_toks = reinterpret_cast(cur_ptr); + } else if (key == "kv_input_pos_bytes") { + kv_input_pos = reinterpret_cast(cur_ptr); } else if (key == "cache_in_bytes" || key == "cache_out_bytes") { auto& k_cache_ref = (key == "cache_in_bytes") ? k_cache : k_cache_out; auto& v_cache_ref = (key == "cache_in_bytes") ? v_cache : v_cache_out; @@ -679,14 +807,16 @@ void SmartMaskIoMgr::IO::init_io_ptrs( } } continue; - } else if (key == "atten_mask_bytes") { + } else if (key == "kv_attention_mask_bytes") { kv_attention_mask = reinterpret_cast(cur_ptr); } else if (key == "kv_logits_bytes") { kv_logits = reinterpret_cast(cur_ptr); } else if (key == "prefill_input_toks_bytes") { prefill_input_toks = reinterpret_cast(cur_ptr); - } else if (key == "prefill_atten_mask_bytes") { - prefill_atten_mask = reinterpret_cast(cur_ptr); + } else if (key == "prefill_input_pos_bytes") { + prefill_input_pos = reinterpret_cast(cur_ptr); + } else if (key == "prefill_attention_mask_bytes") { + prefill_attention_mask = reinterpret_cast(cur_ptr); } else if (key == "prefill_logits_bytes") { prefill_logits = reinterpret_cast(cur_ptr); } else { @@ -720,15 +850,10 @@ void SmartMaskIoMgr::init_io() { std::unordered_map io_bytes_map = get_io_bytes(); switch (eval_mode_) { - case EvalMode::kPrefill: - io_bytes_map.erase("input_tok_bytes"); - io_bytes_map.erase("input_pos_bytes"); - io_bytes_map.erase("atten_mask_bytes"); - io_bytes_map.erase("kv_logits_bytes"); - break; case EvalMode::kKVCached: io_bytes_map.erase("prefill_input_toks_bytes"); - io_bytes_map.erase("prefill_atten_mask_bytes"); + io_bytes_map.erase("prefill_input_pos_bytes"); + io_bytes_map.erase("prefill_attention_mask_bytes"); io_bytes_map.erase("prefill_logits_bytes"); break; case EvalMode::kHybrid: @@ -774,53 +899,55 @@ void SmartMaskIoMgr::prepare_kv_io( std::unordered_map io_bytes_map = get_io_bytes(); // [I]: input_tokens - Result input_tok = methods_meta[0]->input_tensor_meta(0); - input_tok_ = std::make_unique( - input_tok->scalar_type(), - input_tok->sizes().size(), - const_cast(input_tok->sizes().data()), - ptr->input_tok, - const_cast(input_tok->dim_order().data())); - input_tensors_[kv_forward_name_][0].push_back(input_tok_.get()); + Result kv_input_toks = methods_meta[0]->input_tensor_meta(0); + kv_input_toks_ = std::make_unique( + kv_input_toks->scalar_type(), + kv_input_toks->sizes().size(), + const_cast(kv_input_toks->sizes().data()), + ptr->kv_input_toks, + const_cast(kv_input_toks->dim_order().data())); + input_tensors_[kv_forward_name_][0].push_back(kv_input_toks_.get()); ptr->add_custom_mem_info( - ptr->input_tok, - io_bytes_map["input_tok_bytes"], - input_tok->scalar_type(), - input_tok.get()); + ptr->kv_input_toks, + io_bytes_map["kv_input_toks_bytes"], + kv_input_toks->scalar_type(), + kv_input_toks.get()); // [I]: atten_mask - Result atten_mask = methods_meta[0]->input_tensor_meta(1); - attention_mask_ = std::make_unique( - atten_mask->scalar_type(), - atten_mask->sizes().size(), - const_cast(atten_mask->sizes().data()), + std::fill_n(ptr->kv_attention_mask, kv_ar_len_ * context_len_, 0); + Result kv_attention_mask = methods_meta[0]->input_tensor_meta(1); + kv_attention_mask_ = std::make_unique( + kv_attention_mask->scalar_type(), + kv_attention_mask->sizes().size(), + const_cast(kv_attention_mask->sizes().data()), ptr->kv_attention_mask, - const_cast(atten_mask->dim_order().data())); - input_tensors_[kv_forward_name_][0].push_back(attention_mask_.get()); + const_cast( + kv_attention_mask->dim_order().data())); + input_tensors_[kv_forward_name_][0].push_back(kv_attention_mask_.get()); ptr->add_custom_mem_info( ptr->kv_attention_mask, - io_bytes_map["atten_mask_bytes"], - atten_mask->scalar_type(), - atten_mask.get()); + io_bytes_map["kv_attention_mask_bytes"], + kv_attention_mask->scalar_type(), + kv_attention_mask.get()); // [I]: input_pos - Result input_pos = methods_meta[0]->input_tensor_meta(2); - input_pos_ = std::make_unique( - input_pos->scalar_type(), - input_pos->sizes().size(), - const_cast(input_pos->sizes().data()), - ptr->input_pos, - const_cast(input_pos->dim_order().data())); - input_tensors_[kv_forward_name_][0].push_back(input_pos_.get()); + Result kv_input_pos = methods_meta[0]->input_tensor_meta(2); + kv_input_pos_ = std::make_unique( + kv_input_pos->scalar_type(), + kv_input_pos->sizes().size(), + const_cast(kv_input_pos->sizes().data()), + ptr->kv_input_pos, + const_cast(kv_input_pos->dim_order().data())); + input_tensors_[kv_forward_name_][0].push_back(kv_input_pos_.get()); ptr->add_custom_mem_info( - ptr->input_pos, - io_bytes_map["input_pos_bytes"], - input_pos->scalar_type(), - input_pos.get()); + ptr->kv_input_pos, + io_bytes_map["kv_input_pos_bytes"], + kv_input_pos->scalar_type(), + kv_input_pos.get()); // [I] kv_cache size_t layered_head_count = num_layers_ * num_heads_; - int index = 3; // bypass input_tokens, input_pos, atten_mask + int index = 3; // bypass input_tokens, atten_mask, input_pos for (int offset = 0, shard_index = 0; shard_index < modules_.size(); offset += shard_layers_[shard_index], shard_index++) { for (int cache_group = 0; cache_group < 2; ++cache_group) { @@ -913,12 +1040,11 @@ void SmartMaskIoMgr::update_kv_io( int64_t pos, std::vector>& output_tensors) { IO* ptr = static_cast(data_ptr_.get()); - size_t cache_len = std::max(kv_cache_len_, prefill_cache_len_); // update input_tok - *ptr->input_tok = + *ptr->kv_input_toks = use_int64_token_ ? cur_token : static_cast(cur_token); // update position_ids - *ptr->input_pos = static_cast(pos); + *ptr->kv_input_pos = static_cast(pos); // update smart mask for previous cache ptr->kv_attention_mask[pos] = 65535; @@ -937,7 +1063,8 @@ void SmartMaskIoMgr::update_kv_io( for (int i = 0; i < k_cache_in.size(); ++i) { uint8_t* ptr_in = k_cache_in[i]->mutable_data() + pos; const uint8_t* ptr_out = k_cache_out[i]->data(); - for (size_t j = 0, offset = 0; j < head_dim_; ++j, offset += cache_len) { + for (size_t j = 0, offset = 0; j < head_dim_; + ++j, offset += kv_cache_len_) { ptr_in[offset] = ptr_out[j]; } } @@ -958,7 +1085,6 @@ void SmartMaskIoMgr::prepare_prefill_io( IO* ptr = static_cast(data_ptr_.get()); std::unordered_map io_bytes_map = get_io_bytes(); - int32_t cache_len = methods_meta[0]->input_tensor_meta(0)->sizes()[1]; // [I]: pre_input_tokens Result prefill_input_toks = methods_meta[0]->input_tensor_meta(0); prefill_input_toks_ = std::make_unique( @@ -975,30 +1101,92 @@ void SmartMaskIoMgr::prepare_prefill_io( executorch::aten::ScalarType::Int, prefill_input_toks.get()); - // [I]: prefill_attn_mask - for (int i = 0; i < cache_len; ++i) { - for (int j = 0; j < cache_len; ++j) { + // [I]: prefill_attention_mask + for (int i = 0; i < prefill_ar_len_; ++i) { + for (int j = 0, + offset = i * context_len_ + (context_len_ - prefill_ar_len_); + j < prefill_ar_len_; + ++j) { if (i < j) { - ptr->prefill_atten_mask[i * cache_len + j] = 0; + ptr->prefill_attention_mask[j + offset] = 0; } else { - ptr->prefill_atten_mask[i * cache_len + j] = 65535; + ptr->prefill_attention_mask[j + offset] = 65535; } } } - Result prefill_atten_mask = methods_meta[0]->input_tensor_meta(1); - prefill_attn_mask_ = std::make_unique( - prefill_atten_mask->scalar_type(), - prefill_atten_mask->sizes().size(), - const_cast(prefill_atten_mask->sizes().data()), - ptr->prefill_atten_mask, + Result prefill_attention_mask = + methods_meta[0]->input_tensor_meta(1); + prefill_attention_mask_ = std::make_unique( + prefill_attention_mask->scalar_type(), + prefill_attention_mask->sizes().size(), + const_cast( + prefill_attention_mask->sizes().data()), + ptr->prefill_attention_mask, const_cast( - prefill_atten_mask->dim_order().data())); - input_tensors_[prefill_forward_name_][0].push_back(prefill_attn_mask_.get()); + prefill_attention_mask->dim_order().data())); + input_tensors_[prefill_forward_name_][0].push_back( + prefill_attention_mask_.get()); ptr->add_custom_mem_info( - ptr->prefill_atten_mask, - io_bytes_map["prefill_atten_mask_bytes"], + ptr->prefill_attention_mask, + io_bytes_map["prefill_attention_mask_bytes"], executorch::aten::ScalarType::Bits16, - prefill_atten_mask.get()); + prefill_attention_mask.get()); + + if (!is_bert_) { + // [I]: prefill_input_pos + Result prefill_input_pos = + methods_meta[0]->input_tensor_meta(2); + prefill_input_pos_ = std::make_unique( + prefill_input_pos->scalar_type(), + prefill_input_pos->sizes().size(), + const_cast(prefill_input_pos->sizes().data()), + ptr->prefill_input_pos, + const_cast( + prefill_input_pos->dim_order().data())); + input_tensors_[prefill_forward_name_][0].push_back( + prefill_input_pos_.get()); + ptr->add_custom_mem_info( + ptr->prefill_input_pos, + io_bytes_map["prefill_input_pos_bytes"], + prefill_input_pos->scalar_type(), + prefill_input_pos.get()); + + // [I] kv_cache + size_t layered_head_count = num_layers_ * num_heads_; + int index = 3; // bypass input_tokens, atten_mask, input_pos + for (int offset = 0, shard_index = 0; shard_index < modules_.size(); + offset += shard_layers_[shard_index], shard_index++) { + for (int cache_group = 0; cache_group < 2; ++cache_group) { + for (int layer = 0; layer < shard_layers_[shard_index]; ++layer) { + for (int head = 0; head < num_heads_; ++head, ++index) { + Result kv_cache = + methods_meta[shard_index]->input_tensor_meta(index); + std::vector>& cache = + (cache_group == 0 ? k_cache_in_[prefill_forward_name_] + : v_cache_in_[prefill_forward_name_]); + uint8_t* cache_ptr = (cache_group == 0) + ? ptr->k_cache[layer + offset][head] + : ptr->v_cache[layer + offset][head]; + + cache.emplace_back(std::make_unique( + kv_cache->scalar_type(), + kv_cache->sizes().size(), + const_cast(kv_cache->sizes().data()), + cache_ptr, + const_cast( + kv_cache->dim_order().data()))); + ptr->add_custom_mem_info( + cache_ptr, + io_bytes_map["cache_in_bytes"] / layered_head_count, + kv_cache->scalar_type(), + kv_cache.get()); + input_tensors_[prefill_forward_name_][shard_index].push_back( + cache.back().get()); + } + } + } + } + } // [O]: logits int logit_index = 0; @@ -1031,8 +1219,8 @@ void SmartMaskIoMgr::prepare_prefill_io( (cache_group == 0 ? k_cache_out_[prefill_forward_name_] : v_cache_out_[prefill_forward_name_]); void* cache_ptr = (cache_group == 0) - ? ptr->k_cache[layer + offset][head] - : ptr->v_cache[layer + offset][head]; + ? ptr->k_cache_out[layer + offset][head] + : ptr->v_cache_out[layer + offset][head]; cache.emplace_back(std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), @@ -1042,7 +1230,7 @@ void SmartMaskIoMgr::prepare_prefill_io( kv_cache->dim_order().data()))); ptr->add_custom_mem_info( cache_ptr, - io_bytes_map["cache_in_bytes"] / layered_head_count, + io_bytes_map["cache_out_bytes"] / layered_head_count, executorch::aten::ScalarType::Byte, kv_cache.get()); output_tensors_[prefill_forward_name_][shard_index].push_back( @@ -1059,24 +1247,50 @@ void SmartMaskIoMgr::update_prefill_to_kv_io( std::vector>& output_tensors) { IO* ptr = static_cast(data_ptr_.get()); - *ptr->input_tok = + *ptr->kv_input_toks = use_int64_token_ ? cur_token : static_cast(cur_token); - *ptr->input_pos = static_cast(pos); + *ptr->kv_input_pos = static_cast(pos); // pos means the cur_token pos for (int i = 0; i < pos; i++) { ptr->kv_attention_mask[i] = 65535; } - // Update K is enough, copy from last to prevent from overwriting values - size_t copied_size = prefill_cache_len_ * sizeof(uint8_t); - for (int l = 0; l < num_layers_; l++) { - for (int h = 0; h < num_heads_; h++) { - uint8_t* k_cache = ptr->k_cache[l][h]; - for (int hd = head_dim_ - 1; hd > -1; hd--) { - memcpy( - k_cache + (kv_cache_len_ * hd), - k_cache + (prefill_cache_len_ * hd), - copied_size); + if (is_bert_) { + // update v_cache + auto& v_cache_in = v_cache_in_[kv_forward_name_]; + auto& v_cache_out = v_cache_out_[prefill_forward_name_]; + // update v_cache by single thread, this part is cpu cache sensitive + size_t copied_size = kv_cache_len_ * head_dim_ * sizeof(uint8_t); + for (int i = 0; i < v_cache_in.size(); ++i) { + uint8_t* ptr_in = v_cache_in[i]->mutable_data(); + const uint8_t* ptr_out = v_cache_out[i]->data(); + memcpy(ptr_in, ptr_out, copied_size); + } + + auto& k_cache_in = k_cache_in_[kv_forward_name_]; + auto& k_cache_out = k_cache_out_[prefill_forward_name_]; + for (int i = 0; i < k_cache_in.size(); ++i) { + uint8_t* ptr_in = k_cache_in[i]->mutable_data(); + const uint8_t* ptr_out = k_cache_out[i]->data(); + for (size_t j = 0, offset = 0; j < head_dim_; + ++j, offset += kv_cache_len_) { + for (size_t k = 0, k_stride = j * prefill_ar_len_; k < pos; k++) { + ptr_in[offset + k] = ptr_out[k_stride + k]; + } + } + } + } else { + // Update K is enough, copy from last to prevent from overwriting values + size_t copied_size = pos * sizeof(uint8_t); + for (int l = 0; l < num_layers_; l++) { + for (int h = 0; h < num_heads_; h++) { + uint8_t* k_cache = ptr->k_cache[l][h]; + for (int hd = head_dim_ - 1; hd > -1; hd--) { + memcpy( + k_cache + (kv_cache_len_ * hd), + k_cache + (prefill_cache_len_ * hd), + copied_size); + } } } } @@ -1087,38 +1301,71 @@ void SmartMaskIoMgr::update_prefill_io( int64_t pos, std::vector>& output_tensors) { (void)output_tensors; - IO* ptr = static_cast(data_ptr_.get()); - // Support CPU 4-bit embedding, which requires int64 input. - // However, for QNN embedding, only int32 input is needed. - // Therefore, we need to cast to the correct type to write the data. - if (use_int64_token_) { - ptr->prefill_input_toks[pos] = cur_token; - } else { - int32_t* prefill_input_toks_ptr = - reinterpret_cast(ptr->prefill_input_toks); - prefill_input_toks_ptr[pos] = static_cast(cur_token); + + if (!is_bert_) { + // update v_cache + auto& v_cache_in = v_cache_in_[prefill_forward_name_]; + auto& v_cache_out = v_cache_out_[prefill_forward_name_]; + // update v_cache by single thread, this part is cpu cache sensitive + size_t copied_size = prefill_ar_len_ * head_dim_ * sizeof(uint8_t); + for (int i = 0; i < v_cache_in.size(); ++i) { + uint8_t* ptr_in = + v_cache_in[i]->mutable_data() + pos * head_dim_; + const uint8_t* ptr_out = v_cache_out[i]->data(); + memcpy(ptr_in, ptr_out, copied_size); + } + + auto& k_cache_in = k_cache_in_[prefill_forward_name_]; + auto& k_cache_out = k_cache_out_[prefill_forward_name_]; + for (int i = 0; i < k_cache_in.size(); ++i) { + uint8_t* ptr_in = k_cache_in[i]->mutable_data(); + const uint8_t* ptr_out = k_cache_out[i]->data(); + for (size_t j = 0, offset = pos; j < head_dim_; + ++j, offset += prefill_cache_len_) { + for (size_t k = 0, k_stride = j * prefill_ar_len_; k < prefill_ar_len_; + k++) { + ptr_in[offset + k] = ptr_out[k_stride + k]; + } + } + } } } -void SmartMaskIoMgr::fill_prefill_toks(std::vector& prompt_tokens) { +void SmartMaskIoMgr::fill_prefill_toks( + int64_t start_pos, + std::vector& prompt_tokens) { IO* ptr = static_cast(get_mutable_ptr()); - for (int i = 0; i < prompt_tokens.size(); i++) { - // Support CPU 4-bit embedding, which requires int64 input. - // However, for QNN embedding, only int32 input is needed. - // Therefore, we need to cast to the correct type to write the data. - if (use_int64_token_) { - ptr->prefill_input_toks[i] = prompt_tokens[i]; - } else { - int32_t* prefill_input_toks_ptr = - reinterpret_cast(ptr->prefill_input_toks); - prefill_input_toks_ptr[i] = static_cast(prompt_tokens[i]); + for (int i = 0; i < prefill_ar_len_; i++) { + if (!is_bert_) { + ptr->prefill_input_pos[i] = start_pos + i; + } + + if (start_pos + i < prompt_tokens.size()) { + // Support CPU 4-bit embedding, which requires int64 input. + // However, for QNN embedding, only int32 input is needed. + // Therefore, we need to cast to the correct type to write the data. + if (use_int64_token_) { + ptr->prefill_input_toks[i] = prompt_tokens[start_pos + i]; + } else { + int32_t* prefill_input_toks_ptr = + reinterpret_cast(ptr->prefill_input_toks); + prefill_input_toks_ptr[i] = + static_cast(prompt_tokens[start_pos + i]); + } + } + if (start_pos >= prefill_ar_len_) { + for (int j = 0, offset = i * context_len_ + (start_pos - prefill_ar_len_); + j < prefill_ar_len_; + ++j) { + ptr->prefill_attention_mask[offset + j] = 65535; + } } } } void SmartMaskIoMgr::fill_kv_tok_mask(int64_t pos, int64_t cur_token) { IO* ptr = static_cast(get_mutable_ptr()); - *ptr->input_tok = + *ptr->kv_input_toks = use_int64_token_ ? cur_token : static_cast(cur_token); ptr->kv_attention_mask[kv_cache_len_] = 65535; } diff --git a/examples/qualcomm/oss_scripts/llama/runner/io_manager.h b/examples/qualcomm/oss_scripts/llama/runner/io_manager.h index 3a59ab6924e..f1887b99280 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/io_manager.h +++ b/examples/qualcomm/oss_scripts/llama/runner/io_manager.h @@ -23,8 +23,7 @@ namespace example { enum EvalMode { - kPrefill = 0, - kKVCached, + kKVCached = 0, kHybrid, kUnsupported, }; @@ -42,7 +41,9 @@ class IoMgrBase { const std::vector< executorch::runtime::Result>& methods_meta) = 0; - virtual void fill_prefill_toks(std::vector& prompt_tokens) = 0; + virtual void fill_prefill_toks( + int64_t start_pos, + std::vector& prompt_tokens) = 0; virtual void fill_kv_tok_mask(int64_t pos, int64_t cur_token) = 0; virtual void update_prefill_to_kv_io( int64_t cur_token, @@ -81,7 +82,10 @@ class ShiftPointerIoMgr : public IoMgrBase { public: ShiftPointerIoMgr( std::vector>& modules, + int32_t context_len, + int32_t prefill_ar_len, int32_t prefill_cache_len, + int32_t kv_ar_len, int32_t kv_cache_len, int32_t vocab_size, int32_t num_layers, @@ -101,7 +105,9 @@ class ShiftPointerIoMgr : public IoMgrBase { const std::vector< executorch::runtime::Result>& methods_meta) override; - void fill_prefill_toks(std::vector& prompt_tokens) override; + void fill_prefill_toks( + int64_t start_pos, + std::vector& prompt_tokens) override; void fill_kv_tok_mask(int64_t pos, int64_t cur_token) override; void update_prefill_to_kv_io( int64_t cur_token, @@ -119,25 +125,26 @@ class ShiftPointerIoMgr : public IoMgrBase { std::vector>& output_tensors) override; struct IO { - int64_t input_tok; - int32_t input_pos; + int64_t kv_input_toks; + int32_t kv_input_pos; std::vector>> k_cache; std::vector> v_cache; std::vector> k_cache_out; std::vector kv_attention_mask; std::vector kv_logits; std::vector prefill_input_toks; - std::vector prefill_atten_mask; + std::vector prefill_input_pos; + std::vector prefill_attention_mask; std::vector prefill_logits; }; private: - std::unique_ptr input_tok_; - std::unique_ptr input_pos_; - std::unique_ptr hidden_state_; - std::unique_ptr attention_mask_; + std::unique_ptr kv_input_toks_; + std::unique_ptr kv_input_pos_; + std::unique_ptr kv_attention_mask_; std::unique_ptr prefill_input_toks_; - std::unique_ptr prefill_attn_mask_; + std::unique_ptr prefill_input_pos_; + std::unique_ptr prefill_attention_mask_; std::unique_ptr prefill_logits_; std::unordered_map< std::string, @@ -157,7 +164,10 @@ class ShiftPointerIoMgr : public IoMgrBase { v_cache_out_; std::unique_ptr kv_logits_; std::vector shard_layers_; + int32_t context_len_{0}; + int32_t kv_ar_len_{0}; int32_t kv_cache_len_{0}; + int32_t prefill_ar_len_{0}; int32_t prefill_cache_len_{0}; int32_t vocab_size_; int32_t num_layers_; @@ -167,13 +177,17 @@ class ShiftPointerIoMgr : public IoMgrBase { std::string prefill_forward_name_; std::string kv_forward_name_; const bool use_int64_token_{false}; + const bool is_bert_{false}; }; class SmartMaskIoMgr : public IoMgrBase { public: SmartMaskIoMgr( std::vector>& modules, + int32_t context_len, + int32_t prefill_ar_len, int32_t prefill_cache_len, + int32_t kv_ar_len, int32_t kv_cache_len, int32_t vocab_size, int32_t num_layers, @@ -193,7 +207,9 @@ class SmartMaskIoMgr : public IoMgrBase { const std::vector< executorch::runtime::Result>& methods_meta) override; - void fill_prefill_toks(std::vector& prompt_tokens) override; + void fill_prefill_toks( + int64_t start_pos, + std::vector& prompt_tokens) override; void fill_kv_tok_mask(int64_t pos, int64_t cur_token) override; void update_prefill_to_kv_io( int64_t cur_token, @@ -216,22 +232,24 @@ class SmartMaskIoMgr : public IoMgrBase { struct IO { void* shared_buffer_base; - int64_t* input_tok; - int32_t* input_pos; + int64_t* kv_input_toks; + int32_t* kv_input_pos; // layer -> head -> head_dim * seq_len std::vector> k_cache; std::vector> v_cache; // layer -> head -> head_dim std::vector> k_cache_out; std::vector> v_cache_out; - // max_seq_len + // kv_ar_len_ * context_len_ uint16_t* kv_attention_mask; - // vocab_size + // kv_ar_len_ * vocab_size uint16_t* kv_logits; + // prefill_ar_len_ int64_t* prefill_input_toks; - // prefill_cache_len_ ^ 2 - uint16_t* prefill_atten_mask; - // vocab_size * prefill_cache_len_ + int32_t* prefill_input_pos; + // prefill_ar_len_ * context_len_ + uint16_t* prefill_attention_mask; + // vocab_size * prefill_ar_len_ uint16_t* prefill_logits; size_t num_layers_; @@ -252,12 +270,12 @@ class SmartMaskIoMgr : public IoMgrBase { }; private: - std::unique_ptr input_tok_; - std::unique_ptr input_pos_; - std::unique_ptr hidden_state_; - std::unique_ptr attention_mask_; + std::unique_ptr kv_input_toks_; + std::unique_ptr kv_input_pos_; + std::unique_ptr kv_attention_mask_; std::unique_ptr prefill_input_toks_; - std::unique_ptr prefill_attn_mask_; + std::unique_ptr prefill_input_pos_; + std::unique_ptr prefill_attention_mask_; std::unique_ptr prefill_logits_; std::unordered_map< std::string, @@ -277,7 +295,10 @@ class SmartMaskIoMgr : public IoMgrBase { v_cache_out_; std::unique_ptr kv_logits_; std::vector shard_layers_; + int32_t context_len_{0}; + int32_t kv_ar_len_{0}; int32_t kv_cache_len_{0}; + int32_t prefill_ar_len_{0}; int32_t prefill_cache_len_{0}; int32_t vocab_size_; int32_t num_layers_; @@ -287,6 +308,9 @@ class SmartMaskIoMgr : public IoMgrBase { std::string prefill_forward_name_; std::string kv_forward_name_; const bool use_int64_token_{false}; + // If the cache length is zero, it indicates a BERT model, which does not use + // position ids or KV cache inputs. + const bool is_bert_{false}; }; } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index 70ba25a0972..da1997a5060 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -45,7 +45,7 @@ Runner::Runner( const int32_t logits_offset, const float temperature, const int eval_mode, - const std::string& kv_updator) + const std::string& kv_updater) : n_bos_(1), n_eos_(1), tokenizer_path_(tokenizer_path), @@ -53,7 +53,7 @@ Runner::Runner( logits_offset_(logits_offset), temperature_(temperature), eval_mode_(static_cast(eval_mode)), - kv_updator_(kv_updator) { + kv_updater_(kv_updater) { for (size_t i = 0; i < models_path.size(); ++i) { modules_.push_back(std::make_shared( models_path[i], Module::LoadMode::MmapUseMlockIgnoreErrors)); @@ -77,10 +77,6 @@ Error Runner::load() { } switch (eval_mode_) { - case EvalMode::kPrefill: - prefill_forward_name_ = "forward"; - method_names_.emplace_back(prefill_forward_name_); - break; case EvalMode::kKVCached: kv_forward_name_ = "forward"; method_names_.emplace_back(kv_forward_name_); @@ -106,17 +102,22 @@ Error Runner::load() { } if (!prefill_forward_name_.empty()) { - // Use input tokens length to retrieve prefill cache len - // Cache len equals to prefill model seq_len - 1 - prefill_cache_len_ = get_methods_meta(prefill_forward_name_)[0] - ->input_tensor_meta(0) - ->sizes()[1]; + // Use attention mask length to retrieve prefill_ar_len and context length + // Prefill cache length equals to context_len - prefill_ar_len + auto atten_mask_meta = + get_methods_meta(prefill_forward_name_)[0]->input_tensor_meta(1); + prefill_ar_len_ = atten_mask_meta->sizes()[1]; + context_len_ = atten_mask_meta->sizes()[2]; + prefill_cache_len_ = context_len_ - prefill_ar_len_; } if (!kv_forward_name_.empty()) { - // Use k cache length to retirieve kv cache len - // Cache len equals to kv model seq_len - 1 - kv_cache_len_ = - get_methods_meta(kv_forward_name_)[0]->input_tensor_meta(3)->sizes()[2]; + // Use attention mask length to retrieve kv ar len and context length + // Cache len equals to kv model context_len - kv_ar_len + auto atten_mask_meta = + get_methods_meta(kv_forward_name_)[0]->input_tensor_meta(1); + kv_ar_len_ = atten_mask_meta->sizes()[1]; + context_len_ = atten_mask_meta->sizes()[2]; + kv_cache_len_ = context_len_ - kv_ar_len_; } // retrieve any method meta, can be either prefill or kv @@ -130,10 +131,13 @@ Error Runner::load() { executorch::aten::ScalarType::Long; ET_CHECK_MSG(num_layers != -1, "Could not retrieve num layers"); - if (kv_updator_ == "SmartMask") { + if (kv_updater_ == "SmartMask") { io_mgr_ = std::make_unique( modules_, + context_len_, + prefill_ar_len_, prefill_cache_len_, + kv_ar_len_, kv_cache_len_, vocab_size_, num_layers, @@ -143,10 +147,13 @@ Error Runner::load() { prefill_forward_name_, kv_forward_name_, use_int64_token_); - } else if (kv_updator_ == "ShiftPointer") { + } else if (kv_updater_ == "ShiftPointer") { io_mgr_ = std::make_unique( modules_, + context_len_, + prefill_ar_len_, prefill_cache_len_, + kv_ar_len_, kv_cache_len_, vocab_size_, num_layers, @@ -157,16 +164,13 @@ Error Runner::load() { kv_forward_name_, use_int64_token_); } else { - ET_LOG(Error, "Using an unknown updator %s", kv_updator_.c_str()); + ET_LOG(Error, "Using an unknown updater %s", kv_updater_.c_str()); } ET_LOG(Info, "creating io_memory"); // prepare io io_mgr_->init_io(); switch (eval_mode_) { - case EvalMode::kPrefill: - io_mgr_->prepare_prefill_io(get_methods_meta(prefill_forward_name_)); - break; case EvalMode::kKVCached: io_mgr_->prepare_kv_io(get_methods_meta(kv_forward_name_)); break; @@ -324,8 +328,7 @@ Error Runner::generate( break; } - int max_seq_len = std::max(prefill_cache_len_, kv_cache_len_) + 1; - seq_len = (seq_len > 0 && seq_len <= max_seq_len) ? seq_len : max_seq_len; + seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_; Result> encode_res = tokenizer_->encode(prompt_, n_bos_, 0); ET_CHECK_OK_OR_RETURN_ERROR( @@ -333,61 +336,46 @@ Error Runner::generate( std::vector prompt_tokens = encode_res.get(); int num_prompt_tokens = prompt_tokens.size(); - ET_CHECK_MSG(num_prompt_tokens < max_seq_len, "max seq length exceeded"); ET_CHECK_MSG( num_prompt_tokens < seq_len, "sequence length exceeded - please increase the seq_len value"); - if (eval_mode_ == EvalMode::kHybrid) { - int prefill_seq_len = get_methods_meta(prefill_forward_name_)[0] - ->input_tensor_meta(0) - ->sizes()[1] + - 1; - ET_CHECK_MSG( - num_prompt_tokens < prefill_seq_len, - "For hybrid mode, please ensure prompt length(%d) is less than prefill's seq_len(%d)", - num_prompt_tokens, - prefill_seq_len); - } int64_t pos = 0, prev_token, cur_token = prompt_tokens[0]; if (token_callback) { token_callback(prompt_); } auto prefill_execute = [&](const std::string& method_name) { - io_mgr_->fill_prefill_toks(prompt_tokens); + int num_iters = 1 + ((num_prompt_tokens - 1) / prefill_ar_len_); + ET_LOG( + Info, + "Prompt Processor: total %d tokens (AR-%d * %d iters)", + num_prompt_tokens, + prefill_ar_len_, + num_iters); - pos = num_prompt_tokens - 1; - cur_token = prompt_tokens[pos]; - while (pos < seq_len - 1) { - // inference + for (int i = 0; i < num_iters; i++) { + io_mgr_->fill_prefill_toks(pos, prompt_tokens); run_model_step(method_name, inputs[method_name]); - Tensor& logits_tensor = output_tensors[method_name].back()[0]; - prev_token = cur_token; - long sample_start_time_ms = time_in_ms(); - cur_token = logitsToToken(logits_tensor, pos); - stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms; - - io_mgr_->update_prefill_io(cur_token, ++pos, output_tensors[method_name]); - auto piece_res = tokenizer_->decode(prev_token, cur_token); - ET_CHECK(piece_res.ok()); - if (token_callback) { - token_callback(piece_res.get().c_str()); - } - - if (pos == num_prompt_tokens) { - stats_.first_token_ms = time_in_ms(); - stats_.prompt_eval_end_ms = time_in_ms(); - } - - if (pos >= num_prompt_tokens && eos_id_.count(cur_token) > 0) { - ET_LOG(Info, "\nReached to the end of generation"); - break; - } - // prefill model inferences once for prompt in the hybrid mode - if (eval_mode_ == EvalMode::kHybrid) { - break; - } + io_mgr_->update_prefill_io(cur_token, pos, output_tensors[method_name]); + pos += prefill_ar_len_; } + Tensor& logits_tensor = output_tensors[method_name].back()[0]; + prev_token = prompt_tokens[num_prompt_tokens - 1]; + long sample_start_time_ms = time_in_ms(); + cur_token = logitsToToken( + logits_tensor, + (num_prompt_tokens + prefill_ar_len_ - 1) % prefill_ar_len_); + stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms; + + auto piece_res = tokenizer_->decode(prev_token, cur_token); + ET_CHECK(piece_res.ok()); + if (token_callback) { + token_callback(piece_res.get().c_str()); + } + + pos = num_prompt_tokens; + stats_.first_token_ms = time_in_ms(); + stats_.prompt_eval_end_ms = time_in_ms(); }; auto kv_execute = [&](const std::string& method_name) { @@ -429,9 +417,6 @@ Error Runner::generate( }; switch (eval_mode_) { - case EvalMode::kPrefill: - prefill_execute(prefill_forward_name_); - break; case EvalMode::kKVCached: kv_execute(kv_forward_name_); break; diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index b6ba1360bff..e659ac55164 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -33,7 +33,7 @@ class Runner { const int32_t logits_offset, const float temperature, const int eval_mode, - const std::string& kv_updator); + const std::string& kv_updater); struct Stats { // Scaling factor for timestamps - in this case, we use ms. @@ -89,7 +89,10 @@ class Runner { std::string prompt_; // metadata + int32_t context_len_{0}; + int32_t prefill_ar_len_{0}; int32_t prefill_cache_len_{0}; + int32_t kv_ar_len_{0}; int32_t kv_cache_len_{0}; int32_t vocab_size_; int32_t bos_id_; @@ -111,7 +114,7 @@ class Runner { std::string kv_forward_name_; std::vector method_names_; LlamaVersion llama_version_; - std::string kv_updator_; + std::string kv_updater_; }; } // namespace example diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index 720877f0555..dde6a397d9a 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -890,7 +890,7 @@ def _unsafe_adjust_original_program( # noqa: C901 del original_program._state_dict[input_target] elif input_spec.kind == InputKind.BUFFER: if input_spec.persistent: - del original_program._state_dict[input_target] + original_program._state_dict.pop(input_target, None) else: del original_program._constants[input_spec.target] elif input_spec.kind == InputKind.CONSTANT_TENSOR: