Skip to content

Add bias support for Int8DynActInt4WeightLinear #1845

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 5 additions & 16 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,22 +1043,10 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
)
def test_replace_linear_8da4w(self):
module = torch.nn.ModuleList(
[torch.nn.Linear(in_features=256, out_features=50, bias=True)]
)
_replace_linear_8da4w(
module,
256,
False,
torch.float32,
torch.float32,
Int8DynActInt4WeightQATLinear,
copy_weights=True,
)
assert not isinstance(module[0], Int8DynActInt4WeightQATLinear) and isinstance(
module[0], torch.nn.Linear
)
module = torch.nn.ModuleList(
[torch.nn.Linear(in_features=256, out_features=50, bias=False)]
[
torch.nn.Linear(in_features=256, out_features=50, bias=True),
torch.nn.Linear(in_features=256, out_features=50, bias=False),
]
)
_replace_linear_8da4w(
module,
Expand All @@ -1070,6 +1058,7 @@ def test_replace_linear_8da4w(self):
copy_weights=True,
)
assert isinstance(module[0], Int8DynActInt4WeightQATLinear)
assert isinstance(module[1], Int8DynActInt4WeightQATLinear)

@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
Expand Down
21 changes: 18 additions & 3 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,10 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module:


class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64):
def __init__(self, m=64, n=32, k=64, bias=False):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)
self.linear1 = torch.nn.Linear(m, n, bias=bias).to(torch.float)
self.linear2 = torch.nn.Linear(n, k, bias=bias).to(torch.float)

def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
return (
Expand Down Expand Up @@ -272,6 +272,21 @@ def test_8da4w_quantizer(self):
assert isinstance(m.linear2, Int8DynActInt4WeightLinear)
m(*example_inputs)

@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower"
)
def test_8da4w_quantizer_linear_bias(self):
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

quantizer = Int8DynActInt4WeightQuantizer(groupsize=32)
m = ToyLinearModel(bias=True).eval()
example_inputs = m.example_inputs()
m = quantizer.quantize(m)
assert isinstance(m.linear1, Int8DynActInt4WeightLinear)
assert isinstance(m.linear2, Int8DynActInt4WeightLinear)
m(*example_inputs)

# TODO: save model weights as artifacts and re-enable in CI
# For now, to run this test, you will need to download the weights from HF
# and run this script to convert them:
Expand Down
24 changes: 14 additions & 10 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,7 @@ def quantize(
def linear_forward_8da4w(
x,
weight_int8,
bias,
scales,
zeros,
out_features,
Expand Down Expand Up @@ -956,7 +957,7 @@ def linear_forward_8da4w(

# x = x.to(torch.float16)
# w_dq = w_dq.to(torch.float16)
c = torch.nn.functional.linear(x, w_dq)
c = torch.nn.functional.linear(x, w_dq, bias)

# new_shape = origin_x_size[:-1] + (out_features,)
# c = c.reshape(new_shape)
Expand All @@ -970,6 +971,7 @@ class Int8DynActInt4WeightLinear(torch.nn.Module):
in_features: int
out_features: int
weight: torch.Tensor
bias: torch.Tensor

"""
This module implements a dynamic quantized linear layer with int4 weight.
Expand Down Expand Up @@ -1003,7 +1005,6 @@ def __init__(
# )
self.in_features = in_features
self.out_features = out_features
assert not bias, "require bias=False"
# TODO: align groupsize naming
self.groupsize = groupsize
# Precision of the activation which also indicates
Expand Down Expand Up @@ -1034,13 +1035,19 @@ def __init__(
),
)

if bias:
self.register_buffer("bias", torch.zeros(out_features, dtype=precision))
else:
self.bias = None

def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(self.precision)
# padding is removed for perf
# input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
return linear_forward_8da4w(
input,
self.weight,
self.bias,
self.scales,
self.zeros,
self.out_features,
Expand All @@ -1062,18 +1069,15 @@ def _replace_linear_8da4w(
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter

def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
# TODO: support linear bias
return (
isinstance(child, nn.Linear)
and child.bias is None
and (_check_linear_int4_k(child.in_features, groupsize) or padding_allowed)
return isinstance(child, nn.Linear) and (
_check_linear_int4_k(child.in_features, groupsize) or padding_allowed
)

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
new_linear = linear_class(
child.in_features,
child.out_features,
bias=False,
bias=child.bias is not None,
device=child.weight.device,
groupsize=groupsize,
precision=precision,
Expand All @@ -1084,6 +1088,7 @@ def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
# copy the weights, and doing so will result in an error
if copy_weights and child.weight.device != torch.device("meta"):
new_linear.weight = child.weight
new_linear.bias = child.bias
return new_linear

_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)
Expand Down Expand Up @@ -1130,7 +1135,7 @@ def _create_quantized_state_dict(
) -> Dict[str, torch.Tensor]:
cur_state_dict = model.state_dict()
for fqn, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear) and mod.bias is None:
if isinstance(mod, torch.nn.Linear):
out_features = mod.out_features
in_features = mod.in_features
# assert out_features % 8 == 0, "require out_features % 8 == 0"
Expand Down Expand Up @@ -1172,7 +1177,6 @@ def _create_quantized_state_dict(
cur_state_dict[f"{fqn}.weight"] = weight_int8.to(self.device)
cur_state_dict[f"{fqn}.scales"] = scales.to(self.device)
cur_state_dict[f"{fqn}.zeros"] = zeros.to(self.device)
# TODO: support bias?

return cur_state_dict

Expand Down
4 changes: 3 additions & 1 deletion torchao/quantization/qat/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
quantized_linear = Int8DynActInt4WeightLinear(
child.in_features,
child.out_features,
bias=False,
child.bias is not None,
groupsize=config.group_size,
precision=child.weight.dtype,
scales_precision=config.scale_precision,
Expand Down Expand Up @@ -237,6 +237,8 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
quantized_linear.weight = q_weight
quantized_linear.scales = s
quantized_linear.zeros = zp
if child.bias is not None:
quantized_linear.bias = child.bias
else:
self._convert_qat_linear_8da4w(child)

Expand Down
Loading