From 38d22fbbd36a3b89c8d230edf7803a178f18d4f9 Mon Sep 17 00:00:00 2001 From: Min Guo Date: Fri, 31 Jan 2025 16:22:28 -0800 Subject: [PATCH] support input_pos > 0 for prefill model Summary: test input_pos>0 for prefill, not intention for landing but for sync with qc Differential Revision: D68847677 --- examples/qualcomm/oss_scripts/llama/llama.py | 46 +++++++++++++------ .../oss_scripts/llama/model/static_llama.py | 34 ++++++-------- 2 files changed, 44 insertions(+), 36 deletions(-) diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index e80e0c2808a..be3867e249e 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -126,18 +126,32 @@ def _kv_calibrate( else: raise RuntimeError("Unkown tokenizer") + with torch.no_grad(): while token_list[-1] != tokenizer.eos_id and pos < max_cache_len: logits, new_k_caches, new_v_caches = module( torch.full((1, 1), token_list[pos], dtype=torch.int32), atten_mask, - torch.full((1, 1), pos), + freq_cos, + freq_sin, *k_caches, *v_caches, ) atten_mask, pos, k_caches, v_caches = updator( atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches ) + k_caches = [ + torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1) + for i, k_cache in enumerate(k_caches) + ] + v_caches = [ + torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1) + for i, v_cache in enumerate(v_caches) + ] + + pos += 1 + atten_mask[0][-pos - 1] = 0 + print("pos", pos) if pos >= len(token_list): token_list.append(torch.argmax(logits[:, -1], dim=-1).item()) @@ -206,7 +220,7 @@ def calibrate( tokenizer, max_seq_len, ) - elif len(example_inputs) == 5: + elif len(example_inputs) == 6: _kv_calibrate( example_inputs, user_prompts, @@ -220,18 +234,17 @@ def calibrate( class SingleLlama: - def __init__(self, llama_model, pte_filename) -> None: + def __init__(self, llama_model, pte_filename, input_len) -> None: super().__init__() self.llama_model = llama_model self.quant_dtype = None self.llama_meta = self.llama_model.get_metadata() self.has_quant_io = False self.pte_filename = pte_filename + self.input_len = input_len if self.llama_meta["get_use_kv_cache"]: - tokens, atten_mask, pos_ids, k_caches, v_caches = self.get_example_inputs( - use_kv_cache=True - ) - self.inputs = (tokens, atten_mask, pos_ids, *k_caches, *v_caches) + tokens, atten_mask, freq_cos, freq_sin, k_caches, v_caches = self.get_example_inputs(self.input_len) + self.inputs = (tokens, atten_mask,freq_cos ,freq_sin, *k_caches, *v_caches) else: tokens, atten_mask = self.get_example_inputs(use_kv_cache=False) self.inputs = (tokens, atten_mask) @@ -346,7 +359,7 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): logging.info("Quantizing the model...") calibrate( - self.get_example_inputs(self.llama_meta["get_use_kv_cache"]), + self.get_example_inputs(self.input_len), args.prompt, fx_graph_module, tokenizer=tokenizer, @@ -417,8 +430,8 @@ def lowering_modules( with open(f"{work_space}/{self.pte_filename}.pte", "wb") as file: exec_prog_mgr.write_to_file(file) - def get_example_inputs(self, use_kv_cache=True): - return self.llama_model.get_example_inputs(use_kv_cache) + def get_example_inputs(self, input_len): + return self.llama_model.get_example_inputs(self.llama_meta, input_len) def get_quant_attrs(self): return self.quant_attrs @@ -437,7 +450,7 @@ def compile(args, pte_filename, tokenizer): prefill_config = copy.copy(kv_config) prefill_config.max_seq_len = args.prefill_seq_len - prefill_config.use_kv_cache = False + prefill_config.use_kv_cache = True state_dict = torch.load( args.checkpoint, weights_only=True, map_location="cpu", mmap=True @@ -451,14 +464,14 @@ def compile(args, pte_filename, tokenizer): ) elif args.model_mode == "prefill": llama_instance_list.append( - LlamaModel(prefill_config, output_new_cache_only=False) + LlamaModel(prefill_config, output_new_cache_only=True) ) elif args.model_mode == "hybrid": llama_instance_list.append( LlamaModel(kv_config, output_new_cache_only=True) ) llama_instance_list.append( - LlamaModel(prefill_config, output_new_cache_only=False) + LlamaModel(prefill_config, output_new_cache_only=True) ) else: raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") @@ -506,11 +519,13 @@ def compile(args, pte_filename, tokenizer): llama_instance_list[i] = llama_instance_list[i].to( dtype_override.to_torch_dtype() ) - + for i in range(len(llama_instance_list)): llama_instance_list[i] = convert_linear_to_conv2d(llama_instance_list[i]) + print(llama_instance_list[i].output_new_cache_only) + seq_len = 1 if i==0 else args.prefill_seq_len llama_instance_list[i] = SingleLlama( - llama_instance_list[i].eval(), pte_filename + llama_instance_list[i].eval(), pte_filename, seq_len ) if args.ptq: @@ -523,6 +538,7 @@ def compile(args, pte_filename, tokenizer): if args.ptq != None: kv_quant_attrs = {} for i, llama_instance in enumerate(llama_instance_list): + print(f"Quantizing {i}th model") llama_instance.quantize( quant_dtype=quant_dtype, args=args, diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index d1b618ed071..b9875c95136 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -14,7 +14,7 @@ import torch.nn.functional as F from executorch.examples.models.llama.llama_transformer import ( ModelArgs, - precompute_freqs_cis, + Rope, ) @@ -309,9 +309,11 @@ def __init__(self, config: ModelArgs, output_new_cache_only=True): self.n_kv_heads = config.n_kv_heads self.n_layers = config.n_layers self.vocab_size = config.vocab_size - self.rope_freq_base = config.rope_freq_base self.use_kv_cache = config.use_kv_cache self.output_new_cache_only = output_new_cache_only + rope = Rope(config) + pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32) + self.freqs_cos, self.freqs_sin = rope.get_freqs(pos_ids, self.max_seq_len) self.layers = nn.ModuleList( [ @@ -322,13 +324,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=True): self.norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps) self.output = nn.Linear(config.dim, config.vocab_size, bias=False) self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) - freqs_cos, freqs_sin = precompute_freqs_cis( - config.dim // config.n_heads, - config.max_seq_len, - config.rope_freq_base, - ) - self.register_buffer("freqs_cos", freqs_cos, persistent=False) - self.register_buffer("freqs_sin", freqs_sin, persistent=False) + def prepare_output_conv(self): def forward_output_conv(x): @@ -350,20 +346,14 @@ def forward( self, tokens: torch.Tensor, atten_mask: torch.Tensor, - input_pos: Optional[torch.Tensor] = None, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, *args, ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: output_k_cache = [] output_v_cache = [] - # following tensors should be invariant across batches - freqs_cos = ( - self.freqs_cos[input_pos][0] if self.use_kv_cache else self.freqs_cos[:-1] - ) - freqs_sin = ( - self.freqs_sin[input_pos][0] if self.use_kv_cache else self.freqs_sin[:-1] - ) - + hidden_states = self.tok_embeddings(tokens) for ind, decoder_layer in enumerate(self.layers): k_caches = None @@ -389,12 +379,13 @@ def forward( return logits, output_k_cache, output_v_cache - def get_example_inputs(self, use_kv_cache=True): + def get_example_inputs(self, llama_meta, input_len): + use_kv_cache=llama_meta["get_use_kv_cache"] if use_kv_cache: tokens = torch.randint( self.vocab_size, (self.max_batch_size, 1), dtype=torch.int32 ) - pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32) + k_cache, v_cache = [], [] atten_mask = torch.full((self.max_batch_size, self.max_seq_len), -255.0) atten_mask[:, -1] = 0 @@ -418,7 +409,8 @@ def get_example_inputs(self, use_kv_cache=True): return ( tokens, atten_mask, - pos_ids, + self.freqs_cos[:input_len], + self.freqs_sin[:input_len], k_cache, v_cache, )