From 3824a28ad99d8f830cfb259697a69c6ac7c0bbeb Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Mon, 29 Apr 2024 14:56:52 -0700 Subject: [PATCH 1/7] gpt-fast copy-pasta --- torchao/prototype/models/llama3/eval.py | 270 ++++++++ torchao/prototype/models/llama3/generate.py | 426 +++++++++++++ torchao/prototype/models/llama3/model.py | 255 ++++++++ torchao/prototype/models/llama3/quantize.py | 624 +++++++++++++++++++ torchao/prototype/models/llama3/tokenizer.py | 111 ++++ 5 files changed, 1686 insertions(+) create mode 100644 torchao/prototype/models/llama3/eval.py create mode 100644 torchao/prototype/models/llama3/generate.py create mode 100644 torchao/prototype/models/llama3/model.py create mode 100644 torchao/prototype/models/llama3/quantize.py create mode 100644 torchao/prototype/models/llama3/tokenizer.py diff --git a/torchao/prototype/models/llama3/eval.py b/torchao/prototype/models/llama3/eval.py new file mode 100644 index 0000000000..d38abf8625 --- /dev/null +++ b/torchao/prototype/models/llama3/eval.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import sys +import time +from pathlib import Path +from typing import Optional + +import torch +import torch._dynamo.config +import torch._inductor.config + +torch._dynamo.config.automatic_dynamic_shapes = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.epilogue_fusion = False +torch._inductor.config.triton.cudagraphs = True +torch._dynamo.config.cache_size_limit = 100000 + +from tokenizer import get_tokenizer + +from model import Transformer + +try: + import lm_eval + lm_eval_available = True +except: + lm_eval_available = False + +from generate import _load_model, encode_tokens, model_forward + +if lm_eval_available: + try: # lm_eval version 0.4 + from lm_eval.models.huggingface import HFLM as eval_wrapper + from lm_eval.tasks import get_task_dict + from lm_eval.evaluator import evaluate + except: #lm_eval version 0.3 + from lm_eval import base + from lm_eval import tasks + from lm_eval import evaluator + eval_wrapper=base.BaseLM + get_task_dict=tasks.get_task_dict + evaluate=evaluator.evaluate + + +def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( + model: Transformer, + prompt: torch.Tensor, + max_new_tokens: int, + max_seq_length: Optional[int] = None, +): + """ + Sets up model cache and does some bookkeeping calculations for prompt, input_pos and max_seq_length + that are needed for prefill or model_forward + + Args: + model (LLaMA): The model whose cache gets set up + prompt (torch.Tensor): Tensor of shape (T) with indices of the prompt sequence. + max_new_tokens (int): The desired maximum number of new tokens that can be generated. + max_seq_length (Optional[int], optional): The maximum sequence length allowed. + + Returns: + seq (torch.Tensor): prompt but padded with zeros to size max_seq_length + input_pos (torch.Tensor): tensor of integers in increasing order + max_seq_length (int): The maximum sequence length allowed, updated based on other numbers + """ + T = prompt.size(0) + T_new = T + max_new_tokens + if max_seq_length is None: + max_seq_length = min(T_new, model.config.block_size) + + device, dtype = prompt.device, prompt.dtype + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty(T_new, dtype=dtype, device=device) + empty[:T] = prompt + seq = empty + input_pos = torch.arange(0, T, device=device) + + with torch.device(device): + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + + return seq, input_pos, max_seq_length + +class GPTFastEvalWrapper(eval_wrapper): + """ + A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library. + """ + def __init__( + self, + model: Transformer, + tokenizer, + max_seq_length: Optional[int]=None, + ): + super().__init__() + self._model = model + self._tokenizer = tokenizer + self._device = torch.device('cuda') + self._max_seq_length = 2048 if max_seq_length is None else max_seq_length + + @property + def eot_token_id(self): + return self._tokenizer.eos_id() + + @property + def max_length(self): + return self._max_seq_length + + @property + def max_gen_toks(self): + return 50 + + @property + def batch_size(self): + return 1 + + @property + def device(self): + return self._device + + def tok_encode(self, string: str, **kwargs): + encoded = encode_tokens(self._tokenizer, + string, bos=True, device=self._device) + # encoded is a pytorch tensor, but some internal logic in the + # eval harness expects it to be a list instead + # TODO: verify this for multi-batch as well + encoded = encoded.tolist() + return encoded + + def tok_decode(self, tokens): + decoded = self._tokenizer.decode(tokens) + return decoded + + def _model_call(self, inps): + # TODO: make batches work + inps = inps.squeeze(0) + + max_new_tokens = 1 + seq, input_pos, max_seq_length = \ + setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( + self._model, + inps, + max_new_tokens, + self.max_length, + ) + x = seq.index_select(0, input_pos).view(1, -1) + logits = model_forward(self._model, x, input_pos) + return logits + + def _model_generate(self, context, max_length, eos_token_id): + raise Exception('unimplemented') + + +@torch.no_grad() +def eval( + model: Transformer, + tokenizer, + tasks: list = ["hellaswag"], + limit: Optional[int] = None, + max_seq_length: Optional[int] = None, +) -> dict: + """ + Evaluates a language model on a specified task using the lm-evaluation-harness library. + + Args: + model (Transformer): The pre-trained language model to evaluate. + tokenizer: The tokenizer to use for encoding/decoding text. + task (str): The name of the evaluation task to perform. + limit (Optional[int]): The maximum number of samples to evaluate (None for all available). + max_seq_length (Optional[int]): The maximum sequence length allowed for input text. + + Returns: + eval_results (dict): A dictionary of evaluation results for the specified task(s). + """ + model_eval_wrapper = GPTFastEvalWrapper( + model, + tokenizer, + max_seq_length, + ) + + try: + lm_eval.tasks.initialize_tasks() + except: + pass + + if 'hendrycks_test' in tasks: + tasks.remove('hendrycks_test') + tasks += [x for x in lm_eval.tasks.hendrycks_test.create_all_tasks().keys()] + task_dict = get_task_dict(tasks) + + eval_results = evaluate( + model_eval_wrapper, + task_dict, + limit=limit, + ) + return eval_results + + +def main( + checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/lit_model.pth"), + compile: bool = False, + tasks: list = ["hellaswag"], + limit: Optional[int] = None, + max_seq_length: Optional[int] = None, +) -> None: + """Evaluates model on a task from the `lm-evaluation-harness` library. + + Args: + checkpoint_path (Path): The path to the model checkpoint file to load. + compile (bool): Whether or not to compile the model for optimization. + task (Optional[str]): The name of the evaluation task or a list of tasks to perform. + limit (Optional[int]): The maximum number of samples to evaluate (None for all available). + max_seq_length (Optional[int]): The maximum sequence length allowed for input text. + + """ + + assert checkpoint_path.is_file(), checkpoint_path + + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), str(tokenizer_path) + + device = 'cuda' + precision = torch.bfloat16 + + print("Loading model ...") + t0 = time.time() + model = _load_model(checkpoint_path, device, precision, False) + + torch.cuda.synchronize() + print(f"Time to load model: {time.time() - t0:.02f} seconds.") + + model.eval() + + tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) + + torch.manual_seed(1234) + + if compile: + global model_forward + model_forward = torch.compile(model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True) + torch._inductor.config.coordinate_descent_tuning = True + + t1 = time.time() + result = eval( + model, + tokenizer, + tasks, + limit, + max_seq_length, + ) + print(f"Time to run eval: {time.time() - t1:.02f} seconds.") + print(f"For model {checkpoint_path}") + for task, res in result["results"].items(): + print(f"{task}: {res}") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Your CLI description.') + + parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/lit_model.pth"), help='Model checkpoint path.') + parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') + parser.add_argument('--tasks', nargs='+', type=str, default=["hellaswag"], help='list of lm-eluther tasks to evaluate usage: --tasks task1 task2') + parser.add_argument('--limit', type=int, default=None, help='number of samples to evalulate') + parser.add_argument('--max_seq_length', type=int, default=None, help='maximum length sequence to evaluate') + + args = parser.parse_args() + main( + Path(args.checkpoint_path), args.compile, args.tasks, args.limit, args.max_seq_length, + ) diff --git a/torchao/prototype/models/llama3/generate.py b/torchao/prototype/models/llama3/generate.py new file mode 100644 index 0000000000..24ba553d9c --- /dev/null +++ b/torchao/prototype/models/llama3/generate.py @@ -0,0 +1,426 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import itertools +import sys +import time +from pathlib import Path +from typing import Optional, Tuple + +import torch +import torch._dynamo.config +import torch._inductor.config + +def device_sync(device): + if "cuda" in device: + torch.cuda.synchronize(device) + elif ("cpu" in device) or ("mps" in device): + pass + else: + print(f"device={device} is not yet suppported") + + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future + +default_device = 'cuda' if torch.cuda.is_available() else 'cpu' + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from model import Transformer +from tokenizer import get_tokenizer + +def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + +def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + +def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): + probs = logits_to_probs(logits[0, -1], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + +def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: + # input_pos: [B, S] + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs)[0] + +def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [B, 1] + assert input_pos.shape[-1] == 1 + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs) + +def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): + new_tokens, new_probs = [], [] + for i in range(num_new_tokens): + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here + next_token, next_prob = decode_one_token( + model, cur_token, input_pos, **sampling_kwargs + ) + input_pos += 1 + new_tokens.append(next_token.clone()) + callback(new_tokens[-1]) + new_probs.append(next_prob.clone()) + cur_token = next_token.view(1, -1) + + return new_tokens, new_probs + + +def model_forward(model, x, input_pos): + return model(x, input_pos) + +def speculative_decode( + model: Transformer, + draft_model: Transformer, + cur_token: torch.Tensor, + input_pos: int, + speculate_k: int, + **sampling_kwargs +) -> torch.Tensor: + # draft model inference sequentially + device = cur_token.device + orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device) + draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs) + + draft_tokens = torch.cat(draft_tokens) + # parallel inference on target model using draft tokens + target_logits = model_forward( + model, + torch.cat([cur_token.view(1), draft_tokens]).view(1, -1), + torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device) + ) + target_probs = logits_to_probs(target_logits[0], **sampling_kwargs) + draft_probs = torch.stack(draft_probs) + # q: target prob, p: draft prob + # q >= p: always accept draft token + # q < p: q/p prob to accept draft token + p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens] + q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens] + accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p) + rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero() + + if rejected_locations.shape[0] == 0: # All draft tokens have been accepted + accept_length = speculate_k + 1 + last_token = multinomial_sample_one_no_sync(target_probs[-1]) + # fill last token into draft model + model_forward( + draft_model, + draft_tokens[-1].view(1, -1), + orig_input_pos + speculate_k, + ) + return torch.cat([draft_tokens, last_token]) + else: + accept_length = rejected_locations[0].item() + p = draft_probs[accept_length] + q = target_probs[accept_length] + new = q - p + new = torch.where(new > 0, new, 0.0) + new = new / new.sum() + next_token = multinomial_sample_one_no_sync(new) + return torch.cat([draft_tokens[:accept_length], next_token]) + +@torch.no_grad() +def generate( + model: Transformer, + prompt: torch.Tensor, + max_new_tokens: int, + *, + interactive: bool, + draft_model: Transformer, + speculate_k: Optional[int] = 8, + callback = lambda x: x, + **sampling_kwargs +) -> torch.Tensor: + """ + Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """ + + is_speculative = draft_model is not None + # create an empty tensor of the expected final shape and fill in the current tokens + T = prompt.size(0) + T_new = T + max_new_tokens + if interactive: + max_seq_length = 350 + else: + max_seq_length = min(T_new, model.config.block_size) + + device, dtype = prompt.device, prompt.dtype + max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length + with torch.device(device): + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + if is_speculative and draft_model is not model: + draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty(T_new, dtype=dtype, device=device) + empty[:T] = prompt + seq = empty + input_pos = torch.arange(0, T, device=device) + + next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs) + if is_speculative: + prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs) + seq[T] = next_token + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + accept_counts = [0] * (speculate_k + 1) + + if is_speculative: + input_pos = input_pos.item() # for speculative decoding easier to keep on host + while input_pos < T_new - 1: + cur_token = next_token.view(()) + + next_tokens = speculative_decode( + model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs + ) + + accept_counts[len(next_tokens) - 1] += 1 + num_added = min(T_new - input_pos - 1, len(next_tokens)) + seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added] + for i in next_tokens[: num_added,]: + callback(i) + input_pos = input_pos + num_added + next_token = next_tokens[-1] + else: + generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) + seq[T + 1:] = torch.cat(generated_tokens) + + generate_stats = { + 'accept_counts': accept_counts + } + return seq, generate_stats + +def encode_tokens(tokenizer, string, bos=True, device=default_device): + tokens = tokenizer.encode(string) + if bos: + tokens = [tokenizer.bos_id()] + tokens + return torch.tensor(tokens, dtype=torch.int, device=device) + +def _load_model(checkpoint_path, device, precision, use_tp): + use_cuda = 'cuda' in device + with torch.device('meta'): + model = Transformer.from_name(checkpoint_path.parent.name) + + if "int8" in str(checkpoint_path): + print("Using int8 weight-only quantization!") + from quantize import WeightOnlyInt8QuantHandler + simple_quantizer = WeightOnlyInt8QuantHandler(model) + model = simple_quantizer.convert_for_runtime() + + if "int4" in str(checkpoint_path): + print("Using int4 weight-only quantization!") + path_comps = checkpoint_path.name.split(".") + assert path_comps[-3].startswith("g") + assert path_comps[-2] in device, "weight packed format mismatch, please rerun quantize.py!" + groupsize = int(path_comps[-3][1:]) + from quantize import WeightOnlyInt4QuantHandler + simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) + model = simple_quantizer.convert_for_runtime(use_cuda) + + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + if "model" in checkpoint and "stories" in str(checkpoint_path): + checkpoint = checkpoint["model"] + model.load_state_dict(checkpoint, assign=True) + + if use_tp: + from tp import apply_tp + print("Applying tensor parallel to model ...") + apply_tp(model) + + model = model.to(device=device, dtype=precision) + return model.eval() + +B_INST, E_INST = "[INST]", "[/INST]" + +def main( + prompt: str = "Hello, my name is", + interactive: bool = False, + num_samples: int = 5, + max_new_tokens: int = 100, + top_k: int = 200, + temperature: float = 0.8, + checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), + compile: bool = True, + compile_prefill: bool = False, + profile: Optional[Path] = None, + draft_checkpoint_path: Optional[Path] = None, + speculate_k: int = 5, + device=default_device, +) -> None: + """Generates text samples based on a pre-trained Transformer model and tokenizer. + """ + assert checkpoint_path.is_file(), checkpoint_path + + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), str(tokenizer_path) + + global print + from tp import maybe_init_dist + rank = maybe_init_dist() + use_tp = rank is not None + if use_tp: + if rank != 0: + # only print on rank 0 + print = lambda *args, **kwargs: None + + print(f"Using device={device}") + precision = torch.bfloat16 + is_speculative = draft_checkpoint_path is not None + is_chat = "chat" in str(checkpoint_path) + + print("Loading model ...") + t0 = time.time() + model = _load_model(checkpoint_path, device, precision, use_tp) + + if is_speculative: + draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp) + else: + draft_model = None + + device_sync(device=device) # MKG + print(f"Time to load model: {time.time() - t0:.02f} seconds") + + tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) + + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + prompt_length = encoded.size(0) + + torch.manual_seed(1234) + model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) + if compile: + if is_speculative and use_tp: # and ("cuda" in device): + torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case + + if is_speculative: + global model_forward, logits_to_prob + model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True) + + global decode_one_token, prefill + decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) + + # Uncomment to squeeze more perf out of prefill + if compile_prefill: + prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + + + aggregate_metrics = { + 'tokens_per_sec': [], + 'accept_counts': [], + } + start = -1 if compile else 0 + + for i in range(start, num_samples): + device_sync(device=device) # MKG + if i >= 0 and interactive: + prompt = input("What is your prompt? ") + if is_chat: + prompt = f"{B_INST} {prompt.strip()} {E_INST}" + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + + if interactive and i >= 0: + buffer = [] + period_id = tokenizer.encode('.')[0] + done_generating = False + def callback(x): + nonlocal done_generating + if done_generating: + return + buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) + if x.item() == tokenizer.eos_id(): + done_generating = True + if len(buffer) == 4 or done_generating: + print(''.join(buffer), end='', flush=True) + buffer.clear() + # print(, end='', flush=True) + else: + callback = lambda x : x + t0 = time.perf_counter() + import contextlib + if (i != num_samples - 1 or not profile) or (use_tp and rank != 0): + prof = contextlib.nullcontext() + else: + torch.profiler._utils._init_for_cuda_graphs() + prof = torch.profiler.profile() + with prof: + y, metrics = generate( + model, + encoded, + max_new_tokens, + draft_model=draft_model, + speculate_k=speculate_k, + interactive=interactive, + callback=callback, + temperature=temperature, + top_k=top_k, + ) + aggregate_metrics['accept_counts'].append(metrics['accept_counts']) + if i == -1: + print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + continue + if hasattr(prof, "export_chrome_trace"): + if use_tp: + prof.export_chrome_trace(f"{profile}_rank_{rank}.json") + else: + prof.export_chrome_trace(f"{profile}.json") + device_sync(device=device) # MKG + t = time.perf_counter() - t0 + + if not interactive: + print(tokenizer.decode(y.tolist())) + else: + print() + tokens_generated = y.size(0) - prompt_length + tokens_sec = tokens_generated / t + aggregate_metrics['tokens_per_sec'].append(tokens_sec) + print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") + print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") + print("==========") + if is_speculative: + counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])] + acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated] + print(f"Acceptance probs: {acceptance_probs}") + print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}") + + print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") + print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Your CLI description.') + + parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') + parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') + parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') + parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') + parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') + parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') + parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') + parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') + parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') + parser.add_argument('--profile', type=Path, default=None, help='Profile path.') + parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.') + parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.') + parser.add_argument('--device', type=str, default=default_device, help='Device to use') + + args = parser.parse_args() + main( + args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, + args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path, + args.speculate_k, args.device + ) diff --git a/torchao/prototype/models/llama3/model.py b/torchao/prototype/models/llama3/model.py new file mode 100644 index 0000000000..0660bc2b72 --- /dev/null +++ b/torchao/prototype/models/llama3/model.py @@ -0,0 +1,255 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + @classmethod + def from_name(cls, name: str): + if name in transformer_configs: + return cls(**transformer_configs[name]) + # fuzzy search + config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)] + + # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match, + # take longer name (as it have more symbols matched) + if len(config) > 1: + config.sort(key=len, reverse=True) + assert len(config[0]) != len(config[1]), name # make sure only one 'best' match + + return cls(**transformer_configs[config[0]]) + + +transformer_configs = { + "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000), + "7B": dict(n_layer=32, n_head=32, dim=4096), + "13B": dict(n_layer=40, n_head=40, dim=5120), + "30B": dict(n_layer=60, n_head=52, dim=6656), + "34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000), # CodeLlama-34B-Python-hf + "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672), + "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), + "stories15M": dict(n_layer=6, n_head=6, dim=288), + "stories110M": dict(n_layer=12, n_head=12, dim=768), + "Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256), +} + +class KVCache(nn.Module): + def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + + self.freqs_cis: Optional[Tensor] = None + self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_caches(self, max_batch_size, max_seq_length): + if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + dtype = self.output.weight.dtype + # For quantized layers, dtype is encoded in scales + if hasattr(self.output, "scales"): + dtype = self.output.scales.dtype + elif hasattr(self.output, "scales_and_zeros"): + dtype = self.output.scales_and_zeros.dtype + for b in self.layers: + b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype) + + self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype) + self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) + + def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + assert self.freqs_cis is not None, "Caches must be initialized first" + mask = self.causal_mask[None, None, input_pos] + freqs_cis = self.freqs_cis[input_pos] + x = self.tok_embeddings(idx) + + for i, layer in enumerate(self.layers): + x = layer(x, input_pos, freqs_cis, mask) + x = self.norm(x) + logits = self.output(x) + return logits + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.feed_forward = FeedForward(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + y = self.wo(y) + return y + + +class FeedForward(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis( + seq_len: int, n_elem: int, base: int = 10000, + dtype: torch.dtype = torch.bfloat16 +) -> Tensor: + freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) diff --git a/torchao/prototype/models/llama3/quantize.py b/torchao/prototype/models/llama3/quantize.py new file mode 100644 index 0000000000..4ebbe5f57d --- /dev/null +++ b/torchao/prototype/models/llama3/quantize.py @@ -0,0 +1,624 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import time +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tokenizer import get_tokenizer + +try: + from GPTQ import GenericGPTQRunner, InputRecorder + from eval import get_task_dict, evaluate, lm_eval +except: + pass + +from model import Transformer + +default_device = 'cuda' if torch.cuda.is_available() else 'cpu' + +##### Quantization Primitives ###### + +def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): + # assumes symmetric quantization + # assumes axis == 0 + # assumes dense memory format + # TODO(future): relax ^ as needed + + # default setup for affine quantization of activations + eps = torch.finfo(torch.float32).eps + + # get min and max + min_val, max_val = torch.aminmax(x, dim=1) + + # calculate scales and zero_points based on min and max + # reference: https://fburl.com/code/srbiybme + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + device = min_val_neg.device + + # reference: https://fburl.com/code/4wll53rk + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scales = max_val_pos / (float(quant_max - quant_min) / 2) + # ensure scales is the same dtype as the original tensor + scales = torch.clamp(scales, min=eps).to(x.dtype) + zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + + # quantize based on qmin/qmax/scales/zp + # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 + x_div = x / scales.unsqueeze(-1) + x_round = torch.round(x_div) + x_zp = x_round + zero_points.unsqueeze(-1) + quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) + + return quant, scales, zero_points + +def get_group_qparams(w, n_bit=4, groupsize=128): + # needed for GPTQ with padding + if groupsize > w.shape[-1]: + groupsize = w.shape[-1] + assert groupsize > 1 + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2**n_bit - 1 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + zeros = min_val + scales * (2 ** (n_bit - 1)) + return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to( + torch.bfloat16 + ).reshape(w.shape[0], -1) + + +def pack_scales_and_zeros(scales, zeros): + assert scales.shape == zeros.shape + assert scales.dtype == torch.bfloat16 + assert zeros.dtype == torch.bfloat16 + return ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + .transpose(0, 1) + .contiguous() + ) + + +def unpack_scales_and_zeros(scales_and_zeros): + assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 + assert scales_and_zeros.dtype == torch.float + return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) + + +def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128): + assert groupsize > 1 + # needed for GPTQ single column quantize + if groupsize > w.shape[-1] and scales.shape[-1] == 1: + groupsize = w.shape[-1] + + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + min_val = zeros - scales * (2 ** (n_bit - 1)) + max_int = 2**n_bit - 1 + min_int = 0 + w_int32 = ( + to_quant.sub(min_val) + .div(scales) + .round() + .clamp_(min_int, max_int) + .to(torch.int32) + .reshape_as(w) + ) + + return w_int32 + + +def group_quantize_tensor(w, n_bit=4, groupsize=128): + scales, zeros = get_group_qparams(w, n_bit, groupsize) + w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize) + scales_and_zeros = pack_scales_and_zeros(scales, zeros) + return w_int32, scales_and_zeros + + +def group_dequantize_tensor_from_qparams( + w_int32, scales, zeros, n_bit=4, groupsize=128 +): + assert groupsize > 1 + # needed for GPTQ single column dequantize + if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: + groupsize = w_int32.shape[-1] + assert w_int32.shape[-1] % groupsize == 0 + assert w_int32.dim() == 2 + + w_int32_grouped = w_int32.reshape(-1, groupsize) + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + + w_dq = ( + w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32) + ) + return w_dq + + +def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128): + scales, zeros = unpack_scales_and_zeros(scales_and_zeros) + return group_dequantize_tensor_from_qparams( + w_int32, scales, zeros, n_bit, groupsize + ) + +class QuantHandler: + def __init__(self, mod): + self.mod = mod + + def create_quantized_state_dict(self) -> "StateDict": + pass + + def convert_for_runtime(self) -> "nn.Module": + pass + +class GPTQQuantHandler(QuantHandler): + """ + This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class. + Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement + __init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime. + + The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and + create_quantized_state_dict. Here is a description of each function. + + get_qparams_func: + A function that calculates the quantization qparams for an input tensor. + Args: + weight: A 2d weight tensor with non-integer dtype. + Returns: + qparams: it can have any format but will need to be handled by the other defined functions below. + + quantize_func: + A function that applies quantization to an input tensor. It should be noted + that this function needs to be able to handle quantizing the entire weight tensor, a single group, + or a single column. + Args: + weight: A 2d weight tensor with non-integer dtype. + qparams: the output from get_qparams_func + Returns: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + + + dequantize_func: + A function that dequantizes an input quantized weight tensor. It should be noted + that this function needs to be able to handle dequantizing the entire weight tensor, a single group, + or a single column. + Args: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + qparams: the output from get_qparams_func + Returns: + weight: A 2d weight tensor with non-integer dtype. + + combine_qparams_list_func: + A function that combines several qparams into one qparam. + Args: + qparams_list: a list of qparams objects, each obtained by calling get_qparams_func + on a single group from a weight tensor + Returns: + qparams: an object of the same format as the qparams above. + + skip_layer_func: + A function that determines which linear layers should be skipped during GPTQ + Args: + weight: A 2d weight tensor with non-integer dtype. + Returns: + skip: boolean indicating whether layer should be skipped + + make_names_and_values_dict_func: + A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they + should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here. + Args: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + qparams: the output from get_qparams_func + Returns: + names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the + corresponding quantized weights and qparams. + """ + def __init__(self): + assert self.mod is not None + assert self.get_qparams_func is not None + assert self.quantize_func is not None + assert self.dequantize_func is not None + assert self.combine_qparams_list_func is not None + assert self.make_names_and_values_dict_func is not None + + @staticmethod + def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) -> "MultiInput": + input_recorder = InputRecorder( + model, + tokenizer, + calibration_seq_length, + pad_calibration_inputs, + ) + + try: + lm_eval.tasks.initialize_tasks() + except: + pass + task_dict = get_task_dict(calibration_tasks) + print("Obtaining GPTQ calibration inputs on: ", calibration_tasks) + + evaluate( + input_recorder, + task_dict, + limit=calibration_limit, + ) + inputs = input_recorder.get_recorded_inputs() + assert inputs is not None, ( + f"No inputs were collected, use a task other than {calibration_tasks}, "+ + f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "+ + f"{calibration_seq_length})" + ) + print(f"Obtained {len(inputs[0].values)} calibration samples") + return inputs + + @torch.no_grad() + def create_quantized_state_dict( + self, + tokenizer, + blocksize, + percdamp, + groupsize, + calibration_tasks, + calibration_limit, + calibration_seq_length, + pad_calibration_inputs, + ) -> "StateDict": + inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) + print("Tracing model for GPTQ") + GPTQ_runner = GenericGPTQRunner( + self.mod, + inputs, + blocksize, + percdamp, + groupsize, + ).configure_quantization_mode( + self.get_qparams_func, + self.quantize_func, + self.dequantize_func, + self.combine_qparams_list_func, + self.make_names_and_values_dict_func, + self.skip_layer_func + ) + + print("Applying GPTQ to weights") + GPTQ_runner.run() + return GPTQ_runner.get_quantized_state_dict() + + def convert_for_runtime(self) -> "nn.Module": + pass + +##### Weight-only int8 per-channel quantized code ###### + +def replace_linear_weight_only_int8_per_channel(module): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + setattr(module, name, WeightOnlyInt8Linear(child.in_features, child.out_features)) + else: + replace_linear_weight_only_int8_per_channel(child) + +class WeightOnlyInt8QuantHandler: + def __init__(self, mod): + self.mod = mod + + @torch.no_grad() + def create_quantized_state_dict(self): + cur_state_dict = self.mod.state_dict() + for fqn, mod in self.mod.named_modules(): + if isinstance(mod, torch.nn.Linear): + int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, torch.int8) + cur_state_dict[f"{fqn}.weight"] = int8_weight.to('cpu') + cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype).to('cpu') + + return cur_state_dict + + def convert_for_runtime(self): + replace_linear_weight_only_int8_per_channel(self.mod) + return self.mod + + +class WeightOnlyInt8Linear(torch.nn.Module): + __constants__ = ['in_features', 'out_features'] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__(self, in_features: int, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8)) + self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales + +##### weight only int4 per channel groupwise quantized code ###### + +def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): + weight_int32, scales_and_zeros = group_quantize_tensor( + weight_bf16, n_bit=4, groupsize=groupsize + ) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles) + return weight_int4pack, scales_and_zeros + +def _calc_padded_size(k, groupsize=1, innner_k_tiles=1): + from model import find_multiple + return find_multiple(k, 1024) + +def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): + origin_x_size = x.size() + x = x.reshape(-1, origin_x_size[-1]) + c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros) + new_shape = origin_x_size[:-1] + (out_features,) + c = c.reshape(new_shape) + return c + + +def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1): + return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 + +def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed: + setattr(module, name, WeightOnlyInt4Linear( + child.in_features, child.out_features, bias=False, + groupsize=groupsize, inner_k_tiles=inner_k_tiles, use_cuda=use_cuda + )) + else: + replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda) + + +class WeightOnlyInt4QuantHandler: + def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding_allowed=True): + self.mod = mod + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.padding_allowed = padding_allowed + assert groupsize in [32, 64, 128, 256] + assert inner_k_tiles in [2, 4, 8] + + @torch.no_grad() + def create_quantized_state_dict(self): + cur_state_dict = self.mod.state_dict() + for fqn, mod in self.mod.named_modules(): + if isinstance(mod, torch.nn.Linear): + assert not mod.bias + out_features = mod.out_features + in_features = mod.in_features + assert out_features % 8 == 0, "require out_features % 8 == 0" + print(f"linear: {fqn}, in={in_features}, out={out_features}") + + weight = mod.weight.data + if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles): + if self.padding_allowed: + from model import find_multiple + import torch.nn.functional as F + print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0") + padded_in_features = find_multiple(in_features, 1024) + weight = F.pad(weight, pad=(0, padded_in_features - in_features)) + else: + print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + + "and that groupsize and inner_k_tiles*16 evenly divide into it") + continue + weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros( + weight.to(torch.bfloat16), self.groupsize, self.inner_k_tiles + ) + cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu') + cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu') + + return cur_state_dict + + def convert_for_runtime(self, use_cuda): + replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding_allowed, use_cuda) + return self.mod + +class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler): + def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): + from model import find_multiple + self.mod = mod + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.padding = padding + self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize) + self.quantize_func = lambda w, qparams: \ + group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize) + self.dequantize_func = lambda q, qparams: \ + group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float() + self.combine_qparams_list_func = lambda qparams_list: \ + [torch.cat(x, dim=1) for x in zip(*qparams_list)] + # skip unless padding=True or its correctly sized + self.skip_layer_func = lambda linear_weight: not ( + _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding + ) + # we need to do the padding here, both for q and the qparams if necessary + def make_names_and_values_dict_func(q, qparams): + k = q.shape[1] + if not _check_linear_int4_k(k, groupsize, inner_k_tiles): + new_k = find_multiple(k, 1024) + else: + new_k = k + # how much we need to pad the weight + delta_k = new_k - q.shape[1] + final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) + scales_and_zeros = pack_scales_and_zeros(*qparams) + # how many new groups we need for padded weight + delta_groups = new_k // groupsize - scales_and_zeros.shape[0] + final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1) + return {"weight": final_q, "scales_and_zeros": final_s_and_z} + self.make_names_and_values_dict_func = make_names_and_values_dict_func + super().__init__() + + + def convert_for_runtime(self, use_cuda): + replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda) + return self.mod + +class WeightOnlyInt4Linear(torch.nn.Module): + __constants__ = ['in_features', 'out_features'] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__( + self, in_features: int, out_features: int, + bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True, + ) -> None: + super().__init__() + self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) + if self.padding: + from model import find_multiple + self.origin_in_features = in_features + in_features = find_multiple(in_features, 1024) + + self.in_features = in_features + self.out_features = out_features + assert not bias, "require bias=False" + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + + assert out_features % 8 == 0, "require out_features % 8 == 0" + assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" + self.register_buffer( + "weight", + torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) + ) + self.register_buffer( + "scales_and_zeros", + torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16) + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = input.to(torch.bfloat16) + if self.padding: + import torch.nn.functional as F + input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) + return linear_forward_int4( + input, + self.weight, self.scales_and_zeros, self.out_features, self.groupsize + ) + + +def quantize( + checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), + mode: str = 'int8', + # following arguments only available when setting int4 quantization. + groupsize: int = 128, + # following arguments only used for GPTQ + calibration_tasks: list = ["hellaswag"], + calibration_limit: int = 1000, + calibration_seq_length: int = 100, + pad_calibration_inputs: bool = False, + percdamp: float = .01, + blocksize: int = 128, + label: str = '', + device: str = default_device, +) -> None: + assert checkpoint_path.is_file(), checkpoint_path + device = 'cpu' + precision = torch.bfloat16 + + print("Loading model ...") + t0 = time.time() + + with torch.device('meta'): + model = Transformer.from_name(checkpoint_path.parent.name) + + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + if "model" in checkpoint and "stories" in str(checkpoint_path): + checkpoint = checkpoint["model"] + model.load_state_dict(checkpoint, assign=True) + model = model.to(dtype=precision, device=device) + + if mode == 'int8': + print("Quantizing model weights for int8 weight-only symmetric per-channel quantization") + quant_handler = WeightOnlyInt8QuantHandler(model) + quantized_state_dict = quant_handler.create_quantized_state_dict() + + dir_name = checkpoint_path.parent + base_name = checkpoint_path.name + new_base_name = base_name.replace('.pth', f'{label}int8.pth') + + elif mode == 'int4': + print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization") + print(f"Prepacking model weights in {device} optimal layout") + quant_handler = WeightOnlyInt4QuantHandler(model, groupsize) + quantized_state_dict = quant_handler.create_quantized_state_dict() + + dir_name = checkpoint_path.parent + base_name = checkpoint_path.name + new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.{device}.pth") + + elif mode == 'int4-gptq': + print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ...") + quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize) + + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), str(tokenizer_path) + tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) + + quantized_state_dict = quant_handler.create_quantized_state_dict( + tokenizer, + blocksize, + percdamp, + groupsize, + calibration_tasks, + calibration_limit, + calibration_seq_length, + pad_calibration_inputs + ) + + dir_name = checkpoint_path.parent + base_name = checkpoint_path.name + new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.{device}.pth") + else: + raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]") + + quantize_path = dir_name / new_base_name + print(f"Writing quantized weights to {quantize_path}") + quantize_path.unlink(missing_ok=True) # remove existing file if one already there + torch.save(quantized_state_dict, quantize_path) + print(f"Quantization complete took {time.time() - t0:.02f} seconds") + return + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Quantize a model.') + parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.') + parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform') + parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.') + parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') + parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration') + parser.add_argument('--calibration_seq_length', type=int, default=100, help='length of sequences to use for gptq calibration') + parser.add_argument('--pad_calibration_inputs', type=bool, default=False, help='pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower') + parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening') + parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq') + parser.add_argument('--label', type=str, default='_', help='label to add to output filename') + parser.add_argument('--device', type=str, default=default_device, help='device to use') + + args = parser.parse_args() + quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label, args.device) diff --git a/torchao/prototype/models/llama3/tokenizer.py b/torchao/prototype/models/llama3/tokenizer.py new file mode 100644 index 0000000000..c62a0c5b3a --- /dev/null +++ b/torchao/prototype/models/llama3/tokenizer.py @@ -0,0 +1,111 @@ +import os +import sentencepiece as spm +import tiktoken +from tiktoken.load import load_tiktoken_bpe +from pathlib import Path +from typing import Dict + +class TokenizerInterface: + def __init__(self, model_path): + self.model_path = model_path + + def encode(self, text): + raise NotImplementedError("This method should be overridden by subclasses.") + + def decode(self, tokens): + raise NotImplementedError("This method should be overridden by subclasses.") + + def bos_id(self): + raise NotImplementedError("This method should be overridden by subclasses.") + + def eos_id(self): + raise NotImplementedError("This method should be overridden by subclasses.") + +class SentencePieceWrapper(TokenizerInterface): + def __init__(self, model_path): + super().__init__(model_path) + self.processor = spm.SentencePieceProcessor(str(model_path)) + + def encode(self, text): + return self.processor.EncodeAsIds(text) + + def decode(self, tokens): + return self.processor.DecodeIds(tokens) + + def bos_id(self): + return self.processor.bos_id() + + def eos_id(self): + return self.processor.eos_id() + +class TiktokenWrapper(TokenizerInterface): + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 + + def __init__(self, model_path): + super().__init__(model_path) + assert os.path.isfile(model_path), str(model_path) + mergeable_ranks = load_tiktoken_bpe(str(model_path)) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + [ + f"<|reserved_special_token_{i}|>" + for i in range(5, self.num_reserved_special_tokens - 5) + ] + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + # BOS / EOS token IDs + self._bos_id: int = self.special_tokens["<|begin_of_text|>"] + self._eos_id: int = self.special_tokens["<|end_of_text|>"] + + def encode(self, text): + return self.model.encode(text) + + def decode(self, tokens): + return self.model.decode(tokens) + + def bos_id(self): + return self._bos_id + + def eos_id(self): + return self._eos_id + +def get_tokenizer(tokenizer_model_path, model_name): + """ + Factory function to get the appropriate tokenizer based on the model name. + + Args: + - tokenizer_model_path (str): The file path to the tokenizer model. + - model_name (str): The name of the model, used to determine the tokenizer type. + + Returns: + - TokenizerInterface: An instance of a tokenizer. + """ + if "Llama-3" in str(model_name): + return TiktokenWrapper(tokenizer_model_path) + else: + return SentencePieceWrapper(tokenizer_model_path) From 480d83e14eb1c73cc52c708440be4a0004b45426 Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Mon, 29 Apr 2024 15:57:37 -0700 Subject: [PATCH 2/7] gpt_fused --- torchao/prototype/models/gpt_fused/README.md | 13 + .../models/{llama3 => gpt_fused}/model.py | 1 + torchao/prototype/models/llama3/eval.py | 270 -------- torchao/prototype/models/llama3/generate.py | 426 ------------ torchao/prototype/models/llama3/quantize.py | 624 ------------------ torchao/prototype/models/llama3/tokenizer.py | 111 ---- 6 files changed, 14 insertions(+), 1431 deletions(-) create mode 100644 torchao/prototype/models/gpt_fused/README.md rename torchao/prototype/models/{llama3 => gpt_fused}/model.py (99%) delete mode 100644 torchao/prototype/models/llama3/eval.py delete mode 100644 torchao/prototype/models/llama3/generate.py delete mode 100644 torchao/prototype/models/llama3/quantize.py delete mode 100644 torchao/prototype/models/llama3/tokenizer.py diff --git a/torchao/prototype/models/gpt_fused/README.md b/torchao/prototype/models/gpt_fused/README.md new file mode 100644 index 0000000000..40eab89f5b --- /dev/null +++ b/torchao/prototype/models/gpt_fused/README.md @@ -0,0 +1,13 @@ +## gpt-fused + +A more handwritten version of [gpt-fast](https://github.com/pytorch-labs/gpt-fast)'s model.py for us to experiment with. + +Requires the use of gpt-fast to use, but is a drop-in replacement. + +To use it just set to PYTHONPATH environment variable to `PYTHONPATH=/ao/torchao/prototype/models/gpt_fused` and delete gpt-fast's model.py. + +For example + +``` +PYTHONPATH=/home/cpuhrsch/local/ao/torchao/prototype/models/gpt_fused CUDA_VISIBLE_DEVICES=0 numactl --membind 0 --cpubind 0 python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --prompt "Hello, my name is" +``` diff --git a/torchao/prototype/models/llama3/model.py b/torchao/prototype/models/gpt_fused/model.py similarity index 99% rename from torchao/prototype/models/llama3/model.py rename to torchao/prototype/models/gpt_fused/model.py index 0660bc2b72..4c862c0887 100644 --- a/torchao/prototype/models/llama3/model.py +++ b/torchao/prototype/models/gpt_fused/model.py @@ -88,6 +88,7 @@ def update(self, input_pos, k_val, v_val): class Transformer(nn.Module): def __init__(self, config: ModelArgs) -> None: + print("AJSKDLAJSD") super().__init__() self.config = config diff --git a/torchao/prototype/models/llama3/eval.py b/torchao/prototype/models/llama3/eval.py deleted file mode 100644 index d38abf8625..0000000000 --- a/torchao/prototype/models/llama3/eval.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -import sys -import time -from pathlib import Path -from typing import Optional - -import torch -import torch._dynamo.config -import torch._inductor.config - -torch._dynamo.config.automatic_dynamic_shapes = True -torch._inductor.config.triton.unique_kernel_names = True -torch._inductor.config.epilogue_fusion = False -torch._inductor.config.triton.cudagraphs = True -torch._dynamo.config.cache_size_limit = 100000 - -from tokenizer import get_tokenizer - -from model import Transformer - -try: - import lm_eval - lm_eval_available = True -except: - lm_eval_available = False - -from generate import _load_model, encode_tokens, model_forward - -if lm_eval_available: - try: # lm_eval version 0.4 - from lm_eval.models.huggingface import HFLM as eval_wrapper - from lm_eval.tasks import get_task_dict - from lm_eval.evaluator import evaluate - except: #lm_eval version 0.3 - from lm_eval import base - from lm_eval import tasks - from lm_eval import evaluator - eval_wrapper=base.BaseLM - get_task_dict=tasks.get_task_dict - evaluate=evaluator.evaluate - - -def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( - model: Transformer, - prompt: torch.Tensor, - max_new_tokens: int, - max_seq_length: Optional[int] = None, -): - """ - Sets up model cache and does some bookkeeping calculations for prompt, input_pos and max_seq_length - that are needed for prefill or model_forward - - Args: - model (LLaMA): The model whose cache gets set up - prompt (torch.Tensor): Tensor of shape (T) with indices of the prompt sequence. - max_new_tokens (int): The desired maximum number of new tokens that can be generated. - max_seq_length (Optional[int], optional): The maximum sequence length allowed. - - Returns: - seq (torch.Tensor): prompt but padded with zeros to size max_seq_length - input_pos (torch.Tensor): tensor of integers in increasing order - max_seq_length (int): The maximum sequence length allowed, updated based on other numbers - """ - T = prompt.size(0) - T_new = T + max_new_tokens - if max_seq_length is None: - max_seq_length = min(T_new, model.config.block_size) - - device, dtype = prompt.device, prompt.dtype - # create an empty tensor of the expected final shape and fill in the current tokens - empty = torch.empty(T_new, dtype=dtype, device=device) - empty[:T] = prompt - seq = empty - input_pos = torch.arange(0, T, device=device) - - with torch.device(device): - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - - return seq, input_pos, max_seq_length - -class GPTFastEvalWrapper(eval_wrapper): - """ - A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library. - """ - def __init__( - self, - model: Transformer, - tokenizer, - max_seq_length: Optional[int]=None, - ): - super().__init__() - self._model = model - self._tokenizer = tokenizer - self._device = torch.device('cuda') - self._max_seq_length = 2048 if max_seq_length is None else max_seq_length - - @property - def eot_token_id(self): - return self._tokenizer.eos_id() - - @property - def max_length(self): - return self._max_seq_length - - @property - def max_gen_toks(self): - return 50 - - @property - def batch_size(self): - return 1 - - @property - def device(self): - return self._device - - def tok_encode(self, string: str, **kwargs): - encoded = encode_tokens(self._tokenizer, - string, bos=True, device=self._device) - # encoded is a pytorch tensor, but some internal logic in the - # eval harness expects it to be a list instead - # TODO: verify this for multi-batch as well - encoded = encoded.tolist() - return encoded - - def tok_decode(self, tokens): - decoded = self._tokenizer.decode(tokens) - return decoded - - def _model_call(self, inps): - # TODO: make batches work - inps = inps.squeeze(0) - - max_new_tokens = 1 - seq, input_pos, max_seq_length = \ - setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( - self._model, - inps, - max_new_tokens, - self.max_length, - ) - x = seq.index_select(0, input_pos).view(1, -1) - logits = model_forward(self._model, x, input_pos) - return logits - - def _model_generate(self, context, max_length, eos_token_id): - raise Exception('unimplemented') - - -@torch.no_grad() -def eval( - model: Transformer, - tokenizer, - tasks: list = ["hellaswag"], - limit: Optional[int] = None, - max_seq_length: Optional[int] = None, -) -> dict: - """ - Evaluates a language model on a specified task using the lm-evaluation-harness library. - - Args: - model (Transformer): The pre-trained language model to evaluate. - tokenizer: The tokenizer to use for encoding/decoding text. - task (str): The name of the evaluation task to perform. - limit (Optional[int]): The maximum number of samples to evaluate (None for all available). - max_seq_length (Optional[int]): The maximum sequence length allowed for input text. - - Returns: - eval_results (dict): A dictionary of evaluation results for the specified task(s). - """ - model_eval_wrapper = GPTFastEvalWrapper( - model, - tokenizer, - max_seq_length, - ) - - try: - lm_eval.tasks.initialize_tasks() - except: - pass - - if 'hendrycks_test' in tasks: - tasks.remove('hendrycks_test') - tasks += [x for x in lm_eval.tasks.hendrycks_test.create_all_tasks().keys()] - task_dict = get_task_dict(tasks) - - eval_results = evaluate( - model_eval_wrapper, - task_dict, - limit=limit, - ) - return eval_results - - -def main( - checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/lit_model.pth"), - compile: bool = False, - tasks: list = ["hellaswag"], - limit: Optional[int] = None, - max_seq_length: Optional[int] = None, -) -> None: - """Evaluates model on a task from the `lm-evaluation-harness` library. - - Args: - checkpoint_path (Path): The path to the model checkpoint file to load. - compile (bool): Whether or not to compile the model for optimization. - task (Optional[str]): The name of the evaluation task or a list of tasks to perform. - limit (Optional[int]): The maximum number of samples to evaluate (None for all available). - max_seq_length (Optional[int]): The maximum sequence length allowed for input text. - - """ - - assert checkpoint_path.is_file(), checkpoint_path - - tokenizer_path = checkpoint_path.parent / "tokenizer.model" - assert tokenizer_path.is_file(), str(tokenizer_path) - - device = 'cuda' - precision = torch.bfloat16 - - print("Loading model ...") - t0 = time.time() - model = _load_model(checkpoint_path, device, precision, False) - - torch.cuda.synchronize() - print(f"Time to load model: {time.time() - t0:.02f} seconds.") - - model.eval() - - tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) - - torch.manual_seed(1234) - - if compile: - global model_forward - model_forward = torch.compile(model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True) - torch._inductor.config.coordinate_descent_tuning = True - - t1 = time.time() - result = eval( - model, - tokenizer, - tasks, - limit, - max_seq_length, - ) - print(f"Time to run eval: {time.time() - t1:.02f} seconds.") - print(f"For model {checkpoint_path}") - for task, res in result["results"].items(): - print(f"{task}: {res}") - - -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser(description='Your CLI description.') - - parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/lit_model.pth"), help='Model checkpoint path.') - parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') - parser.add_argument('--tasks', nargs='+', type=str, default=["hellaswag"], help='list of lm-eluther tasks to evaluate usage: --tasks task1 task2') - parser.add_argument('--limit', type=int, default=None, help='number of samples to evalulate') - parser.add_argument('--max_seq_length', type=int, default=None, help='maximum length sequence to evaluate') - - args = parser.parse_args() - main( - Path(args.checkpoint_path), args.compile, args.tasks, args.limit, args.max_seq_length, - ) diff --git a/torchao/prototype/models/llama3/generate.py b/torchao/prototype/models/llama3/generate.py deleted file mode 100644 index 24ba553d9c..0000000000 --- a/torchao/prototype/models/llama3/generate.py +++ /dev/null @@ -1,426 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -import itertools -import sys -import time -from pathlib import Path -from typing import Optional, Tuple - -import torch -import torch._dynamo.config -import torch._inductor.config - -def device_sync(device): - if "cuda" in device: - torch.cuda.synchronize(device) - elif ("cpu" in device) or ("mps" in device): - pass - else: - print(f"device={device} is not yet suppported") - - -torch._inductor.config.coordinate_descent_tuning = True -torch._inductor.config.triton.unique_kernel_names = True -torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future - -default_device = 'cuda' if torch.cuda.is_available() else 'cpu' - -# support running without installing as a package -wd = Path(__file__).parent.parent.resolve() -sys.path.append(str(wd)) - -from model import Transformer -from tokenizer import get_tokenizer - -def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization - q = torch.empty_like(probs_sort).exponential_(1) - return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) - -def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): - logits = logits / max(temperature, 1e-5) - - if top_k is not None: - v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - pivot = v.select(-1, -1).unsqueeze(-1) - logits = torch.where(logits < pivot, -float("Inf"), logits) - probs = torch.nn.functional.softmax(logits, dim=-1) - return probs - -def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): - probs = logits_to_probs(logits[0, -1], temperature, top_k) - idx_next = multinomial_sample_one_no_sync(probs) - return idx_next, probs - -def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: - # input_pos: [B, S] - logits = model(x, input_pos) - return sample(logits, **sampling_kwargs)[0] - -def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [B, 1] - assert input_pos.shape[-1] == 1 - logits = model(x, input_pos) - return sample(logits, **sampling_kwargs) - -def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): - new_tokens, new_probs = [], [] - for i in range(num_new_tokens): - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here - next_token, next_prob = decode_one_token( - model, cur_token, input_pos, **sampling_kwargs - ) - input_pos += 1 - new_tokens.append(next_token.clone()) - callback(new_tokens[-1]) - new_probs.append(next_prob.clone()) - cur_token = next_token.view(1, -1) - - return new_tokens, new_probs - - -def model_forward(model, x, input_pos): - return model(x, input_pos) - -def speculative_decode( - model: Transformer, - draft_model: Transformer, - cur_token: torch.Tensor, - input_pos: int, - speculate_k: int, - **sampling_kwargs -) -> torch.Tensor: - # draft model inference sequentially - device = cur_token.device - orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device) - draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs) - - draft_tokens = torch.cat(draft_tokens) - # parallel inference on target model using draft tokens - target_logits = model_forward( - model, - torch.cat([cur_token.view(1), draft_tokens]).view(1, -1), - torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device) - ) - target_probs = logits_to_probs(target_logits[0], **sampling_kwargs) - draft_probs = torch.stack(draft_probs) - # q: target prob, p: draft prob - # q >= p: always accept draft token - # q < p: q/p prob to accept draft token - p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens] - q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens] - accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p) - rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero() - - if rejected_locations.shape[0] == 0: # All draft tokens have been accepted - accept_length = speculate_k + 1 - last_token = multinomial_sample_one_no_sync(target_probs[-1]) - # fill last token into draft model - model_forward( - draft_model, - draft_tokens[-1].view(1, -1), - orig_input_pos + speculate_k, - ) - return torch.cat([draft_tokens, last_token]) - else: - accept_length = rejected_locations[0].item() - p = draft_probs[accept_length] - q = target_probs[accept_length] - new = q - p - new = torch.where(new > 0, new, 0.0) - new = new / new.sum() - next_token = multinomial_sample_one_no_sync(new) - return torch.cat([draft_tokens[:accept_length], next_token]) - -@torch.no_grad() -def generate( - model: Transformer, - prompt: torch.Tensor, - max_new_tokens: int, - *, - interactive: bool, - draft_model: Transformer, - speculate_k: Optional[int] = 8, - callback = lambda x: x, - **sampling_kwargs -) -> torch.Tensor: - """ - Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. - """ - - is_speculative = draft_model is not None - # create an empty tensor of the expected final shape and fill in the current tokens - T = prompt.size(0) - T_new = T + max_new_tokens - if interactive: - max_seq_length = 350 - else: - max_seq_length = min(T_new, model.config.block_size) - - device, dtype = prompt.device, prompt.dtype - max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length - with torch.device(device): - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - if is_speculative and draft_model is not model: - draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - - # create an empty tensor of the expected final shape and fill in the current tokens - empty = torch.empty(T_new, dtype=dtype, device=device) - empty[:T] = prompt - seq = empty - input_pos = torch.arange(0, T, device=device) - - next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs) - if is_speculative: - prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs) - seq[T] = next_token - - input_pos = torch.tensor([T], device=device, dtype=torch.int) - accept_counts = [0] * (speculate_k + 1) - - if is_speculative: - input_pos = input_pos.item() # for speculative decoding easier to keep on host - while input_pos < T_new - 1: - cur_token = next_token.view(()) - - next_tokens = speculative_decode( - model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs - ) - - accept_counts[len(next_tokens) - 1] += 1 - num_added = min(T_new - input_pos - 1, len(next_tokens)) - seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added] - for i in next_tokens[: num_added,]: - callback(i) - input_pos = input_pos + num_added - next_token = next_tokens[-1] - else: - generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) - seq[T + 1:] = torch.cat(generated_tokens) - - generate_stats = { - 'accept_counts': accept_counts - } - return seq, generate_stats - -def encode_tokens(tokenizer, string, bos=True, device=default_device): - tokens = tokenizer.encode(string) - if bos: - tokens = [tokenizer.bos_id()] + tokens - return torch.tensor(tokens, dtype=torch.int, device=device) - -def _load_model(checkpoint_path, device, precision, use_tp): - use_cuda = 'cuda' in device - with torch.device('meta'): - model = Transformer.from_name(checkpoint_path.parent.name) - - if "int8" in str(checkpoint_path): - print("Using int8 weight-only quantization!") - from quantize import WeightOnlyInt8QuantHandler - simple_quantizer = WeightOnlyInt8QuantHandler(model) - model = simple_quantizer.convert_for_runtime() - - if "int4" in str(checkpoint_path): - print("Using int4 weight-only quantization!") - path_comps = checkpoint_path.name.split(".") - assert path_comps[-3].startswith("g") - assert path_comps[-2] in device, "weight packed format mismatch, please rerun quantize.py!" - groupsize = int(path_comps[-3][1:]) - from quantize import WeightOnlyInt4QuantHandler - simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) - model = simple_quantizer.convert_for_runtime(use_cuda) - - checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) - if "model" in checkpoint and "stories" in str(checkpoint_path): - checkpoint = checkpoint["model"] - model.load_state_dict(checkpoint, assign=True) - - if use_tp: - from tp import apply_tp - print("Applying tensor parallel to model ...") - apply_tp(model) - - model = model.to(device=device, dtype=precision) - return model.eval() - -B_INST, E_INST = "[INST]", "[/INST]" - -def main( - prompt: str = "Hello, my name is", - interactive: bool = False, - num_samples: int = 5, - max_new_tokens: int = 100, - top_k: int = 200, - temperature: float = 0.8, - checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), - compile: bool = True, - compile_prefill: bool = False, - profile: Optional[Path] = None, - draft_checkpoint_path: Optional[Path] = None, - speculate_k: int = 5, - device=default_device, -) -> None: - """Generates text samples based on a pre-trained Transformer model and tokenizer. - """ - assert checkpoint_path.is_file(), checkpoint_path - - tokenizer_path = checkpoint_path.parent / "tokenizer.model" - assert tokenizer_path.is_file(), str(tokenizer_path) - - global print - from tp import maybe_init_dist - rank = maybe_init_dist() - use_tp = rank is not None - if use_tp: - if rank != 0: - # only print on rank 0 - print = lambda *args, **kwargs: None - - print(f"Using device={device}") - precision = torch.bfloat16 - is_speculative = draft_checkpoint_path is not None - is_chat = "chat" in str(checkpoint_path) - - print("Loading model ...") - t0 = time.time() - model = _load_model(checkpoint_path, device, precision, use_tp) - - if is_speculative: - draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp) - else: - draft_model = None - - device_sync(device=device) # MKG - print(f"Time to load model: {time.time() - t0:.02f} seconds") - - tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) - - encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) - prompt_length = encoded.size(0) - - torch.manual_seed(1234) - model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) - if compile: - if is_speculative and use_tp: # and ("cuda" in device): - torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case - - if is_speculative: - global model_forward, logits_to_prob - model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True) - - global decode_one_token, prefill - decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) - - # Uncomment to squeeze more perf out of prefill - if compile_prefill: - prefill = torch.compile(prefill, fullgraph=True, dynamic=True) - - - aggregate_metrics = { - 'tokens_per_sec': [], - 'accept_counts': [], - } - start = -1 if compile else 0 - - for i in range(start, num_samples): - device_sync(device=device) # MKG - if i >= 0 and interactive: - prompt = input("What is your prompt? ") - if is_chat: - prompt = f"{B_INST} {prompt.strip()} {E_INST}" - encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) - - if interactive and i >= 0: - buffer = [] - period_id = tokenizer.encode('.')[0] - done_generating = False - def callback(x): - nonlocal done_generating - if done_generating: - return - buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) - if x.item() == tokenizer.eos_id(): - done_generating = True - if len(buffer) == 4 or done_generating: - print(''.join(buffer), end='', flush=True) - buffer.clear() - # print(, end='', flush=True) - else: - callback = lambda x : x - t0 = time.perf_counter() - import contextlib - if (i != num_samples - 1 or not profile) or (use_tp and rank != 0): - prof = contextlib.nullcontext() - else: - torch.profiler._utils._init_for_cuda_graphs() - prof = torch.profiler.profile() - with prof: - y, metrics = generate( - model, - encoded, - max_new_tokens, - draft_model=draft_model, - speculate_k=speculate_k, - interactive=interactive, - callback=callback, - temperature=temperature, - top_k=top_k, - ) - aggregate_metrics['accept_counts'].append(metrics['accept_counts']) - if i == -1: - print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") - continue - if hasattr(prof, "export_chrome_trace"): - if use_tp: - prof.export_chrome_trace(f"{profile}_rank_{rank}.json") - else: - prof.export_chrome_trace(f"{profile}.json") - device_sync(device=device) # MKG - t = time.perf_counter() - t0 - - if not interactive: - print(tokenizer.decode(y.tolist())) - else: - print() - tokens_generated = y.size(0) - prompt_length - tokens_sec = tokens_generated / t - aggregate_metrics['tokens_per_sec'].append(tokens_sec) - print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") - print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") - print("==========") - if is_speculative: - counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])] - acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated] - print(f"Acceptance probs: {acceptance_probs}") - print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}") - - print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") - print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") - - -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser(description='Your CLI description.') - - parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') - parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') - parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') - parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') - parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') - parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') - parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') - parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') - parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') - parser.add_argument('--profile', type=Path, default=None, help='Profile path.') - parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.') - parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.') - parser.add_argument('--device', type=str, default=default_device, help='Device to use') - - args = parser.parse_args() - main( - args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, - args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path, - args.speculate_k, args.device - ) diff --git a/torchao/prototype/models/llama3/quantize.py b/torchao/prototype/models/llama3/quantize.py deleted file mode 100644 index 4ebbe5f57d..0000000000 --- a/torchao/prototype/models/llama3/quantize.py +++ /dev/null @@ -1,624 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -import time -from pathlib import Path - -import torch -import torch.nn as nn -import torch.nn.functional as F -from tokenizer import get_tokenizer - -try: - from GPTQ import GenericGPTQRunner, InputRecorder - from eval import get_task_dict, evaluate, lm_eval -except: - pass - -from model import Transformer - -default_device = 'cuda' if torch.cuda.is_available() else 'cpu' - -##### Quantization Primitives ###### - -def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): - # assumes symmetric quantization - # assumes axis == 0 - # assumes dense memory format - # TODO(future): relax ^ as needed - - # default setup for affine quantization of activations - eps = torch.finfo(torch.float32).eps - - # get min and max - min_val, max_val = torch.aminmax(x, dim=1) - - # calculate scales and zero_points based on min and max - # reference: https://fburl.com/code/srbiybme - min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) - max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) - device = min_val_neg.device - - # reference: https://fburl.com/code/4wll53rk - max_val_pos = torch.max(-min_val_neg, max_val_pos) - scales = max_val_pos / (float(quant_max - quant_min) / 2) - # ensure scales is the same dtype as the original tensor - scales = torch.clamp(scales, min=eps).to(x.dtype) - zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) - - # quantize based on qmin/qmax/scales/zp - # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 - x_div = x / scales.unsqueeze(-1) - x_round = torch.round(x_div) - x_zp = x_round + zero_points.unsqueeze(-1) - quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) - - return quant, scales, zero_points - -def get_group_qparams(w, n_bit=4, groupsize=128): - # needed for GPTQ with padding - if groupsize > w.shape[-1]: - groupsize = w.shape[-1] - assert groupsize > 1 - assert w.shape[-1] % groupsize == 0 - assert w.dim() == 2 - - to_quant = w.reshape(-1, groupsize) - assert torch.isnan(to_quant).sum() == 0 - - max_val = to_quant.amax(dim=1, keepdim=True) - min_val = to_quant.amin(dim=1, keepdim=True) - max_int = 2**n_bit - 1 - scales = (max_val - min_val).clamp(min=1e-6) / max_int - zeros = min_val + scales * (2 ** (n_bit - 1)) - return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to( - torch.bfloat16 - ).reshape(w.shape[0], -1) - - -def pack_scales_and_zeros(scales, zeros): - assert scales.shape == zeros.shape - assert scales.dtype == torch.bfloat16 - assert zeros.dtype == torch.bfloat16 - return ( - torch.cat( - [ - scales.reshape(scales.size(0), scales.size(1), 1), - zeros.reshape(zeros.size(0), zeros.size(1), 1), - ], - 2, - ) - .transpose(0, 1) - .contiguous() - ) - - -def unpack_scales_and_zeros(scales_and_zeros): - assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 - assert scales_and_zeros.dtype == torch.float - return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) - - -def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128): - assert groupsize > 1 - # needed for GPTQ single column quantize - if groupsize > w.shape[-1] and scales.shape[-1] == 1: - groupsize = w.shape[-1] - - assert w.shape[-1] % groupsize == 0 - assert w.dim() == 2 - - to_quant = w.reshape(-1, groupsize) - assert torch.isnan(to_quant).sum() == 0 - - scales = scales.reshape(-1, 1) - zeros = zeros.reshape(-1, 1) - min_val = zeros - scales * (2 ** (n_bit - 1)) - max_int = 2**n_bit - 1 - min_int = 0 - w_int32 = ( - to_quant.sub(min_val) - .div(scales) - .round() - .clamp_(min_int, max_int) - .to(torch.int32) - .reshape_as(w) - ) - - return w_int32 - - -def group_quantize_tensor(w, n_bit=4, groupsize=128): - scales, zeros = get_group_qparams(w, n_bit, groupsize) - w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize) - scales_and_zeros = pack_scales_and_zeros(scales, zeros) - return w_int32, scales_and_zeros - - -def group_dequantize_tensor_from_qparams( - w_int32, scales, zeros, n_bit=4, groupsize=128 -): - assert groupsize > 1 - # needed for GPTQ single column dequantize - if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: - groupsize = w_int32.shape[-1] - assert w_int32.shape[-1] % groupsize == 0 - assert w_int32.dim() == 2 - - w_int32_grouped = w_int32.reshape(-1, groupsize) - scales = scales.reshape(-1, 1) - zeros = zeros.reshape(-1, 1) - - w_dq = ( - w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32) - ) - return w_dq - - -def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128): - scales, zeros = unpack_scales_and_zeros(scales_and_zeros) - return group_dequantize_tensor_from_qparams( - w_int32, scales, zeros, n_bit, groupsize - ) - -class QuantHandler: - def __init__(self, mod): - self.mod = mod - - def create_quantized_state_dict(self) -> "StateDict": - pass - - def convert_for_runtime(self) -> "nn.Module": - pass - -class GPTQQuantHandler(QuantHandler): - """ - This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class. - Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement - __init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime. - - The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and - create_quantized_state_dict. Here is a description of each function. - - get_qparams_func: - A function that calculates the quantization qparams for an input tensor. - Args: - weight: A 2d weight tensor with non-integer dtype. - Returns: - qparams: it can have any format but will need to be handled by the other defined functions below. - - quantize_func: - A function that applies quantization to an input tensor. It should be noted - that this function needs to be able to handle quantizing the entire weight tensor, a single group, - or a single column. - Args: - weight: A 2d weight tensor with non-integer dtype. - qparams: the output from get_qparams_func - Returns: - quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) - - - dequantize_func: - A function that dequantizes an input quantized weight tensor. It should be noted - that this function needs to be able to handle dequantizing the entire weight tensor, a single group, - or a single column. - Args: - quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) - qparams: the output from get_qparams_func - Returns: - weight: A 2d weight tensor with non-integer dtype. - - combine_qparams_list_func: - A function that combines several qparams into one qparam. - Args: - qparams_list: a list of qparams objects, each obtained by calling get_qparams_func - on a single group from a weight tensor - Returns: - qparams: an object of the same format as the qparams above. - - skip_layer_func: - A function that determines which linear layers should be skipped during GPTQ - Args: - weight: A 2d weight tensor with non-integer dtype. - Returns: - skip: boolean indicating whether layer should be skipped - - make_names_and_values_dict_func: - A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they - should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here. - Args: - quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) - qparams: the output from get_qparams_func - Returns: - names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the - corresponding quantized weights and qparams. - """ - def __init__(self): - assert self.mod is not None - assert self.get_qparams_func is not None - assert self.quantize_func is not None - assert self.dequantize_func is not None - assert self.combine_qparams_list_func is not None - assert self.make_names_and_values_dict_func is not None - - @staticmethod - def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) -> "MultiInput": - input_recorder = InputRecorder( - model, - tokenizer, - calibration_seq_length, - pad_calibration_inputs, - ) - - try: - lm_eval.tasks.initialize_tasks() - except: - pass - task_dict = get_task_dict(calibration_tasks) - print("Obtaining GPTQ calibration inputs on: ", calibration_tasks) - - evaluate( - input_recorder, - task_dict, - limit=calibration_limit, - ) - inputs = input_recorder.get_recorded_inputs() - assert inputs is not None, ( - f"No inputs were collected, use a task other than {calibration_tasks}, "+ - f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "+ - f"{calibration_seq_length})" - ) - print(f"Obtained {len(inputs[0].values)} calibration samples") - return inputs - - @torch.no_grad() - def create_quantized_state_dict( - self, - tokenizer, - blocksize, - percdamp, - groupsize, - calibration_tasks, - calibration_limit, - calibration_seq_length, - pad_calibration_inputs, - ) -> "StateDict": - inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) - print("Tracing model for GPTQ") - GPTQ_runner = GenericGPTQRunner( - self.mod, - inputs, - blocksize, - percdamp, - groupsize, - ).configure_quantization_mode( - self.get_qparams_func, - self.quantize_func, - self.dequantize_func, - self.combine_qparams_list_func, - self.make_names_and_values_dict_func, - self.skip_layer_func - ) - - print("Applying GPTQ to weights") - GPTQ_runner.run() - return GPTQ_runner.get_quantized_state_dict() - - def convert_for_runtime(self) -> "nn.Module": - pass - -##### Weight-only int8 per-channel quantized code ###### - -def replace_linear_weight_only_int8_per_channel(module): - for name, child in module.named_children(): - if isinstance(child, nn.Linear): - setattr(module, name, WeightOnlyInt8Linear(child.in_features, child.out_features)) - else: - replace_linear_weight_only_int8_per_channel(child) - -class WeightOnlyInt8QuantHandler: - def __init__(self, mod): - self.mod = mod - - @torch.no_grad() - def create_quantized_state_dict(self): - cur_state_dict = self.mod.state_dict() - for fqn, mod in self.mod.named_modules(): - if isinstance(mod, torch.nn.Linear): - int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, torch.int8) - cur_state_dict[f"{fqn}.weight"] = int8_weight.to('cpu') - cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype).to('cpu') - - return cur_state_dict - - def convert_for_runtime(self): - replace_linear_weight_only_int8_per_channel(self.mod) - return self.mod - - -class WeightOnlyInt8Linear(torch.nn.Module): - __constants__ = ['in_features', 'out_features'] - in_features: int - out_features: int - weight: torch.Tensor - - def __init__(self, in_features: int, out_features: int, bias: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8)) - self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales - -##### weight only int4 per channel groupwise quantized code ###### - -def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): - weight_int32, scales_and_zeros = group_quantize_tensor( - weight_bf16, n_bit=4, groupsize=groupsize - ) - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles) - return weight_int4pack, scales_and_zeros - -def _calc_padded_size(k, groupsize=1, innner_k_tiles=1): - from model import find_multiple - return find_multiple(k, 1024) - -def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): - origin_x_size = x.size() - x = x.reshape(-1, origin_x_size[-1]) - c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros) - new_shape = origin_x_size[:-1] + (out_features,) - c = c.reshape(new_shape) - return c - - -def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1): - return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 - -def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda): - for name, child in module.named_children(): - if isinstance(child, nn.Linear): - if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed: - setattr(module, name, WeightOnlyInt4Linear( - child.in_features, child.out_features, bias=False, - groupsize=groupsize, inner_k_tiles=inner_k_tiles, use_cuda=use_cuda - )) - else: - replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda) - - -class WeightOnlyInt4QuantHandler: - def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding_allowed=True): - self.mod = mod - self.groupsize = groupsize - self.inner_k_tiles = inner_k_tiles - self.padding_allowed = padding_allowed - assert groupsize in [32, 64, 128, 256] - assert inner_k_tiles in [2, 4, 8] - - @torch.no_grad() - def create_quantized_state_dict(self): - cur_state_dict = self.mod.state_dict() - for fqn, mod in self.mod.named_modules(): - if isinstance(mod, torch.nn.Linear): - assert not mod.bias - out_features = mod.out_features - in_features = mod.in_features - assert out_features % 8 == 0, "require out_features % 8 == 0" - print(f"linear: {fqn}, in={in_features}, out={out_features}") - - weight = mod.weight.data - if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles): - if self.padding_allowed: - from model import find_multiple - import torch.nn.functional as F - print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0") - padded_in_features = find_multiple(in_features, 1024) - weight = F.pad(weight, pad=(0, padded_in_features - in_features)) - else: - print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + - "and that groupsize and inner_k_tiles*16 evenly divide into it") - continue - weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros( - weight.to(torch.bfloat16), self.groupsize, self.inner_k_tiles - ) - cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu') - cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu') - - return cur_state_dict - - def convert_for_runtime(self, use_cuda): - replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding_allowed, use_cuda) - return self.mod - -class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler): - def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): - from model import find_multiple - self.mod = mod - self.groupsize = groupsize - self.inner_k_tiles = inner_k_tiles - self.padding = padding - self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize) - self.quantize_func = lambda w, qparams: \ - group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize) - self.dequantize_func = lambda q, qparams: \ - group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float() - self.combine_qparams_list_func = lambda qparams_list: \ - [torch.cat(x, dim=1) for x in zip(*qparams_list)] - # skip unless padding=True or its correctly sized - self.skip_layer_func = lambda linear_weight: not ( - _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding - ) - # we need to do the padding here, both for q and the qparams if necessary - def make_names_and_values_dict_func(q, qparams): - k = q.shape[1] - if not _check_linear_int4_k(k, groupsize, inner_k_tiles): - new_k = find_multiple(k, 1024) - else: - new_k = k - # how much we need to pad the weight - delta_k = new_k - q.shape[1] - final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) - scales_and_zeros = pack_scales_and_zeros(*qparams) - # how many new groups we need for padded weight - delta_groups = new_k // groupsize - scales_and_zeros.shape[0] - final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1) - return {"weight": final_q, "scales_and_zeros": final_s_and_z} - self.make_names_and_values_dict_func = make_names_and_values_dict_func - super().__init__() - - - def convert_for_runtime(self, use_cuda): - replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda) - return self.mod - -class WeightOnlyInt4Linear(torch.nn.Module): - __constants__ = ['in_features', 'out_features'] - in_features: int - out_features: int - weight: torch.Tensor - - def __init__( - self, in_features: int, out_features: int, - bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True, - ) -> None: - super().__init__() - self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) - if self.padding: - from model import find_multiple - self.origin_in_features = in_features - in_features = find_multiple(in_features, 1024) - - self.in_features = in_features - self.out_features = out_features - assert not bias, "require bias=False" - self.groupsize = groupsize - self.inner_k_tiles = inner_k_tiles - - assert out_features % 8 == 0, "require out_features % 8 == 0" - assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" - self.register_buffer( - "weight", - torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) - ) - self.register_buffer( - "scales_and_zeros", - torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16) - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - input = input.to(torch.bfloat16) - if self.padding: - import torch.nn.functional as F - input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) - return linear_forward_int4( - input, - self.weight, self.scales_and_zeros, self.out_features, self.groupsize - ) - - -def quantize( - checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), - mode: str = 'int8', - # following arguments only available when setting int4 quantization. - groupsize: int = 128, - # following arguments only used for GPTQ - calibration_tasks: list = ["hellaswag"], - calibration_limit: int = 1000, - calibration_seq_length: int = 100, - pad_calibration_inputs: bool = False, - percdamp: float = .01, - blocksize: int = 128, - label: str = '', - device: str = default_device, -) -> None: - assert checkpoint_path.is_file(), checkpoint_path - device = 'cpu' - precision = torch.bfloat16 - - print("Loading model ...") - t0 = time.time() - - with torch.device('meta'): - model = Transformer.from_name(checkpoint_path.parent.name) - - checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) - if "model" in checkpoint and "stories" in str(checkpoint_path): - checkpoint = checkpoint["model"] - model.load_state_dict(checkpoint, assign=True) - model = model.to(dtype=precision, device=device) - - if mode == 'int8': - print("Quantizing model weights for int8 weight-only symmetric per-channel quantization") - quant_handler = WeightOnlyInt8QuantHandler(model) - quantized_state_dict = quant_handler.create_quantized_state_dict() - - dir_name = checkpoint_path.parent - base_name = checkpoint_path.name - new_base_name = base_name.replace('.pth', f'{label}int8.pth') - - elif mode == 'int4': - print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization") - print(f"Prepacking model weights in {device} optimal layout") - quant_handler = WeightOnlyInt4QuantHandler(model, groupsize) - quantized_state_dict = quant_handler.create_quantized_state_dict() - - dir_name = checkpoint_path.parent - base_name = checkpoint_path.name - new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.{device}.pth") - - elif mode == 'int4-gptq': - print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ...") - quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize) - - tokenizer_path = checkpoint_path.parent / "tokenizer.model" - assert tokenizer_path.is_file(), str(tokenizer_path) - tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) - - quantized_state_dict = quant_handler.create_quantized_state_dict( - tokenizer, - blocksize, - percdamp, - groupsize, - calibration_tasks, - calibration_limit, - calibration_seq_length, - pad_calibration_inputs - ) - - dir_name = checkpoint_path.parent - base_name = checkpoint_path.name - new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.{device}.pth") - else: - raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]") - - quantize_path = dir_name / new_base_name - print(f"Writing quantized weights to {quantize_path}") - quantize_path.unlink(missing_ok=True) # remove existing file if one already there - torch.save(quantized_state_dict, quantize_path) - print(f"Quantization complete took {time.time() - t0:.02f} seconds") - return - -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser(description='Quantize a model.') - parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.') - parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform') - parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.') - parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') - parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration') - parser.add_argument('--calibration_seq_length', type=int, default=100, help='length of sequences to use for gptq calibration') - parser.add_argument('--pad_calibration_inputs', type=bool, default=False, help='pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower') - parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening') - parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq') - parser.add_argument('--label', type=str, default='_', help='label to add to output filename') - parser.add_argument('--device', type=str, default=default_device, help='device to use') - - args = parser.parse_args() - quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label, args.device) diff --git a/torchao/prototype/models/llama3/tokenizer.py b/torchao/prototype/models/llama3/tokenizer.py deleted file mode 100644 index c62a0c5b3a..0000000000 --- a/torchao/prototype/models/llama3/tokenizer.py +++ /dev/null @@ -1,111 +0,0 @@ -import os -import sentencepiece as spm -import tiktoken -from tiktoken.load import load_tiktoken_bpe -from pathlib import Path -from typing import Dict - -class TokenizerInterface: - def __init__(self, model_path): - self.model_path = model_path - - def encode(self, text): - raise NotImplementedError("This method should be overridden by subclasses.") - - def decode(self, tokens): - raise NotImplementedError("This method should be overridden by subclasses.") - - def bos_id(self): - raise NotImplementedError("This method should be overridden by subclasses.") - - def eos_id(self): - raise NotImplementedError("This method should be overridden by subclasses.") - -class SentencePieceWrapper(TokenizerInterface): - def __init__(self, model_path): - super().__init__(model_path) - self.processor = spm.SentencePieceProcessor(str(model_path)) - - def encode(self, text): - return self.processor.EncodeAsIds(text) - - def decode(self, tokens): - return self.processor.DecodeIds(tokens) - - def bos_id(self): - return self.processor.bos_id() - - def eos_id(self): - return self.processor.eos_id() - -class TiktokenWrapper(TokenizerInterface): - """ - Tokenizing and encoding/decoding text using the Tiktoken tokenizer. - """ - - special_tokens: Dict[str, int] - - num_reserved_special_tokens = 256 - - pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 - - def __init__(self, model_path): - super().__init__(model_path) - assert os.path.isfile(model_path), str(model_path) - mergeable_ranks = load_tiktoken_bpe(str(model_path)) - num_base_tokens = len(mergeable_ranks) - special_tokens = [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|reserved_special_token_2|>", - "<|reserved_special_token_3|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|reserved_special_token_4|>", - "<|eot_id|>", # end of turn - ] + [ - f"<|reserved_special_token_{i}|>" - for i in range(5, self.num_reserved_special_tokens - 5) - ] - self.special_tokens = { - token: num_base_tokens + i for i, token in enumerate(special_tokens) - } - self.model = tiktoken.Encoding( - name=Path(model_path).name, - pat_str=self.pat_str, - mergeable_ranks=mergeable_ranks, - special_tokens=self.special_tokens, - ) - # BOS / EOS token IDs - self._bos_id: int = self.special_tokens["<|begin_of_text|>"] - self._eos_id: int = self.special_tokens["<|end_of_text|>"] - - def encode(self, text): - return self.model.encode(text) - - def decode(self, tokens): - return self.model.decode(tokens) - - def bos_id(self): - return self._bos_id - - def eos_id(self): - return self._eos_id - -def get_tokenizer(tokenizer_model_path, model_name): - """ - Factory function to get the appropriate tokenizer based on the model name. - - Args: - - tokenizer_model_path (str): The file path to the tokenizer model. - - model_name (str): The name of the model, used to determine the tokenizer type. - - Returns: - - TokenizerInterface: An instance of a tokenizer. - """ - if "Llama-3" in str(model_name): - return TiktokenWrapper(tokenizer_model_path) - else: - return SentencePieceWrapper(tokenizer_model_path) From d3620427f8de0b5efc4bc08ebf738b4fe6ea728e Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Mon, 29 Apr 2024 16:00:29 -0700 Subject: [PATCH 3/7] Remove print --- torchao/prototype/models/gpt_fused/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/prototype/models/gpt_fused/model.py b/torchao/prototype/models/gpt_fused/model.py index 4c862c0887..0660bc2b72 100644 --- a/torchao/prototype/models/gpt_fused/model.py +++ b/torchao/prototype/models/gpt_fused/model.py @@ -88,7 +88,6 @@ def update(self, input_pos, k_val, v_val): class Transformer(nn.Module): def __init__(self, config: ModelArgs) -> None: - print("AJSKDLAJSD") super().__init__() self.config = config From f7305252da6253312f14b302afbe2ce5f1342da6 Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Mon, 29 Apr 2024 17:02:37 -0700 Subject: [PATCH 4/7] Benchmark copy pasta --- benchmarks/gpt_fused/generate.py | 428 ++++++++++++++++++++ benchmarks/gpt_fused/tokenizer.py | 111 +++++ torchao/prototype/models/gpt_fused/model.py | 14 +- 3 files changed, 550 insertions(+), 3 deletions(-) create mode 100644 benchmarks/gpt_fused/generate.py create mode 100644 benchmarks/gpt_fused/tokenizer.py diff --git a/benchmarks/gpt_fused/generate.py b/benchmarks/gpt_fused/generate.py new file mode 100644 index 0000000000..4e7bb88728 --- /dev/null +++ b/benchmarks/gpt_fused/generate.py @@ -0,0 +1,428 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import itertools +import sys +import time +from pathlib import Path +from typing import Optional, Tuple + +import torch +import torch._dynamo.config +import torch._inductor.config + +def device_sync(device): + if "cuda" in device: + torch.cuda.synchronize(device) + elif ("cpu" in device) or ("mps" in device): + pass + else: + print(f"device={device} is not yet suppported") + + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future + +default_device = 'cuda' if torch.cuda.is_available() else 'cpu' + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from torchao.prototype.models.gpt_fused.model import Transformer +from tokenizer import get_tokenizer + +def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + +def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + +def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): + probs = logits_to_probs(logits[0, -1], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + +def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: + # input_pos: [B, S] + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs)[0] + +def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [B, 1] + assert input_pos.shape[-1] == 1 + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs) + +def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): + new_tokens, new_probs = [], [] + for i in range(num_new_tokens): + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here + next_token, next_prob = decode_one_token( + model, cur_token, input_pos, **sampling_kwargs + ) + input_pos += 1 + new_tokens.append(next_token.clone()) + callback(new_tokens[-1]) + new_probs.append(next_prob.clone()) + cur_token = next_token.view(1, -1) + + return new_tokens, new_probs + + +def model_forward(model, x, input_pos): + return model(x, input_pos) + +def speculative_decode( + model: Transformer, + draft_model: Transformer, + cur_token: torch.Tensor, + input_pos: int, + speculate_k: int, + **sampling_kwargs +) -> torch.Tensor: + # draft model inference sequentially + device = cur_token.device + orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device) + draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs) + + draft_tokens = torch.cat(draft_tokens) + # parallel inference on target model using draft tokens + target_logits = model_forward( + model, + torch.cat([cur_token.view(1), draft_tokens]).view(1, -1), + torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device) + ) + target_probs = logits_to_probs(target_logits[0], **sampling_kwargs) + draft_probs = torch.stack(draft_probs) + # q: target prob, p: draft prob + # q >= p: always accept draft token + # q < p: q/p prob to accept draft token + p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens] + q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens] + accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p) + rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero() + + if rejected_locations.shape[0] == 0: # All draft tokens have been accepted + accept_length = speculate_k + 1 + last_token = multinomial_sample_one_no_sync(target_probs[-1]) + # fill last token into draft model + model_forward( + draft_model, + draft_tokens[-1].view(1, -1), + orig_input_pos + speculate_k, + ) + return torch.cat([draft_tokens, last_token]) + else: + accept_length = rejected_locations[0].item() + p = draft_probs[accept_length] + q = target_probs[accept_length] + new = q - p + new = torch.where(new > 0, new, 0.0) + new = new / new.sum() + next_token = multinomial_sample_one_no_sync(new) + return torch.cat([draft_tokens[:accept_length], next_token]) + +@torch.no_grad() +def generate( + model: Transformer, + prompt: torch.Tensor, + max_new_tokens: int, + *, + interactive: bool, + draft_model: Transformer, + speculate_k: Optional[int] = 8, + callback = lambda x: x, + **sampling_kwargs +) -> torch.Tensor: + """ + Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """ + + is_speculative = draft_model is not None + # create an empty tensor of the expected final shape and fill in the current tokens + T = prompt.size(0) + T_new = T + max_new_tokens + if interactive: + max_seq_length = 350 + else: + max_seq_length = min(T_new, model.config.block_size) + + device, dtype = prompt.device, prompt.dtype + max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length + with torch.device(device): + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + if is_speculative and draft_model is not model: + draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty(T_new, dtype=dtype, device=device) + empty[:T] = prompt + seq = empty + input_pos = torch.arange(0, T, device=device) + + next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs) + if is_speculative: + prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs) + seq[T] = next_token + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + accept_counts = [0] * (speculate_k + 1) + + if is_speculative: + input_pos = input_pos.item() # for speculative decoding easier to keep on host + while input_pos < T_new - 1: + cur_token = next_token.view(()) + + next_tokens = speculative_decode( + model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs + ) + + accept_counts[len(next_tokens) - 1] += 1 + num_added = min(T_new - input_pos - 1, len(next_tokens)) + seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added] + for i in next_tokens[: num_added,]: + callback(i) + input_pos = input_pos + num_added + next_token = next_tokens[-1] + else: + generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) + seq[T + 1:] = torch.cat(generated_tokens) + + generate_stats = { + 'accept_counts': accept_counts + } + return seq, generate_stats + +def encode_tokens(tokenizer, string, bos=True, device=default_device): + tokens = tokenizer.encode(string) + if bos: + tokens = [tokenizer.bos_id()] + tokens + return torch.tensor(tokens, dtype=torch.int, device=device) + +def _load_model(checkpoint_path, device, precision, use_tp): + use_cuda = 'cuda' in device + with torch.device('meta'): + model = Transformer.from_name(checkpoint_path.parent.name) + + if "int8" in str(checkpoint_path): + print("Using int8 weight-only quantization!") + from quantize import WeightOnlyInt8QuantHandler + simple_quantizer = WeightOnlyInt8QuantHandler(model) + model = simple_quantizer.convert_for_runtime() + + if "int4" in str(checkpoint_path): + print("Using int4 weight-only quantization!") + path_comps = checkpoint_path.name.split(".") + assert path_comps[-3].startswith("g") + assert path_comps[-2] in device, "weight packed format mismatch, please rerun quantize.py!" + groupsize = int(path_comps[-3][1:]) + from quantize import WeightOnlyInt4QuantHandler + simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) + model = simple_quantizer.convert_for_runtime(use_cuda) + + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + if "model" in checkpoint and "stories" in str(checkpoint_path): + checkpoint = checkpoint["model"] + model.load_state_dict(checkpoint, assign=True) + + if use_tp: + from tp import apply_tp + print("Applying tensor parallel to model ...") + apply_tp(model) + + model = model.to(device=device, dtype=precision) + return model.eval() + +B_INST, E_INST = "[INST]", "[/INST]" + +def main( + prompt: str = "Hello, my name is", + interactive: bool = False, + num_samples: int = 5, + max_new_tokens: int = 100, + top_k: int = 200, + temperature: float = 0.8, + checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), + compile: bool = True, + compile_prefill: bool = False, + profile: Optional[Path] = None, + draft_checkpoint_path: Optional[Path] = None, + speculate_k: int = 5, + device=default_device, +) -> None: + """Generates text samples based on a pre-trained Transformer model and tokenizer. + """ + assert checkpoint_path.is_file(), checkpoint_path + + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), str(tokenizer_path) + + # global print + # from tp import maybe_init_dist + # rank = maybe_init_dist() + # use_tp = rank is not None + # if use_tp: + # if rank != 0: + # # only print on rank 0 + # print = lambda *args, **kwargs: None + use_tp = False + + print(f"Using device={device}") + precision = torch.bfloat16 + is_speculative = draft_checkpoint_path is not None + is_chat = "chat" in str(checkpoint_path) + + print("Loading model ...") + t0 = time.time() + model = _load_model(checkpoint_path, device, precision, use_tp) + + if is_speculative: + draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp) + else: + draft_model = None + + device_sync(device=device) # MKG + print(f"Time to load model: {time.time() - t0:.02f} seconds") + + tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) + + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + prompt_length = encoded.size(0) + + torch.manual_seed(1234) + model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) + if compile: + if is_speculative and use_tp: # and ("cuda" in device): + torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case + + if is_speculative: + global model_forward, logits_to_prob + model_forward = torch.compile(model_forward, mode="max-autotune", fullgraph=True) + + global decode_one_token, prefill + print("compiling wiht max-autotune") + decode_one_token = torch.compile(decode_one_token, mode="max-autotune", fullgraph=True) + + # Uncomment to squeeze more perf out of prefill + if compile_prefill: + prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + + + aggregate_metrics = { + 'tokens_per_sec': [], + 'accept_counts': [], + } + start = -1 if compile else 0 + + for i in range(start, num_samples): + device_sync(device=device) # MKG + if i >= 0 and interactive: + prompt = input("What is your prompt? ") + if is_chat: + prompt = f"{B_INST} {prompt.strip()} {E_INST}" + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + + if interactive and i >= 0: + buffer = [] + period_id = tokenizer.encode('.')[0] + done_generating = False + def callback(x): + nonlocal done_generating + if done_generating: + return + buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) + if x.item() == tokenizer.eos_id(): + done_generating = True + if len(buffer) == 4 or done_generating: + print(''.join(buffer), end='', flush=True) + buffer.clear() + # print(, end='', flush=True) + else: + callback = lambda x : x + t0 = time.perf_counter() + import contextlib + if (i != num_samples - 1 or not profile) or (use_tp and rank != 0): + prof = contextlib.nullcontext() + else: + torch.profiler._utils._init_for_cuda_graphs() + prof = torch.profiler.profile() + with prof: + y, metrics = generate( + model, + encoded, + max_new_tokens, + draft_model=draft_model, + speculate_k=speculate_k, + interactive=interactive, + callback=callback, + temperature=temperature, + top_k=top_k, + ) + aggregate_metrics['accept_counts'].append(metrics['accept_counts']) + if i == -1: + print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + continue + if hasattr(prof, "export_chrome_trace"): + if use_tp: + prof.export_chrome_trace(f"{profile}_rank_{rank}.json") + else: + prof.export_chrome_trace(f"{profile}.json") + device_sync(device=device) # MKG + t = time.perf_counter() - t0 + + if not interactive: + print(tokenizer.decode(y.tolist())) + else: + print() + tokens_generated = y.size(0) - prompt_length + tokens_sec = tokens_generated / t + aggregate_metrics['tokens_per_sec'].append(tokens_sec) + print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") + print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") + print("==========") + if is_speculative: + counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])] + acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated] + print(f"Acceptance probs: {acceptance_probs}") + print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}") + + print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") + print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Your CLI description.') + + parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') + parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') + parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') + parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') + parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') + parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') + parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') + parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') + parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') + parser.add_argument('--profile', type=Path, default=None, help='Profile path.') + parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.') + parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.') + parser.add_argument('--device', type=str, default=default_device, help='Device to use') + + args = parser.parse_args() + main( + args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, + args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path, + args.speculate_k, args.device + ) diff --git a/benchmarks/gpt_fused/tokenizer.py b/benchmarks/gpt_fused/tokenizer.py new file mode 100644 index 0000000000..c62a0c5b3a --- /dev/null +++ b/benchmarks/gpt_fused/tokenizer.py @@ -0,0 +1,111 @@ +import os +import sentencepiece as spm +import tiktoken +from tiktoken.load import load_tiktoken_bpe +from pathlib import Path +from typing import Dict + +class TokenizerInterface: + def __init__(self, model_path): + self.model_path = model_path + + def encode(self, text): + raise NotImplementedError("This method should be overridden by subclasses.") + + def decode(self, tokens): + raise NotImplementedError("This method should be overridden by subclasses.") + + def bos_id(self): + raise NotImplementedError("This method should be overridden by subclasses.") + + def eos_id(self): + raise NotImplementedError("This method should be overridden by subclasses.") + +class SentencePieceWrapper(TokenizerInterface): + def __init__(self, model_path): + super().__init__(model_path) + self.processor = spm.SentencePieceProcessor(str(model_path)) + + def encode(self, text): + return self.processor.EncodeAsIds(text) + + def decode(self, tokens): + return self.processor.DecodeIds(tokens) + + def bos_id(self): + return self.processor.bos_id() + + def eos_id(self): + return self.processor.eos_id() + +class TiktokenWrapper(TokenizerInterface): + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 + + def __init__(self, model_path): + super().__init__(model_path) + assert os.path.isfile(model_path), str(model_path) + mergeable_ranks = load_tiktoken_bpe(str(model_path)) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + [ + f"<|reserved_special_token_{i}|>" + for i in range(5, self.num_reserved_special_tokens - 5) + ] + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + # BOS / EOS token IDs + self._bos_id: int = self.special_tokens["<|begin_of_text|>"] + self._eos_id: int = self.special_tokens["<|end_of_text|>"] + + def encode(self, text): + return self.model.encode(text) + + def decode(self, tokens): + return self.model.decode(tokens) + + def bos_id(self): + return self._bos_id + + def eos_id(self): + return self._eos_id + +def get_tokenizer(tokenizer_model_path, model_name): + """ + Factory function to get the appropriate tokenizer based on the model name. + + Args: + - tokenizer_model_path (str): The file path to the tokenizer model. + - model_name (str): The name of the model, used to determine the tokenizer type. + + Returns: + - TokenizerInterface: An instance of a tokenizer. + """ + if "Llama-3" in str(model_name): + return TiktokenWrapper(tokenizer_model_path) + else: + return SentencePieceWrapper(tokenizer_model_path) diff --git a/torchao/prototype/models/gpt_fused/model.py b/torchao/prototype/models/gpt_fused/model.py index 0660bc2b72..a303b19ff5 100644 --- a/torchao/prototype/models/gpt_fused/model.py +++ b/torchao/prototype/models/gpt_fused/model.py @@ -206,12 +206,20 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona class FeedForward(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() - self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) - self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w13 = nn.Linear(config.dim, 2 * config.intermediate_size, bias=False) self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + self.dim = config.intermediate_size + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "w1.weight" in state_dict: + w1 = state_dict.pop(prefix + "w1.weight") + w3 = state_dict.pop(prefix + "w3.weight") + state_dict[prefix + "w13.weight"] = torch.cat([w1, w3]) def forward(self, x: Tensor) -> Tensor: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) + x1, x3 = self.w13(x).split([self.dim, self.dim], dim=-1) + return self.w2(F.silu(x1) * x3) class RMSNorm(nn.Module): From 620e2089e32207c7dce144ae81b256c3657a75c3 Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Mon, 29 Apr 2024 17:23:41 -0700 Subject: [PATCH 5/7] Simpler benchmark --- benchmarks/gpt_fused/generate.py | 208 ++++--------------------------- 1 file changed, 24 insertions(+), 184 deletions(-) diff --git a/benchmarks/gpt_fused/generate.py b/benchmarks/gpt_fused/generate.py index 4e7bb88728..d21a31526f 100644 --- a/benchmarks/gpt_fused/generate.py +++ b/benchmarks/gpt_fused/generate.py @@ -13,18 +13,11 @@ import torch._dynamo.config import torch._inductor.config -def device_sync(device): - if "cuda" in device: - torch.cuda.synchronize(device) - elif ("cpu" in device) or ("mps" in device): - pass - else: - print(f"device={device} is not yet suppported") - +from torchao.prototype.models.gpt_fused.model import Transformer +from tokenizer import get_tokenizer torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.triton.unique_kernel_names = True -torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future default_device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -32,8 +25,6 @@ def device_sync(device): wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -from torchao.prototype.models.gpt_fused.model import Transformer -from tokenizer import get_tokenizer def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization q = torch.empty_like(probs_sort).exponential_(1) @@ -65,7 +56,7 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso logits = model(x, input_pos) return sample(logits, **sampling_kwargs) -def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): +def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, **sampling_kwargs): new_tokens, new_probs = [], [] for i in range(num_new_tokens): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here @@ -74,7 +65,6 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc ) input_pos += 1 new_tokens.append(next_token.clone()) - callback(new_tokens[-1]) new_probs.append(next_prob.clone()) cur_token = next_token.view(1, -1) @@ -84,87 +74,26 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc def model_forward(model, x, input_pos): return model(x, input_pos) -def speculative_decode( - model: Transformer, - draft_model: Transformer, - cur_token: torch.Tensor, - input_pos: int, - speculate_k: int, - **sampling_kwargs -) -> torch.Tensor: - # draft model inference sequentially - device = cur_token.device - orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device) - draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs) - - draft_tokens = torch.cat(draft_tokens) - # parallel inference on target model using draft tokens - target_logits = model_forward( - model, - torch.cat([cur_token.view(1), draft_tokens]).view(1, -1), - torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device) - ) - target_probs = logits_to_probs(target_logits[0], **sampling_kwargs) - draft_probs = torch.stack(draft_probs) - # q: target prob, p: draft prob - # q >= p: always accept draft token - # q < p: q/p prob to accept draft token - p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens] - q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens] - accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p) - rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero() - - if rejected_locations.shape[0] == 0: # All draft tokens have been accepted - accept_length = speculate_k + 1 - last_token = multinomial_sample_one_no_sync(target_probs[-1]) - # fill last token into draft model - model_forward( - draft_model, - draft_tokens[-1].view(1, -1), - orig_input_pos + speculate_k, - ) - return torch.cat([draft_tokens, last_token]) - else: - accept_length = rejected_locations[0].item() - p = draft_probs[accept_length] - q = target_probs[accept_length] - new = q - p - new = torch.where(new > 0, new, 0.0) - new = new / new.sum() - next_token = multinomial_sample_one_no_sync(new) - return torch.cat([draft_tokens[:accept_length], next_token]) - @torch.no_grad() def generate( model: Transformer, prompt: torch.Tensor, max_new_tokens: int, - *, - interactive: bool, - draft_model: Transformer, - speculate_k: Optional[int] = 8, - callback = lambda x: x, **sampling_kwargs ) -> torch.Tensor: """ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. """ - is_speculative = draft_model is not None # create an empty tensor of the expected final shape and fill in the current tokens T = prompt.size(0) T_new = T + max_new_tokens - if interactive: - max_seq_length = 350 - else: - max_seq_length = min(T_new, model.config.block_size) + max_seq_length = min(T_new, model.config.block_size) device, dtype = prompt.device, prompt.dtype - max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length + max_seq_length = max_seq_length with torch.device(device): model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - if is_speculative and draft_model is not model: - draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) # create an empty tensor of the expected final shape and fill in the current tokens empty = torch.empty(T_new, dtype=dtype, device=device) @@ -173,32 +102,13 @@ def generate( input_pos = torch.arange(0, T, device=device) next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs) - if is_speculative: - prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs) seq[T] = next_token input_pos = torch.tensor([T], device=device, dtype=torch.int) - accept_counts = [0] * (speculate_k + 1) + accept_counts = [0] * (1) - if is_speculative: - input_pos = input_pos.item() # for speculative decoding easier to keep on host - while input_pos < T_new - 1: - cur_token = next_token.view(()) - - next_tokens = speculative_decode( - model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs - ) - - accept_counts[len(next_tokens) - 1] += 1 - num_added = min(T_new - input_pos - 1, len(next_tokens)) - seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added] - for i in next_tokens[: num_added,]: - callback(i) - input_pos = input_pos + num_added - next_token = next_tokens[-1] - else: - generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) - seq[T + 1:] = torch.cat(generated_tokens) + generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, **sampling_kwargs) + seq[T + 1:] = torch.cat(generated_tokens) generate_stats = { 'accept_counts': accept_counts @@ -249,7 +159,6 @@ def _load_model(checkpoint_path, device, precision, use_tp): def main( prompt: str = "Hello, my name is", - interactive: bool = False, num_samples: int = 5, max_new_tokens: int = 100, top_k: int = 200, @@ -257,9 +166,6 @@ def main( checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), compile: bool = True, compile_prefill: bool = False, - profile: Optional[Path] = None, - draft_checkpoint_path: Optional[Path] = None, - speculate_k: int = 5, device=default_device, ) -> None: """Generates text samples based on a pre-trained Transformer model and tokenizer. @@ -281,19 +187,12 @@ def main( print(f"Using device={device}") precision = torch.bfloat16 - is_speculative = draft_checkpoint_path is not None - is_chat = "chat" in str(checkpoint_path) print("Loading model ...") t0 = time.time() model = _load_model(checkpoint_path, device, precision, use_tp) - if is_speculative: - draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp) - else: - draft_model = None - - device_sync(device=device) # MKG + torch.cuda.synchronize(device) print(f"Time to load model: {time.time() - t0:.02f} seconds") tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) @@ -304,16 +203,9 @@ def main( torch.manual_seed(1234) model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) if compile: - if is_speculative and use_tp: # and ("cuda" in device): - torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case - - if is_speculative: - global model_forward, logits_to_prob - model_forward = torch.compile(model_forward, mode="max-autotune", fullgraph=True) - global decode_one_token, prefill print("compiling wiht max-autotune") - decode_one_token = torch.compile(decode_one_token, mode="max-autotune", fullgraph=True) + model = torch.compile(model, mode="max-autotune", fullgraph=True) # Uncomment to squeeze more perf out of prefill if compile_prefill: @@ -327,77 +219,29 @@ def main( start = -1 if compile else 0 for i in range(start, num_samples): - device_sync(device=device) # MKG - if i >= 0 and interactive: - prompt = input("What is your prompt? ") - if is_chat: - prompt = f"{B_INST} {prompt.strip()} {E_INST}" - encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) - - if interactive and i >= 0: - buffer = [] - period_id = tokenizer.encode('.')[0] - done_generating = False - def callback(x): - nonlocal done_generating - if done_generating: - return - buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) - if x.item() == tokenizer.eos_id(): - done_generating = True - if len(buffer) == 4 or done_generating: - print(''.join(buffer), end='', flush=True) - buffer.clear() - # print(, end='', flush=True) - else: - callback = lambda x : x + torch.cuda.synchronize(device) t0 = time.perf_counter() - import contextlib - if (i != num_samples - 1 or not profile) or (use_tp and rank != 0): - prof = contextlib.nullcontext() - else: - torch.profiler._utils._init_for_cuda_graphs() - prof = torch.profiler.profile() - with prof: - y, metrics = generate( - model, - encoded, - max_new_tokens, - draft_model=draft_model, - speculate_k=speculate_k, - interactive=interactive, - callback=callback, - temperature=temperature, - top_k=top_k, - ) - aggregate_metrics['accept_counts'].append(metrics['accept_counts']) + y, metrics = generate( + model, + encoded, + max_new_tokens, + temperature=temperature, + top_k=top_k, + ) + aggregate_metrics['accept_counts'].append(metrics['accept_counts']) if i == -1: print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") continue - if hasattr(prof, "export_chrome_trace"): - if use_tp: - prof.export_chrome_trace(f"{profile}_rank_{rank}.json") - else: - prof.export_chrome_trace(f"{profile}.json") - device_sync(device=device) # MKG + torch.cuda.synchronize(device) t = time.perf_counter() - t0 - if not interactive: - print(tokenizer.decode(y.tolist())) - else: - print() + print(tokenizer.decode(y.tolist())) tokens_generated = y.size(0) - prompt_length tokens_sec = tokens_generated / t aggregate_metrics['tokens_per_sec'].append(tokens_sec) print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") print("==========") - if is_speculative: - counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])] - acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated] - print(f"Acceptance probs: {acceptance_probs}") - print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}") - print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") @@ -407,7 +251,6 @@ def callback(x): parser = argparse.ArgumentParser(description='Your CLI description.') parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') - parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') @@ -415,14 +258,11 @@ def callback(x): parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') - parser.add_argument('--profile', type=Path, default=None, help='Profile path.') - parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.') - parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.') parser.add_argument('--device', type=str, default=default_device, help='Device to use') args = parser.parse_args() main( - args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, - args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path, - args.speculate_k, args.device + args.prompt, args.num_samples, args.max_new_tokens, args.top_k, + args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, + args.device ) From a994000e213bcde2a53e5e474fd9065d16d7e89e Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Tue, 30 Apr 2024 17:14:40 -0700 Subject: [PATCH 6/7] Remove tp or speculative decoding --- benchmarks/gpt_fused/generate.py | 96 ++++++++++++-------------------- 1 file changed, 37 insertions(+), 59 deletions(-) diff --git a/benchmarks/gpt_fused/generate.py b/benchmarks/gpt_fused/generate.py index d21a31526f..d94aa4f985 100644 --- a/benchmarks/gpt_fused/generate.py +++ b/benchmarks/gpt_fused/generate.py @@ -50,6 +50,7 @@ def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **samp logits = model(x, input_pos) return sample(logits, **sampling_kwargs)[0] +@torch.compile(mode='max-autotune', fullgraph=True) def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: # input_pos: [B, 1] assert input_pos.shape[-1] == 1 @@ -105,15 +106,10 @@ def generate( seq[T] = next_token input_pos = torch.tensor([T], device=device, dtype=torch.int) - accept_counts = [0] * (1) - generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, **sampling_kwargs) seq[T + 1:] = torch.cat(generated_tokens) - generate_stats = { - 'accept_counts': accept_counts - } - return seq, generate_stats + return seq def encode_tokens(tokenizer, string, bos=True, device=default_device): tokens = tokenizer.encode(string) @@ -121,7 +117,7 @@ def encode_tokens(tokenizer, string, bos=True, device=default_device): tokens = [tokenizer.bos_id()] + tokens return torch.tensor(tokens, dtype=torch.int, device=device) -def _load_model(checkpoint_path, device, precision, use_tp): +def _load_model(checkpoint_path, device, precision): use_cuda = 'cuda' in device with torch.device('meta'): model = Transformer.from_name(checkpoint_path.parent.name) @@ -147,16 +143,20 @@ def _load_model(checkpoint_path, device, precision, use_tp): checkpoint = checkpoint["model"] model.load_state_dict(checkpoint, assign=True) - if use_tp: - from tp import apply_tp - print("Applying tensor parallel to model ...") - apply_tp(model) - model = model.to(device=device, dtype=precision) return model.eval() B_INST, E_INST = "[INST]", "[/INST]" +def profiler_runner(path, fn, *args, **kwargs): + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA], + record_shapes=True) as prof: + result = fn(*args, **kwargs) + prof.export_chrome_trace(path) + return result + def main( prompt: str = "Hello, my name is", num_samples: int = 5, @@ -164,8 +164,6 @@ def main( top_k: int = 200, temperature: float = 0.8, checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), - compile: bool = True, - compile_prefill: bool = False, device=default_device, ) -> None: """Generates text samples based on a pre-trained Transformer model and tokenizer. @@ -175,22 +173,12 @@ def main( tokenizer_path = checkpoint_path.parent / "tokenizer.model" assert tokenizer_path.is_file(), str(tokenizer_path) - # global print - # from tp import maybe_init_dist - # rank = maybe_init_dist() - # use_tp = rank is not None - # if use_tp: - # if rank != 0: - # # only print on rank 0 - # print = lambda *args, **kwargs: None - use_tp = False - print(f"Using device={device}") precision = torch.bfloat16 print("Loading model ...") t0 = time.time() - model = _load_model(checkpoint_path, device, precision, use_tp) + model = _load_model(checkpoint_path, device, precision) torch.cuda.synchronize(device) print(f"Time to load model: {time.time() - t0:.02f} seconds") @@ -202,45 +190,36 @@ def main( torch.manual_seed(1234) model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) - if compile: - global decode_one_token, prefill - print("compiling wiht max-autotune") - model = torch.compile(model, mode="max-autotune", fullgraph=True) - - # Uncomment to squeeze more perf out of prefill - if compile_prefill: - prefill = torch.compile(prefill, fullgraph=True, dynamic=True) - aggregate_metrics = { 'tokens_per_sec': [], - 'accept_counts': [], } - start = -1 if compile else 0 - - for i in range(start, num_samples): - torch.cuda.synchronize(device) - t0 = time.perf_counter() - y, metrics = generate( - model, - encoded, - max_new_tokens, - temperature=temperature, - top_k=top_k, - ) - aggregate_metrics['accept_counts'].append(metrics['accept_counts']) - if i == -1: - print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") - continue - torch.cuda.synchronize(device) - t = time.perf_counter() - t0 - - print(tokenizer.decode(y.tolist())) + + for i in range(num_samples): + with torch.autograd.profiler.record_function(f"timed region for inference {i}"): + torch.cuda.synchronize(device) + t0 = time.perf_counter() + y = generate( + model, + encoded, + max_new_tokens, + temperature=temperature, + top_k=top_k, + ) + if i == 0: + print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + torch.cuda.synchronize(device) + t = time.perf_counter() - t0 + + # print(tokenizer.decode(y.tolist())) tokens_generated = y.size(0) - prompt_length tokens_sec = tokens_generated / t - aggregate_metrics['tokens_per_sec'].append(tokens_sec) print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") + if i > 0: + aggregate_metrics['tokens_per_sec'].append(tokens_sec) + else: + print("Don't count first inference run.") print("==========") print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") @@ -256,13 +235,12 @@ def main( parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') - parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') - parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') parser.add_argument('--device', type=str, default=default_device, help='Device to use') args = parser.parse_args() + # profiler_runner("profile.json.gz", main, main( args.prompt, args.num_samples, args.max_new_tokens, args.top_k, - args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, + args.temperature, args.checkpoint_path, args.device ) From 12535d809519412e16e5f8178d3285f19a935590 Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Wed, 1 May 2024 12:25:59 -0700 Subject: [PATCH 7/7] Trying weight only int8 quant --- benchmarks/gpt_fused/generate.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/benchmarks/gpt_fused/generate.py b/benchmarks/gpt_fused/generate.py index d94aa4f985..a9a246e3e2 100644 --- a/benchmarks/gpt_fused/generate.py +++ b/benchmarks/gpt_fused/generate.py @@ -61,6 +61,9 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc new_tokens, new_probs = [], [] for i in range(num_new_tokens): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here + # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True, enable_math=True): # Actually better for Inductor to codegen attention here + # with torch.autograd.profiler.record_function(f"generate token {i}"): + # torch.cuda.synchronize() next_token, next_prob = decode_one_token( model, cur_token, input_pos, **sampling_kwargs ) @@ -68,6 +71,7 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc new_tokens.append(next_token.clone()) new_probs.append(next_prob.clone()) cur_token = next_token.view(1, -1) + # torch.cuda.synchronize() return new_tokens, new_probs @@ -96,6 +100,9 @@ def generate( with torch.device(device): model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + from torchao.quantization import apply_weight_only_int8_quant + apply_weight_only_int8_quant(model) + # create an empty tensor of the expected final shape and fill in the current tokens empty = torch.empty(T_new, dtype=dtype, device=device) empty[:T] = prompt @@ -106,10 +113,14 @@ def generate( seq[T] = next_token input_pos = torch.tensor([T], device=device, dtype=torch.int) + torch.cuda.synchronize(device) + t0 = time.perf_counter() generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, **sampling_kwargs) seq[T + 1:] = torch.cat(generated_tokens) + torch.cuda.synchronize(device) + t = time.perf_counter() - t0 - return seq + return seq, t def encode_tokens(tokenizer, string, bos=True, device=default_device): tokens = tokenizer.encode(string) @@ -197,9 +208,7 @@ def main( for i in range(num_samples): with torch.autograd.profiler.record_function(f"timed region for inference {i}"): - torch.cuda.synchronize(device) - t0 = time.perf_counter() - y = generate( + y, t = generate( model, encoded, max_new_tokens, @@ -208,8 +217,6 @@ def main( ) if i == 0: print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") - torch.cuda.synchronize(device) - t = time.perf_counter() - t0 # print(tokenizer.decode(y.tolist())) tokens_generated = y.size(0) - prompt_length