diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 4d685169a1..3c29028898 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -1043,22 +1043,10 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: ) def test_replace_linear_8da4w(self): module = torch.nn.ModuleList( - [torch.nn.Linear(in_features=256, out_features=50, bias=True)] - ) - _replace_linear_8da4w( - module, - 256, - False, - torch.float32, - torch.float32, - Int8DynActInt4WeightQATLinear, - copy_weights=True, - ) - assert not isinstance(module[0], Int8DynActInt4WeightQATLinear) and isinstance( - module[0], torch.nn.Linear - ) - module = torch.nn.ModuleList( - [torch.nn.Linear(in_features=256, out_features=50, bias=False)] + [ + torch.nn.Linear(in_features=256, out_features=50, bias=True), + torch.nn.Linear(in_features=256, out_features=50, bias=False), + ] ) _replace_linear_8da4w( module, @@ -1070,6 +1058,7 @@ def test_replace_linear_8da4w(self): copy_weights=True, ) assert isinstance(module[0], Int8DynActInt4WeightQATLinear) + assert isinstance(module[1], Int8DynActInt4WeightQATLinear) @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 4af429940f..2caaef7745 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -115,10 +115,10 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module: class ToyLinearModel(torch.nn.Module): - def __init__(self, m=64, n=32, k=64): + def __init__(self, m=64, n=32, k=64, bias=False): super().__init__() - self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float) - self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float) + self.linear1 = torch.nn.Linear(m, n, bias=bias).to(torch.float) + self.linear2 = torch.nn.Linear(n, k, bias=bias).to(torch.float) def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): return ( @@ -272,6 +272,21 @@ def test_8da4w_quantizer(self): assert isinstance(m.linear2, Int8DynActInt4WeightLinear) m(*example_inputs) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower" + ) + def test_8da4w_quantizer_linear_bias(self): + from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer + + quantizer = Int8DynActInt4WeightQuantizer(groupsize=32) + m = ToyLinearModel(bias=True).eval() + example_inputs = m.example_inputs() + m = quantizer.quantize(m) + assert isinstance(m.linear1, Int8DynActInt4WeightLinear) + assert isinstance(m.linear2, Int8DynActInt4WeightLinear) + m(*example_inputs) + # TODO: save model weights as artifacts and re-enable in CI # For now, to run this test, you will need to download the weights from HF # and run this script to convert them: diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index b405d96af6..6c63937051 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -923,6 +923,7 @@ def quantize( def linear_forward_8da4w( x, weight_int8, + bias, scales, zeros, out_features, @@ -956,7 +957,7 @@ def linear_forward_8da4w( # x = x.to(torch.float16) # w_dq = w_dq.to(torch.float16) - c = torch.nn.functional.linear(x, w_dq) + c = torch.nn.functional.linear(x, w_dq, bias) # new_shape = origin_x_size[:-1] + (out_features,) # c = c.reshape(new_shape) @@ -970,6 +971,7 @@ class Int8DynActInt4WeightLinear(torch.nn.Module): in_features: int out_features: int weight: torch.Tensor + bias: torch.Tensor """ This module implements a dynamic quantized linear layer with int4 weight. @@ -1003,7 +1005,6 @@ def __init__( # ) self.in_features = in_features self.out_features = out_features - assert not bias, "require bias=False" # TODO: align groupsize naming self.groupsize = groupsize # Precision of the activation which also indicates @@ -1034,6 +1035,11 @@ def __init__( ), ) + if bias: + self.register_buffer("bias", torch.zeros(out_features, dtype=precision)) + else: + self.bias = None + def forward(self, input: torch.Tensor) -> torch.Tensor: input = input.to(self.precision) # padding is removed for perf @@ -1041,6 +1047,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return linear_forward_8da4w( input, self.weight, + self.bias, self.scales, self.zeros, self.out_features, @@ -1062,18 +1069,15 @@ def _replace_linear_8da4w( from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: - # TODO: support linear bias - return ( - isinstance(child, nn.Linear) - and child.bias is None - and (_check_linear_int4_k(child.in_features, groupsize) or padding_allowed) + return isinstance(child, nn.Linear) and ( + _check_linear_int4_k(child.in_features, groupsize) or padding_allowed ) def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: new_linear = linear_class( child.in_features, child.out_features, - bias=False, + bias=child.bias is not None, device=child.weight.device, groupsize=groupsize, precision=precision, @@ -1084,6 +1088,7 @@ def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: # copy the weights, and doing so will result in an error if copy_weights and child.weight.device != torch.device("meta"): new_linear.weight = child.weight + new_linear.bias = child.bias return new_linear _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) @@ -1130,7 +1135,7 @@ def _create_quantized_state_dict( ) -> Dict[str, torch.Tensor]: cur_state_dict = model.state_dict() for fqn, mod in model.named_modules(): - if isinstance(mod, torch.nn.Linear) and mod.bias is None: + if isinstance(mod, torch.nn.Linear): out_features = mod.out_features in_features = mod.in_features # assert out_features % 8 == 0, "require out_features % 8 == 0" @@ -1172,7 +1177,6 @@ def _create_quantized_state_dict( cur_state_dict[f"{fqn}.weight"] = weight_int8.to(self.device) cur_state_dict[f"{fqn}.scales"] = scales.to(self.device) cur_state_dict[f"{fqn}.zeros"] = zeros.to(self.device) - # TODO: support bias? return cur_state_dict diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index 716634fe9d..12584fade8 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -208,7 +208,7 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module): quantized_linear = Int8DynActInt4WeightLinear( child.in_features, child.out_features, - bias=False, + child.bias is not None, groupsize=config.group_size, precision=child.weight.dtype, scales_precision=config.scale_precision, @@ -237,6 +237,8 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module): quantized_linear.weight = q_weight quantized_linear.scales = s quantized_linear.zeros = zp + if child.bias is not None: + quantized_linear.bias = child.bias else: self._convert_qat_linear_8da4w(child)