From 920b0840cd1643080b602e049e2df0e384a0adbc Mon Sep 17 00:00:00 2001 From: Jack <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 25 Feb 2025 13:41:10 -0800 Subject: [PATCH 1/2] Switch to new ao quant api for 8da4w (#8501) --- .../models/llama/source_transformation/quantize.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index 17cff7c63fd..af075efaff8 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -138,14 +138,12 @@ def quantize( # noqa C901 # Check for required args if group_size is None: raise Exception("For 8da4w quantization, group size must be specified.") - from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer - # 1. Quantize in checkpoint dtype. - model = Int8DynActInt4WeightQuantizer( - precision=checkpoint_torch_dtype, groupsize=group_size - ).quantize(model) - # 2. Set the computation dtype (what weights/acts dequantize to). - model = set_8da4w_computation_dtype(model, computation_torch_dtype) + from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_ + + quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size)) + + # TODO: deal with checkpoint / computation dtype decoupling. if verbose: print("quantized model:", model) @@ -700,7 +698,7 @@ def convert_for_runtime(self) -> nn.Module: def quantized_model(self) -> nn.Module: model_updated_state_dict = self.create_quantized_state_dict(self.packed) self.convert_for_runtime() - self.mod.load_state_dict(model_updated_state_dict) + self.mod.load_state_dict(model_updated_state_dict, assign=True) return self.mod From 3e848e9b2cb3c843b60e11d0350745a6a6beac55 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 25 Mar 2025 10:07:05 -0700 Subject: [PATCH 2/2] Fix export with unwrap_tensor_subclass --- examples/models/llama/source_transformation/quantize.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index abf32ffda90..2ef016de097 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -138,8 +138,10 @@ def quantize( # noqa C901 raise Exception("For 8da4w quantization, group size must be specified.") from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_ + from torchao.utils import unwrap_tensor_subclass quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size)) + model = unwrap_tensor_subclass(model) # TODO: deal with checkpoint / computation dtype decoupling.