From d63e541aea9da10b4c92a378345484b522f80117 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 18 Jun 2024 10:42:00 -0700 Subject: [PATCH 1/4] Additional script fixes Summary: int4wo had an issue with device swap after quantization int4wo-gptq had an issue with.... Test Plan: python eval.py -q int4wo-64 --compile wikitext: {'word_perplexity,none': 12.842987954345306, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.611855472207904, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.6887223897240059, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'} python eval.py -q int4wo-64-gptq --compile Reviewers: Subscribers: Tasks: Tags: --- torchao/_models/llama/eval.py | 2 +- torchao/quantization/GPTQ.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 7842c3e66c..2b22efef06 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -66,7 +66,7 @@ def run_evaluation( if "int4wo" in quantization and not "gptq" in quantization: groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - quantize(model, int4wo(groupsize=groupsize)) + quantize(model.to(device), int4wo(groupsize=groupsize)) if "int4wo" in quantization and "gptq" in quantization: groupsize=int(quantization.split("-")[-2]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 12910c89bf..590bf5e94e 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -672,9 +672,9 @@ def quantize( class Int4WeightOnlyGPTQQuantizer(GPTQQuantizer): def __init__( self, - blocksize, - percdamp, - groupsize, + blocksize=128, + percdamp=0.01, + groupsize=64, inner_k_tiles=8, padding_allowed=True, device: torch.device = torch.device("cuda"), From 4b70e56fd2c86bc0858c04df5bcffc4e216283a5 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 18 Jun 2024 12:31:46 -0700 Subject: [PATCH 2/4] fix use_index_put_for_kv_cache Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/_models/llama/eval.py | 3 ++- torchao/quantization/GPTQ.py | 8 +++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 2b22efef06..e70dba1ed3 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -71,6 +71,7 @@ def run_evaluation( groupsize=int(quantization.split("-")[-2]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" assert precision==torch.bfloat16, f"{quantization} requires precision or bfloat16 but got {precision}" + assert "cuda" in device, "int4 gptq quantization only works on cuda" inputs = InputRecorder( tokenizer, calibration_seq_length, @@ -83,7 +84,7 @@ def run_evaluation( calibration_limit, ).get_inputs() - quantizer = Int4WeightOnlyGPTQQuantizer(groupsize=groupsize) + quantizer = Int4WeightOnlyGPTQQuantizer(groupsize=groupsize, device=device) model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) model = quantizer.quantize(model, inputs).to(device) else: diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 590bf5e94e..144645375b 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -39,6 +39,11 @@ ) aten = torch.ops.aten +# need this to fix the model so it works for GPTQ +import torchao +from torchao._models.llama.model import use_index_put_for_kv_cache +torchao._models.llama.model.use_index_put_for_kv_cache = True + if not _lm_eval_available: logging.info("lm_eval is not installed, GPTQ may not be usable") @@ -81,6 +86,7 @@ def __init__( # trace model for one input one_input = [multi.values[0].cpu() for multi in inputs] # pyre-ignore[16] + model.cpu()(*one_input) exported_model = torch._dynamo.export( model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake" )(*one_input) @@ -95,7 +101,7 @@ def __init__( self.groupsize = groupsize self.inputs = inputs self.gptq_done = False - self.debug = False + self.debug = True def configure_quantization_mode( self, From cf8a3622bfc285fb28f39c60f236ef80b77bab1e Mon Sep 17 00:00:00 2001 From: HDCharles Date: Fri, 21 Jun 2024 13:42:25 -0700 Subject: [PATCH 3/4] final tests Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_quant_api.py | 2 -- torchao/_models/_eval.py | 27 ++++++++++++++++----------- torchao/_models/llama/eval.py | 2 +- torchao/_models/llama/generate.py | 4 ++-- torchao/_models/llama/tokenizer.py | 4 ++++ torchao/quantization/GPTQ.py | 9 +++------ torchao/quantization/__init__.py | 2 -- torchao/quantization/autoquant.py | 16 +++++++++++----- 8 files changed, 37 insertions(+), 29 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index b22a157568..e8b9d606d7 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -238,7 +238,6 @@ def test_8da4w_quantizer(self): def test_8da4w_gptq_quantizer(self): from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer from torchao._models._eval import InputRecorder, TransformerEvalWrapper - torchao._models.llama.model.use_index_put_for_kv_cache = True # should be similar to TorchCompileDynamicQuantizer precision = torch.bfloat16 device = "cpu" @@ -338,7 +337,6 @@ def test_8da4w_quantizer_eval(self): def test_gptq_quantizer_int4_weight_only(self): from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer from torchao._models._eval import InputRecorder, TransformerEvalWrapper - torchao._models.llama.model.use_index_put_for_kv_cache = True precision = torch.bfloat16 device = "cuda" checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") diff --git a/torchao/_models/_eval.py b/torchao/_models/_eval.py index 2d53865659..858196776f 100644 --- a/torchao/_models/_eval.py +++ b/torchao/_models/_eval.py @@ -52,8 +52,13 @@ def __init__( pad_token=0, device="cpu", ): - super().__init__() - self._tokenizer = tokenizer + try: + super().__init__() + except TypeError: + # lm_eval 0.4.2 removed the default init + super().__init__("gpt2", device="cpu") + + self.tokenizer = tokenizer self._device = torch.device(device) self.vocab_size = vocab_size self._max_seq_length = calibration_seq_length @@ -74,9 +79,9 @@ def __init__( @property def eot_token_id(self): try: - return self._tokenizer.eos_id() + return self.tokenizer.eos_id() except: - return self._tokenizer.eos_id + return self.tokenizer.eos_id @property def max_length(self): @@ -96,16 +101,16 @@ def device(self): def tok_encode(self, string: str, **kwargs): # TODO: verify this for multi-batch as well - tokens = self._tokenizer.encode(string) - if hasattr(self._tokenizer, "bos_id"): + tokens = self.tokenizer.encode(string) + if hasattr(self.tokenizer, "bos_id"): try: - tokens = [self._tokenizer.bos_id()] + tokens + tokens = [self.tokenizer.bos_id()] + tokens except: - tokens = [self._tokenizer.bos_id] + tokens + tokens = [self.tokenizer.bos_id] + tokens return tokens def tok_decode(self, tokens): - decoded = self._tokenizer.decode(tokens) + decoded = self.tokenizer.decode(tokens) return decoded def add_input(self, args): @@ -185,9 +190,9 @@ def __init__( input_prep_func=None, device="cuda" ): - super().__init__(None, None) + super().__init__(tokenizer, None) self._model = model - self._tokenizer = tokenizer + # self.tokenizer = tokenizer self._device = torch.device(device) self._max_seq_length = max_seq_length diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index e70dba1ed3..862cd19841 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -21,7 +21,7 @@ from tokenizer import get_tokenizer import time from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer -from model import prepare_inputs_for_model +from torchao._models.llama.model import prepare_inputs_for_model torch._inductor.config.fx_graph_cache = True torch._inductor.config.force_fuse_int_mm_with_mul = True diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 7f7cfab885..ce19a7be24 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -35,8 +35,8 @@ def device_sync(device): wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -from model import Transformer, prepare_inputs_for_model -from tokenizer import get_tokenizer +from torchao._models.llama.model import Transformer, prepare_inputs_for_model +from torchao._models.llama.tokenizer import get_tokenizer def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization q = torch.empty_like(probs_sort).exponential_(1) diff --git a/torchao/_models/llama/tokenizer.py b/torchao/_models/llama/tokenizer.py index 15082068db..2d6be2dcf9 100644 --- a/torchao/_models/llama/tokenizer.py +++ b/torchao/_models/llama/tokenizer.py @@ -30,6 +30,8 @@ class SentencePieceWrapper(TokenizerInterface): def __init__(self, model_path): super().__init__(model_path) self.processor = spm.SentencePieceProcessor(str(model_path)) + self.bos_token_id = self.bos_id() + self.eos_token_id = self.eos_id() def encode(self, text): return self.processor.EncodeAsIds(text) @@ -86,6 +88,8 @@ def __init__(self, model_path): # BOS / EOS token IDs self._bos_id: int = self.special_tokens["<|begin_of_text|>"] self._eos_id: int = self.special_tokens["<|end_of_text|>"] + self.bos_token_id = self.bos_id() + self.eos_token_id = self.eos_id() def encode(self, text): return self.model.encode(text) diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 144645375b..275e716aa1 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -39,11 +39,6 @@ ) aten = torch.ops.aten -# need this to fix the model so it works for GPTQ -import torchao -from torchao._models.llama.model import use_index_put_for_kv_cache -torchao._models.llama.model.use_index_put_for_kv_cache = True - if not _lm_eval_available: logging.info("lm_eval is not installed, GPTQ may not be usable") @@ -86,7 +81,9 @@ def __init__( # trace model for one input one_input = [multi.values[0].cpu() for multi in inputs] # pyre-ignore[16] - model.cpu()(*one_input) + # needed for GPTQ on the torchao llama model + import torchao + torchao._models.llama.model.use_index_put_for_kv_cache = True exported_model = torch._dynamo.export( model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake" )(*one_input) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index f1bb82921e..b4139f0c1d 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -17,8 +17,6 @@ "swap_conv2d_1x1_to_linear" "safe_int_mm", "autoquant", - "change_linears_to_autoquantizable", - "change_autoquantizable_to_quantized", "get_scale", "SmoothFakeDynQuantMixin", "SmoothFakeDynamicallyQuantizedLinear", diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 1a63fed57c..18a58cd17f 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -17,6 +17,12 @@ except: from torch._inductor.runtime.runtime_utils import do_bench +__all__ = [ + "AutoQuantizableLinearWeight", + "autoquant", +] + + aten = torch.ops.aten AUTOQUANT_CACHE = {} @@ -382,11 +388,11 @@ def from_float(cls, weight): AQInt8DynamicallyQuantizedLinearWeight, ] -def change_linears_to_autoquantizable(model, **kwargs): +def _change_linears_to_autoquantizable(model, **kwargs): """ Converts all linear weight tensors to the AutoQuantizableLinearWeight tensor subclass. Expectation is that this is followed - by running the model and then calling change_autoquantizable_to_quantized + by running the model and then calling _change_autoquantizable_to_quantized """ from torchao.quantization.quant_api import _is_linear filter_fn = kwargs.pop("filter_fn", _is_linear) @@ -401,7 +407,7 @@ def change_linears_to_autoquantizable(model, **kwargs): filter_fn if filter_fn is not None else _is_linear, ) -def change_autoquantizable_to_quantized(model, **kwargs): +def _change_autoquantizable_to_quantized(model, **kwargs): """ Converts AutoQuantizableLinearWeight tensor subclasses to various quantized/non-quantized tensor subclasses depending @@ -490,7 +496,7 @@ def autoquant( # perform initial swap from linear weights # to AutoQuantizableLinearWeight - change_linears_to_autoquantizable( + _change_linears_to_autoquantizable( model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list, @@ -531,7 +537,7 @@ def autoquant_prehook(module, args, kwargs): # note the torch.compile wrapper (eval_frame) moves the assignment of any assigned # attributes to the inner model that didn't exist before, so we have to call delattr on the inner model def finalize_autoquant(): - change_autoquantizable_to_quantized( + _change_autoquantizable_to_quantized( real_model, **aq_kwargs, ) From 79b0c1d520e7d1c3699dce79e2aef4a6496d5a7c Mon Sep 17 00:00:00 2001 From: HDCharles Date: Fri, 21 Jun 2024 13:59:50 -0700 Subject: [PATCH 4/4] updating quantize apis Summary: Test Plan: two python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt two python eval.py -q int4wo Reviewers: Subscribers: Tasks: Tags: --- torchao/_models/llama/eval.py | 8 ++++---- torchao/_models/llama/generate.py | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 862cd19841..36e5085018 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -13,7 +13,7 @@ ) from torchao.quantization.quant_api import ( - quantize, int4wo, int8wo, int8da_int8w, unwrap_tensor_subclass + quantize, int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight, unwrap_tensor_subclass ) from torchao._models._eval import TransformerEvalWrapper, InputRecorder @@ -60,13 +60,13 @@ def run_evaluation( if quantization: if "int8wo" in quantization: - quantize(model, int8wo()) + quantize(model, int8_weight_only()) if "int8dq" in quantization: - quantize(model, int8da_int8w()) + quantize(model, int8_dynamic_activation_int8_weight()) if "int4wo" in quantization and not "gptq" in quantization: groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - quantize(model.to(device), int4wo(groupsize=groupsize)) + quantize(model.to(device), int4_weight_only(group_size=groupsize)) if "int4wo" in quantization and "gptq" in quantization: groupsize=int(quantization.split("-")[-2]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index ce19a7be24..34e7ca82b2 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -189,21 +189,21 @@ def main( if quantization: from torchao.quantization.quant_api import ( quantize, - int8wo, - int8da_int8w, - int4wo, + int8_weight_only, + int8_dynamic_activation_int8_weight, + int4_weight_only, autoquant, unwrap_tensor_subclass ) if "int8wo" in quantization: - quantize(model, int8wo()) + quantize(model, int8_weight_only()) if "int8dq" in quantization: - quantize(model, int8da_int8w()) + quantize(model, int8_dynamic_activation_int8_weight()) if "int4wo" in quantization: groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - quantize(model, int4wo(groupsize=groupsize)) + quantize(model, int4_weight_only(groupsize=groupsize)) if "autoquant" == quantization: model = autoquant(model, manual=True)