Skip to content

Commit 3015548

Browse files
committed
up
1 parent be447a9 commit 3015548

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

torchao/quantization/GPTQ.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -935,8 +935,11 @@ def linear_forward_8da4w(
935935
):
936936
# to match torchao.quantization.quant_api._int8_asymm_per_token_quant
937937
x = per_token_dynamic_quant(
938-
x, scale_dtype=torch.float64, zero_point_dtype=torch.int64
938+
x, scale_dtype=torch.float32, zero_point_dtype=torch.float32
939939
)
940+
# x = per_token_dynamic_quant(
941+
# x, scale_dtype=torch.float64, zero_point_dtype=torch.int64
942+
# )
940943
# TODO: verify and remove following reshape code
941944
# origin_x_size = x.size()
942945
# x = x.reshape(-1, origin_x_size[-1])

torchao/quantization/quant_api.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -568,14 +568,16 @@ def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
568568
"""This is defined here instead of local function to support serialization"""
569569
mapping_type = MappingType.ASYMMETRIC
570570
target_dtype = torch.int8
571+
scale_dtype = torch.float32
572+
zero_point_dtype = torch.int32
571573
if TORCH_VERSION_AT_LEAST_2_6:
572574
return to_affine_quantized_intx(
573575
x,
574576
mapping_type,
575577
_get_per_token_block_size(x),
576578
target_dtype,
577-
scale_dtype=torch.float64,
578-
zero_point_dtype=torch.int64,
579+
scale_dtype=scale_dtype,
580+
zero_point_dtype=zero_point_dtype,
579581
)
580582
else:
581583
return to_affine_quantized_intx(
@@ -649,7 +651,6 @@ def _int8_dynamic_activation_int4_weight_transform(
649651
# weight settings
650652
block_size = (1, group_size)
651653
target_dtype = torch.int8
652-
eps = torch.finfo(torch.float32).eps
653654
quant_min = -8
654655
quant_max = 7
655656

@@ -680,7 +681,6 @@ def _int8_dynamic_activation_int4_weight_transform(
680681
target_dtype,
681682
quant_min,
682683
quant_max,
683-
eps,
684684
_layout=layout,
685685
)
686686
weight = to_linear_activation_quantized(weight, input_quant_func)
@@ -793,7 +793,6 @@ def _int8_dynamic_activation_intx_weight_transform(
793793
target_dtype=torch.int8,
794794
quant_min=quant_min,
795795
quant_max=quant_max,
796-
eps=torch.finfo(torch.float32).eps,
797796
scale_dtype=weight_scale_dtype,
798797
zero_point_dtype=torch.int8,
799798
preserve_zero=(weight_mapping_type == MappingType.SYMMETRIC),
@@ -1830,7 +1829,6 @@ def _intx_weight_only_transform(
18301829
target_dtype=torch.int8,
18311830
quant_min=quant_min,
18321831
quant_max=quant_max,
1833-
eps=torch.finfo(torch.float32).eps,
18341832
scale_dtype=scale_dtype,
18351833
zero_point_dtype=torch.int8,
18361834
preserve_zero=(mapping_type == MappingType.SYMMETRIC),

torchao/quantization/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def get_group_qparams_symmetric(
556556
quant_max=quant_max,
557557
eps=eps,
558558
scale_dtype=precision,
559-
zero_point_dtype=torch.int32,
559+
zero_point_dtype=precision,
560560
)
561561
return scale.reshape(w.shape[0], -1), zero_point.reshape(w.shape[0], -1)
562562

0 commit comments

Comments
 (0)