diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index 36743bb3b79..2ef016de097 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -136,14 +136,14 @@ def quantize( # noqa C901 # Check for required args if group_size is None: raise Exception("For 8da4w quantization, group size must be specified.") - from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer - # 1. Quantize in checkpoint dtype. - model = Int8DynActInt4WeightQuantizer( - precision=checkpoint_torch_dtype, groupsize=group_size - ).quantize(model) - # 2. Set the computation dtype (what weights/acts dequantize to). - model = set_8da4w_computation_dtype(model, computation_torch_dtype) + from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_ + from torchao.utils import unwrap_tensor_subclass + + quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size)) + model = unwrap_tensor_subclass(model) + + # TODO: deal with checkpoint / computation dtype decoupling. if verbose: print("quantized model:", model) @@ -698,7 +698,7 @@ def convert_for_runtime(self) -> nn.Module: def quantized_model(self) -> nn.Module: model_updated_state_dict = self.create_quantized_state_dict(self.packed) self.convert_for_runtime() - self.mod.load_state_dict(model_updated_state_dict) + self.mod.load_state_dict(model_updated_state_dict, assign=True) return self.mod