diff --git a/backends/vulkan/_passes/int4_weight_only_quantizer.py b/backends/vulkan/_passes/int4_weight_only_quantizer.py index 4821b613405..409cbb4b755 100644 --- a/backends/vulkan/_passes/int4_weight_only_quantizer.py +++ b/backends/vulkan/_passes/int4_weight_only_quantizer.py @@ -39,11 +39,12 @@ def __init__( from torchao.utils import find_multiple self.origin_in_features = in_features - in_features = find_multiple(in_features, (1024,)) + # pyre-ignore[6]: Incompatible parameter type + in_features = find_multiple(in_features, 1024) + self.use_bias = bias self.in_features = in_features self.out_features = out_features - assert not bias, "require bias=False" self.device = device self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles @@ -80,6 +81,11 @@ def __init__( device=device, ), ) + if bias: + self.register_buffer( + "bias", + torch.empty((out_features,), dtype=torch.float32, device=device), + ) def forward(self, input: torch.Tensor) -> torch.Tensor: if self.padding: @@ -87,13 +93,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # The forward method is replaced. In the original implementation, the forward # method is torchao.quantization.GPTQ.linear_forward_int4; here a Vulkan custom # operator is called instead. - return torch.ops.et_vk.linear_weight_int4( + r = torch.ops.et_vk.linear_weight_int4( input, self.weight, self.groupsize, self.scales_and_zeros, self.inner_k_tiles, ) + if self.use_bias: + return r + self.bias + return r # This function is coped from torchao.quantization.GPTQ._replace_linear_int4 @@ -128,7 +137,7 @@ def _vk_replace_linear_int4( new_linear = linear_class( child.in_features, child.out_features, - bias=False, + bias=child.bias is not None, device=child.weight.device, groupsize=groupsize, inner_k_tiles=inner_k_tiles, @@ -138,6 +147,9 @@ def _vk_replace_linear_int4( if copy_weights and child.weight.device != torch.device("meta"): # pyre-fixme[16]: `Module` has no attribute `weight`. new_linear.weight = child.weight + if child.bias is not None: + # pyre-fixme[16]: `Module` has no attribute `bias`. + new_linear.bias = child.bias setattr(module, name, new_linear) else: _vk_replace_linear_int4( @@ -189,7 +201,6 @@ def _create_quantized_state_dict( mod.out_features < self.feature_limit and mod.in_features < self.feature_limit ): - assert not mod.bias out_features = mod.out_features in_features = mod.in_features logging.info(f"linear: {fqn}, in={in_features}, out={out_features}") @@ -210,7 +221,8 @@ def _create_quantized_state_dict( logging.warn( f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" ) - padded_in_features = find_multiple(in_features, (1024,)) + # pyre-ignore[6]: Incompatible parameter type + padded_in_features = find_multiple(in_features, 1024) weight = F.pad( weight, pad=(0, padded_in_features - in_features) )