@@ -568,14 +568,16 @@ def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
568
568
"""This is defined here instead of local function to support serialization"""
569
569
mapping_type = MappingType .ASYMMETRIC
570
570
target_dtype = torch .int8
571
+ scale_dtype = torch .float32
572
+ zero_point_dtype = torch .int32
571
573
if TORCH_VERSION_AT_LEAST_2_6 :
572
574
return to_affine_quantized_intx (
573
575
x ,
574
576
mapping_type ,
575
577
_get_per_token_block_size (x ),
576
578
target_dtype ,
577
- scale_dtype = torch . float64 ,
578
- zero_point_dtype = torch . int64 ,
579
+ scale_dtype = scale_dtype ,
580
+ zero_point_dtype = zero_point_dtype ,
579
581
)
580
582
else :
581
583
return to_affine_quantized_intx (
@@ -649,7 +651,6 @@ def _int8_dynamic_activation_int4_weight_transform(
649
651
# weight settings
650
652
block_size = (1 , group_size )
651
653
target_dtype = torch .int8
652
- eps = torch .finfo (torch .float32 ).eps
653
654
quant_min = - 8
654
655
quant_max = 7
655
656
@@ -680,7 +681,6 @@ def _int8_dynamic_activation_int4_weight_transform(
680
681
target_dtype ,
681
682
quant_min ,
682
683
quant_max ,
683
- eps ,
684
684
_layout = layout ,
685
685
)
686
686
weight = to_linear_activation_quantized (weight , input_quant_func )
@@ -793,7 +793,6 @@ def _int8_dynamic_activation_intx_weight_transform(
793
793
target_dtype = torch .int8 ,
794
794
quant_min = quant_min ,
795
795
quant_max = quant_max ,
796
- eps = torch .finfo (torch .float32 ).eps ,
797
796
scale_dtype = weight_scale_dtype ,
798
797
zero_point_dtype = torch .int8 ,
799
798
preserve_zero = (weight_mapping_type == MappingType .SYMMETRIC ),
@@ -1830,7 +1829,6 @@ def _intx_weight_only_transform(
1830
1829
target_dtype = torch .int8 ,
1831
1830
quant_min = quant_min ,
1832
1831
quant_max = quant_max ,
1833
- eps = torch .finfo (torch .float32 ).eps ,
1834
1832
scale_dtype = scale_dtype ,
1835
1833
zero_point_dtype = torch .int8 ,
1836
1834
preserve_zero = (mapping_type == MappingType .SYMMETRIC ),
0 commit comments