From 79e740cb25ad4deef7fe3d021055dd7e45b5cf47 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 2 Jul 2024 11:25:53 -0700 Subject: [PATCH] Renaming `quantize` to `quantize_` Summary: Addressing feedback for `quantize` API from https://github.com/pytorch/ao/issues/391#issuecomment-2174713094 this is an API that changes model inplace, so we want to change the name to reflect that. inplace model quantization is important especially for LLM since it will be hard to load the model to memory. we typically load the model to meta device and then load the quantized weight. Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags: --- README.md | 6 ++-- test/integration/test_integration.py | 18 ++++++------ test/prototype/test_quant_llm.py | 4 +-- test/quantization/test_quant_api.py | 18 ++++++------ torchao/__init__.py | 4 +-- torchao/_models/llama/eval.py | 36 +++++++++++------------ torchao/_models/llama/generate.py | 14 ++++----- torchao/prototype/quant_llm/README.md | 6 ++-- torchao/quantization/README.md | 10 +++---- torchao/quantization/__init__.py | 2 +- torchao/quantization/quant_api.py | 14 ++++----- tutorials/quantize_vit/run_vit_b_quant.py | 4 +-- 12 files changed, 68 insertions(+), 68 deletions(-) diff --git a/README.md b/README.md index e99f565900..bcb3966bff 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,8 @@ All with no intrusive code changes and minimal accuracy degradation. Quantizing your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/) and a HuggingFace inference example [here](scripts/hf_eval.py) ```python -from torchao.quantization.quant_api import quantize, int4_weight_only -m = quantize(m, int4_weight_only()) +from torchao.quantization.quant_api import quantize_, int4_weight_only +quantize_(m, int4_weight_only()) ``` Benchmarks are run on a machine with a single A100 GPU using the script in `_models/llama` which generates text in a latency-optimized way (batchsize=1) @@ -70,7 +70,7 @@ swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear}) * [MX](torchao/prototype/mx_formats) implementing training and inference support with tensors using the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet. * [nf4](torchao/dtypes/nf4tensor.py) which was used to [implement QLoRA](https://github.com/pytorch/torchtune/blob/main/docs/source/tutorials/qlora_finetune.rst) one of the most popular finetuning algorithms without writing custom Triton or CUDA code. Accessible talk [here](https://x.com/HamelHusain/status/1800315287574847701) -* [fp6](torchao/prototype/quant_llm/) for 2x faster inference over fp16 with an easy to use API `quantize(model, fp6_llm_weight_only())` +* [fp6](torchao/prototype/quant_llm/) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fp6_llm_weight_only())` ## Composability diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 4d5a2c511c..c21f3a38be 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -23,7 +23,7 @@ int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight, - quantize, + quantize_, _replace_with_custom_fn_if_matches_filter, ) # APIs to be deprecated (used for torch 2.2.2 and 2.3) @@ -98,21 +98,21 @@ def _int8wo_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int8_weight_only(), set_inductor_config=False) + quantize_(mod, int8_weight_only(), set_inductor_config=False) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int8_woqtensors(mod) def _int8da_int8w_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False) + quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int8_dqtensors(mod) def _int4wo_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int4_weight_only(), set_inductor_config=False) + quantize_(mod, int4_weight_only(), set_inductor_config=False) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int4_woqtensors(mod) @@ -127,8 +127,8 @@ def _int4wo_api(mod): def undo_recommended_configs(): torch._inductor.config.coordinate_descent_tuning = False torch._inductor.config.coordinate_descent_check_all_directions = False - torch._inductor.config.force_fuse_int_mm_with_mul = False - torch._inductor.config.fx_graph_cache = False + torch._inductor.config.force_fuse_int_mm_with_mul = False + torch._inductor.config.fx_graph_cache = False torch._inductor.config.triton.unique_kernel_names = False torch.set_float32_matmul_precision("highest") @@ -844,7 +844,7 @@ def api(mod): kwargs_copy = kwargs.copy() kwargs_copy["group_size"] = groupsize del kwargs_copy["groupsize"] - quantize(mod, int4_weight_only(**kwargs_copy)) + quantize_(mod, int4_weight_only(**kwargs_copy)) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int4_woqtensors(mod, **kwargs) @@ -865,7 +865,7 @@ def test_dynamic_quant(self): m = nn.Sequential(nn.Linear(K, N)) y_ref = m(x) - quantize(m, int8_dynamic_activation_int8_weight()) + quantize_(m, int8_dynamic_activation_int8_weight()) y_test = m(x) sqnr = compute_error(y_ref, y_test) @@ -1259,7 +1259,7 @@ def test_autoquant_manual(self, device, dtype): out3 = mod(example_input) sqnr2 = SQNR(out, out3) self.assertTrue(sqnr2 >= 30) - + @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, [ diff --git a/test/prototype/test_quant_llm.py b/test/prototype/test_quant_llm.py index 77eac6f69d..fab2d972b1 100644 --- a/test/prototype/test_quant_llm.py +++ b/test/prototype/test_quant_llm.py @@ -16,7 +16,7 @@ ) from torchao.prototype.quant_llm.quant_llm import _pack_tc_fpx, _pack_tc_fp6 from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32 -from torchao.quantization.quant_api import quantize +from torchao.quantization.quant_api import quantize_ _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) @@ -91,7 +91,7 @@ def test_quant_llm_quantize(self, ebits, mbits, bias): linear = torch.nn.Linear(IC, OC, bias=bias, device=device) fpx_linear = copy.deepcopy(linear) - quantize(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits)) + quantize_(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits)) x = torch.randn(N, IC, device=device, dtype=torch.half) expected = fpx_linear(x) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index e8b9d606d7..b137cd22dc 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -31,7 +31,7 @@ Int8WeightOnlyQuantizedLinearWeight, Int4WeightOnlyQuantizedLinearWeight, ) -from torchao import quantize +from torchao import quantize_ from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, Quantizer, @@ -89,7 +89,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module: class TorchCompileDynamicQuantizer(Quantizer): def quantize(self, model: torch.nn.Module) -> torch.nn.Module: - quantize(model, int8_dynamic_activation_int8_weight()) + quantize_(model, int8_dynamic_activation_int8_weight()) return model class ToyLinearModel(torch.nn.Module): @@ -152,7 +152,7 @@ class TestQuantFlow(TestCase): def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() example_inputs = m.example_inputs() - m = quantize(m, int8_dynamic_activation_int8_weight()) + quantize_(m, int8_dynamic_activation_int8_weight()) quantized = m(*example_inputs) # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64 # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {}) @@ -195,7 +195,7 @@ def test_int8_wo_quant_save_load(self): ) m = ToyLinearModel().eval().cpu() def api(model): - model = quantize(model, int8_weight_only()) + quantize_(model, int8_weight_only()) unwrap_tensor_subclass(model) api(m) @@ -501,7 +501,7 @@ def test_quantized_tensor_subclass_8da4w(self): m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() - m = quantize(m, int8_dynamic_activation_int4_weight(group_size=group_size)) + quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size)) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) @@ -530,7 +530,7 @@ def test_quantized_tensor_subclass_int4(self): example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda") group_size = 32 - m = quantize(m, int4_weight_only(group_size=group_size)) + quantize_(m, int4_weight_only(group_size=group_size)) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -550,7 +550,7 @@ def test_quantized_tensor_subclass_int8_wo(self): m_copy = copy.deepcopy(m) example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) - m = quantize(m, int8_weight_only()) + quantize_(m, int8_weight_only()) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -573,7 +573,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): m_copy = copy.deepcopy(m) # setting batch_size to 20 to be compatible with the kernel example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") - m = quantize(m, int8_dynamic_activation_int8_weight()) + quantize_(m, int8_dynamic_activation_int8_weight()) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) @@ -607,7 +607,7 @@ def test_quantized_tensor_subclass_save_load(self): m_copy = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=torch.bfloat16) - m = quantize(m, int8_weight_only()) + quantize_(m, int8_weight_only()) ref = m(*example_inputs) with tempfile.NamedTemporaryFile() as f: torch.save(m.state_dict(), f) diff --git a/torchao/__init__.py b/torchao/__init__.py index 3b5a1b3c0f..104dc5f311 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -30,14 +30,14 @@ from torchao.quantization import ( autoquant, - quantize, + quantize_, ) from . import dtypes __all__ = [ "dtypes", "autoquant", - "quantize", + "quantize_", ] # test-pytorchbot diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 73deafffec..35e35ecf03 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -13,7 +13,7 @@ ) from torchao.quantization.quant_api import ( - quantize, int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight, unwrap_tensor_subclass + quantize_, int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight, unwrap_tensor_subclass ) from torchao._models._eval import TransformerEvalWrapper, InputRecorder @@ -60,13 +60,13 @@ def run_evaluation( if quantization: if "int8wo" in quantization: - quantize(model, int8_weight_only()) + quantize_(model, int8_weight_only()) if "int8dq" in quantization: - quantize(model, int8_dynamic_activation_int8_weight()) + quantize_(model, int8_dynamic_activation_int8_weight()) if "int4wo" in quantization and not "gptq" in quantization: groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - quantize(model.to(device), int4_weight_only(group_size=groupsize)) + quantize_(model.to(device), int4_weight_only(group_size=groupsize)) if "int4wo" in quantization and "gptq" in quantization: groupsize=int(quantization.split("-")[-2]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" @@ -94,8 +94,8 @@ def run_evaluation( model = torch.compile(model, mode="max-autotune", fullgraph=True) with torch.no_grad(): TransformerEvalWrapper( - model=model.to(device), - tokenizer=tokenizer, + model=model.to(device), + tokenizer=tokenizer, max_seq_length=max_length, input_prep_func=prepare_inputs_for_model, device=device, @@ -122,16 +122,16 @@ def run_evaluation( args = parser.parse_args() run_evaluation( - args.checkpoint_path, - args.tasks, - args.limit, - args.device, - args.precision, - args.quantization, - args.compile, - args.max_length, - args.calibration_tasks, - args.calibration_limit, - args.calibration_seq_length, - args.pad_calibration_inputs, + args.checkpoint_path, + args.tasks, + args.limit, + args.device, + args.precision, + args.quantization, + args.compile, + args.max_length, + args.calibration_tasks, + args.calibration_limit, + args.calibration_seq_length, + args.pad_calibration_inputs, ) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 8142f80bb8..34ff9abb12 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -100,7 +100,7 @@ def generate( T_new = T + max_new_tokens seq = torch.empty(T_new, dtype=prompt.dtype, device=device) seq[:T] = prompt.view(-1) - + # setup model cache max_seq_length = min(T_new, model.config.block_size) if not interactive else 350 with torch.device(device): @@ -158,7 +158,7 @@ def main( """ torchao.quantization.utils.recommended_inductor_config_setter() - + assert checkpoint_path.is_file(), checkpoint_path tokenizer_path = checkpoint_path.parent / "tokenizer.model" assert tokenizer_path.is_file(), str(tokenizer_path) @@ -180,11 +180,11 @@ def main( prompt_length = encoded.size(0) torch.manual_seed(1234) - + if quantization: from torchao.quantization.quant_api import ( - quantize, + quantize_, int8_weight_only, int8_dynamic_activation_int8_weight, int4_weight_only, @@ -193,13 +193,13 @@ def main( ) if "int8wo" in quantization: - quantize(model, int8_weight_only()) + quantize_(model, int8_weight_only()) if "int8dq" in quantization: - quantize(model, int8_dynamic_activation_int8_weight()) + quantize_(model, int8_dynamic_activation_int8_weight()) if "int4wo" in quantization: groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - quantize(model, int4_weight_only(group_size=groupsize)) + quantize_(model, int4_weight_only(group_size=groupsize)) if "autoquant" == quantization: model = autoquant(model, manual=True) diff --git a/torchao/prototype/quant_llm/README.md b/torchao/prototype/quant_llm/README.md index 631df30817..f0ecd38d5a 100644 --- a/torchao/prototype/quant_llm/README.md +++ b/torchao/prototype/quant_llm/README.md @@ -5,15 +5,15 @@ This is a FP16 x FPx mixed matmul kernel optimized for io bound workloads per [F ## Usage ```python -from torchao.quantization.quant_api import quantize +from torchao.quantization.quant_api import quantize_ from torchao.prototype.quant_llm import fp6_llm_weight_only, quant_llm_fpx_weight_only model = ... model.half() # not necessary, but recommeneded to maintain accuracy -quantize(model, fp6_llm_weight_only()) # convert nn.Lineaer.weight to FP6 E3M2 in-place +quantize_(model, fp6_llm_weight_only()) # convert nn.Lineaer.weight to FP6 E3M2 in-place # for generic FPx EyMz where x = 1 + y + z -# quantize(model, quant_llm_fpx_weight_only(2, 2)) # use FP5 E2M2 instead +# quantize_(model, quant_llm_fpx_weight_only(2, 2)) # use FP5 E2M2 instead # fully compatible with torch.compile() model.compile(mode="max-autotune", fullgraph=True) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 76e7cd9ff2..4765d6a5fc 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -74,7 +74,7 @@ from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain from torchao.dtypes import to_affine_quantized import copy from torchao.quantization.quant_api import ( - quantize, + quantize_, int4_weight_only, ) @@ -101,7 +101,7 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune') # apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao) group_size = 32 # only works for torch 2.4+ -m = quantize(m, int4_weight_only(group_size=group_size)) +quantize_(m, int4_weight_only(group_size=group_size)) # temporary workaround for tensor subclass + torch.compile from torchao.utils import unwrap_tensor_subclass @@ -168,7 +168,7 @@ torch._inductor.config.force_fuse_int_mm_with_mul = True # for torch 2.4+ from torchao.quantization import quantize, int8_dynamic_activation_int8_weight -quantize(model, int8_dynamic_activation_int8_weight()) +quantize_(model, int8_dynamic_activation_int8_weight()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors @@ -180,7 +180,7 @@ change_linear_weights_to_int8_dqtensors(model) ```python # for torch 2.4+ from torchao.quantization import quantize, int8_weight_only -quantize(model, int8_weight_only()) +quantize_(model, int8_weight_only()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors @@ -195,7 +195,7 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is ```python # for torch 2.4+ from torchao.quantization import quantize, int4_weight_only -quantize(model, int4_weight_only()) +quantize_(model, int4_weight_only()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 115062c8f6..a1cf1bf034 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -29,7 +29,7 @@ "quantize_affine", "dequantize_affine", "choose_qprams_affine", - "quantize", + "quantize_", "int8_dynamic_activation_int4_weight", "int8_dynamic_activation_int8_weight", "int4_weight_only", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 31ab71f385..3da530b940 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -54,7 +54,7 @@ "Int4WeightOnlyQuantizer", "autoquant", "_get_subclass_inserter", - "quantize", + "quantize_", "int8_dynamic_activation_int4_weight", "int8_dynamic_activation_int8_weight", "int4_weight_only", @@ -259,8 +259,8 @@ def insert_subclass(lin): return insert_subclass -def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True) -> torch.nn.Module: - """Convert the weight of linear modules in the model with `apply_tensor_subclass` +def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True): + """Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace Args: model (torch.nn.Module): input model @@ -273,7 +273,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens import torch import torch.nn as nn - from torchao import quantize + from torchao import quantize_ # 1. quantize with some predefined `apply_tensor_subclass` method that corresponds to # optimized execution paths or kernels (e.g. int4 tinygemm kernel) @@ -286,7 +286,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens from torchao.quantization.quant_api import int4_weight_only m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) - m = quantize(m, int4_weight_only(group_size=32)) + quantize_(m, int4_weight_only(group_size=32)) # 2. write your own new apply_tensor_subclass # You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor @@ -305,7 +305,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: return isinstance(module, nn.Linear) m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) - m = quantize(m, apply_weight_quant, filter_fn) + quantize_(m, apply_weight_quant, filter_fn) """ if set_inductor_config: @@ -315,7 +315,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: _get_linear_subclass_inserter(apply_tensor_subclass), _is_linear if filter_fn is None else filter_fn, ) - return model + def int8_dynamic_activation_int4_weight(group_size=32): """Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear diff --git a/tutorials/quantize_vit/run_vit_b_quant.py b/tutorials/quantize_vit/run_vit_b_quant.py index 07e0118d20..a082cfe53a 100644 --- a/tutorials/quantize_vit/run_vit_b_quant.py +++ b/tutorials/quantize_vit/run_vit_b_quant.py @@ -19,9 +19,9 @@ # for APIs for earlier torch version and other quantization techniques # for torch 2.4+ -from torchao.quantization.quant_api import quantize +from torchao.quantization.quant_api import quantize_ from torchao.quantization.quant_api import int8_dynamic_activation_int8_weight -quantize(model, int8_dynamic_activation_int8_weight()) +quantize_(model, int8_dynamic_activation_int8_weight()) ## Quantization code - end ## compilation configs