Skip to content

[ET-VK][int4] patch 4-bit source transformation quantizer to support linear modules with biases #8224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 6, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions backends/vulkan/_passes/int4_weight_only_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ def __init__(
from torchao.utils import find_multiple

self.origin_in_features = in_features
in_features = find_multiple(in_features, (1024,))
# pyre-ignore[6]: Incompatible parameter type
in_features = find_multiple(in_features, 1024)

self.use_bias = bias
self.in_features = in_features
self.out_features = out_features
assert not bias, "require bias=False"
self.device = device
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
Expand Down Expand Up @@ -80,20 +81,28 @@ def __init__(
device=device,
),
)
if bias:
self.register_buffer(
"bias",
torch.empty((out_features,), dtype=torch.float32, device=device),
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.padding:
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
# The forward method is replaced. In the original implementation, the forward
# method is torchao.quantization.GPTQ.linear_forward_int4; here a Vulkan custom
# operator is called instead.
return torch.ops.et_vk.linear_weight_int4(
r = torch.ops.et_vk.linear_weight_int4(
input,
self.weight,
self.groupsize,
self.scales_and_zeros,
self.inner_k_tiles,
)
if self.use_bias:
return r + self.bias
return r


# This function is coped from torchao.quantization.GPTQ._replace_linear_int4
Expand Down Expand Up @@ -128,7 +137,7 @@ def _vk_replace_linear_int4(
new_linear = linear_class(
child.in_features,
child.out_features,
bias=False,
bias=child.bias is not None,
device=child.weight.device,
groupsize=groupsize,
inner_k_tiles=inner_k_tiles,
Expand All @@ -138,6 +147,9 @@ def _vk_replace_linear_int4(
if copy_weights and child.weight.device != torch.device("meta"):
# pyre-fixme[16]: `Module` has no attribute `weight`.
new_linear.weight = child.weight
if child.bias is not None:
# pyre-fixme[16]: `Module` has no attribute `bias`.
new_linear.bias = child.bias
setattr(module, name, new_linear)
else:
_vk_replace_linear_int4(
Expand Down Expand Up @@ -189,7 +201,6 @@ def _create_quantized_state_dict(
mod.out_features < self.feature_limit
and mod.in_features < self.feature_limit
):
assert not mod.bias
out_features = mod.out_features
in_features = mod.in_features
logging.info(f"linear: {fqn}, in={in_features}, out={out_features}")
Expand All @@ -210,7 +221,8 @@ def _create_quantized_state_dict(
logging.warn(
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
)
padded_in_features = find_multiple(in_features, (1024,))
# pyre-ignore[6]: Incompatible parameter type
padded_in_features = find_multiple(in_features, 1024)
weight = F.pad(
weight, pad=(0, padded_in_features - in_features)
)
Expand Down
Loading