Skip to content

support input_pos > 0 for prefill model #8127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,32 @@ def _kv_calibrate(
else:
raise RuntimeError("Unkown tokenizer")


with torch.no_grad():
while token_list[-1] != tokenizer.eos_id and pos < max_cache_len:
logits, new_k_caches, new_v_caches = module(
torch.full((1, 1), token_list[pos], dtype=torch.int32),
atten_mask,
torch.full((1, 1), pos),
freq_cos,
freq_sin,
*k_caches,
*v_caches,
)
atten_mask, pos, k_caches, v_caches = updator(
atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
)
k_caches = [
torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1)
for i, k_cache in enumerate(k_caches)
]
v_caches = [
torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1)
for i, v_cache in enumerate(v_caches)
]

pos += 1
atten_mask[0][-pos - 1] = 0
print("pos", pos)
if pos >= len(token_list):
token_list.append(torch.argmax(logits[:, -1], dim=-1).item())

Expand Down Expand Up @@ -206,7 +220,7 @@ def calibrate(
tokenizer,
max_seq_len,
)
elif len(example_inputs) == 5:
elif len(example_inputs) == 6:
_kv_calibrate(
example_inputs,
user_prompts,
Expand All @@ -220,18 +234,17 @@ def calibrate(


class SingleLlama:
def __init__(self, llama_model, pte_filename) -> None:
def __init__(self, llama_model, pte_filename, input_len) -> None:
super().__init__()
self.llama_model = llama_model
self.quant_dtype = None
self.llama_meta = self.llama_model.get_metadata()
self.has_quant_io = False
self.pte_filename = pte_filename
self.input_len = input_len
if self.llama_meta["get_use_kv_cache"]:
tokens, atten_mask, pos_ids, k_caches, v_caches = self.get_example_inputs(
use_kv_cache=True
)
self.inputs = (tokens, atten_mask, pos_ids, *k_caches, *v_caches)
tokens, atten_mask, freq_cos, freq_sin, k_caches, v_caches = self.get_example_inputs(self.input_len)
self.inputs = (tokens, atten_mask,freq_cos ,freq_sin, *k_caches, *v_caches)
else:
tokens, atten_mask = self.get_example_inputs(use_kv_cache=False)
self.inputs = (tokens, atten_mask)
Expand Down Expand Up @@ -346,7 +359,7 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):

logging.info("Quantizing the model...")
calibrate(
self.get_example_inputs(self.llama_meta["get_use_kv_cache"]),
self.get_example_inputs(self.input_len),
args.prompt,
fx_graph_module,
tokenizer=tokenizer,
Expand Down Expand Up @@ -417,8 +430,8 @@ def lowering_modules(
with open(f"{work_space}/{self.pte_filename}.pte", "wb") as file:
exec_prog_mgr.write_to_file(file)

def get_example_inputs(self, use_kv_cache=True):
return self.llama_model.get_example_inputs(use_kv_cache)
def get_example_inputs(self, input_len):
return self.llama_model.get_example_inputs(self.llama_meta, input_len)

def get_quant_attrs(self):
return self.quant_attrs
Expand All @@ -437,7 +450,7 @@ def compile(args, pte_filename, tokenizer):

prefill_config = copy.copy(kv_config)
prefill_config.max_seq_len = args.prefill_seq_len
prefill_config.use_kv_cache = False
prefill_config.use_kv_cache = True

state_dict = torch.load(
args.checkpoint, weights_only=True, map_location="cpu", mmap=True
Expand All @@ -451,14 +464,14 @@ def compile(args, pte_filename, tokenizer):
)
elif args.model_mode == "prefill":
llama_instance_list.append(
LlamaModel(prefill_config, output_new_cache_only=False)
LlamaModel(prefill_config, output_new_cache_only=True)
)
elif args.model_mode == "hybrid":
llama_instance_list.append(
LlamaModel(kv_config, output_new_cache_only=True)
)
llama_instance_list.append(
LlamaModel(prefill_config, output_new_cache_only=False)
LlamaModel(prefill_config, output_new_cache_only=True)
)
else:
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")
Expand Down Expand Up @@ -506,11 +519,13 @@ def compile(args, pte_filename, tokenizer):
llama_instance_list[i] = llama_instance_list[i].to(
dtype_override.to_torch_dtype()
)

for i in range(len(llama_instance_list)):
llama_instance_list[i] = convert_linear_to_conv2d(llama_instance_list[i])
print(llama_instance_list[i].output_new_cache_only)
seq_len = 1 if i==0 else args.prefill_seq_len
llama_instance_list[i] = SingleLlama(
llama_instance_list[i].eval(), pte_filename
llama_instance_list[i].eval(), pte_filename, seq_len
)

if args.ptq:
Expand All @@ -523,6 +538,7 @@ def compile(args, pte_filename, tokenizer):
if args.ptq != None:
kv_quant_attrs = {}
for i, llama_instance in enumerate(llama_instance_list):
print(f"Quantizing {i}th model")
llama_instance.quantize(
quant_dtype=quant_dtype,
args=args,
Expand Down
34 changes: 13 additions & 21 deletions examples/qualcomm/oss_scripts/llama/model/static_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch.nn.functional as F
from executorch.examples.models.llama.llama_transformer import (
ModelArgs,
precompute_freqs_cis,
Rope,
)


Expand Down Expand Up @@ -309,9 +309,11 @@ def __init__(self, config: ModelArgs, output_new_cache_only=True):
self.n_kv_heads = config.n_kv_heads
self.n_layers = config.n_layers
self.vocab_size = config.vocab_size
self.rope_freq_base = config.rope_freq_base
self.use_kv_cache = config.use_kv_cache
self.output_new_cache_only = output_new_cache_only
rope = Rope(config)
pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32)
self.freqs_cos, self.freqs_sin = rope.get_freqs(pos_ids, self.max_seq_len)

self.layers = nn.ModuleList(
[
Expand All @@ -322,13 +324,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=True):
self.norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
freqs_cos, freqs_sin = precompute_freqs_cis(
config.dim // config.n_heads,
config.max_seq_len,
config.rope_freq_base,
)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)


def prepare_output_conv(self):
def forward_output_conv(x):
Expand All @@ -350,20 +346,14 @@ def forward(
self,
tokens: torch.Tensor,
atten_mask: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
*args,
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:

output_k_cache = []
output_v_cache = []
# following tensors should be invariant across batches
freqs_cos = (
self.freqs_cos[input_pos][0] if self.use_kv_cache else self.freqs_cos[:-1]
)
freqs_sin = (
self.freqs_sin[input_pos][0] if self.use_kv_cache else self.freqs_sin[:-1]
)


hidden_states = self.tok_embeddings(tokens)
for ind, decoder_layer in enumerate(self.layers):
k_caches = None
Expand All @@ -389,12 +379,13 @@ def forward(

return logits, output_k_cache, output_v_cache

def get_example_inputs(self, use_kv_cache=True):
def get_example_inputs(self, llama_meta, input_len):
use_kv_cache=llama_meta["get_use_kv_cache"]
if use_kv_cache:
tokens = torch.randint(
self.vocab_size, (self.max_batch_size, 1), dtype=torch.int32
)
pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32)

k_cache, v_cache = [], []
atten_mask = torch.full((self.max_batch_size, self.max_seq_len), -255.0)
atten_mask[:, -1] = 0
Expand All @@ -418,7 +409,8 @@ def get_example_inputs(self, use_kv_cache=True):
return (
tokens,
atten_mask,
pos_ids,
self.freqs_cos[:input_len],
self.freqs_sin[:input_len],
k_cache,
v_cache,
)
Expand Down
Loading