44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from enum import Enum
7+ from enum import Enum , auto
88from typing import List , Optional , Tuple , Dict
99import torch
1010
1111from torchao .kernel .intmm import int_scaled_matmul
1212from torchao .kernel .intmm import safe_int_mm
13- from torchao .utils import TORCH_VERSION_AFTER_2_3
13+ from torchao .utils import (
14+ TORCH_VERSION_AFTER_2_3 ,
15+ TORCH_VERSION_AFTER_2_5 ,
16+ )
1417
1518
1619__all__ = [
@@ -34,17 +37,17 @@ class MappingType(Enum):
3437 based on this mapping
3538 e.g. scale = (10.2 - (-3.5)) / (7 - (-8))
3639 """
37- SYMMETRIC = 0
38- ASYMMETRIC = 1
40+ SYMMETRIC = auto ()
41+ ASYMMETRIC = auto ()
3942
4043class ZeroPointDomain (Enum ):
4144 """Enum that indicate whether zero_point is in integer domain or floating point domain
4245
4346 integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer)
4447 float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
4548 """
46- INT = 0
47- FLOAT = 1
49+ INT = auto ()
50+ FLOAT = auto ()
4851
4952"""
5053Map from dtype to the bound value of integers
@@ -69,6 +72,41 @@ class ZeroPointDomain(Enum):
6972 })
7073
7174
75+ # def register_custom_op(name: str):
76+ # from torch._inductor.decomposition import register_decomposition
77+
78+ # def decorator(fn):
79+ # if TORCH_VERSION_AFTER_2_5:
80+ # opdef = torch.library.custom_op(name, mutates_args=())(fn)
81+ # opdef.register_fake(fn)
82+ # register_decomposition([opdef._opoverload])(fn)
83+ # return opdef
84+ # else:
85+ # return fn
86+
87+ # return decorator
88+
89+ quant_lib = torch .library .Library ("quant" , "FRAGMENT" )
90+
91+ def register_custom_op (lib , schema : str ):
92+ from torch ._inductor .decomposition import register_decomposition
93+
94+ def decorator (fn ):
95+ if TORCH_VERSION_AFTER_2_5 :
96+ # TODO: change order
97+ lib_namespace = lib .ns
98+ op_name = schema .split ("(" )[0 ]
99+ lib .define (schema )
100+ lib .impl (op_name , fn , "CompositeImplicitAutograd" )
101+ op = getattr (getattr (torch .ops , lib_namespace ), op_name )
102+ register_decomposition ([op ])(fn )
103+ return fn
104+ else :
105+ return fn
106+
107+ return decorator
108+
109+
72110# TODO: decide on if we want to allow custom quant_min/quant_max here
73111def _get_and_check_qmin_qmax (dtype , quant_min , quant_max ):
74112 """Get quant_min and quant_max args based on dtype and also
@@ -140,7 +178,7 @@ def quantize_affine(
140178 quant_min : Optional [int ] = None ,
141179 quant_max : Optional [int ] = None ,
142180 zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
143- ):
181+ ) -> torch . Tensor :
144182 """
145183 Args:
146184 input (torch.Tensor): original float32, float16 or bfloat16 Tensor
@@ -174,6 +212,31 @@ def quantize_affine(
174212 Output:
175213 quantized tensor with requested dtype
176214 """
215+ return _quantize_affine (
216+ input ,
217+ block_size ,
218+ scale ,
219+ zero_point ,
220+ output_dtype ,
221+ quant_min ,
222+ quant_max ,
223+ zero_point_domain .name ,
224+ )
225+
226+
227+ @register_custom_op (quant_lib , 'quantize_affine(Tensor input, int[] block_size, Tensor scale, Tensor? zero_point, ScalarType output_dtype, int? quant_min=None, int? quant_max=None, str zero_point_domain="INT") -> Tensor' )
228+ def _quantize_affine (
229+ input : torch .Tensor ,
230+ block_size : List [int ],
231+ scale : torch .Tensor ,
232+ zero_point : Optional [torch .Tensor ],
233+ output_dtype : torch .dtype ,
234+ quant_min : Optional [int ] = None ,
235+ quant_max : Optional [int ] = None ,
236+ zero_point_domain : str = "INT" ,
237+ ) -> torch .Tensor :
238+ """op definition that has compatible signatures with custom op library
239+ """
177240 # TODO: validations
178241 # TODO: validate scale/zero_point dimensions are compatible with block_size
179242 assert input .dtype in [torch .float32 , torch .float16 , torch .bfloat16 ], f"Unsupported input dtype: { input .dtype } "
@@ -188,12 +251,12 @@ def quantize_affine(
188251 if zero_point is not None :
189252 zero_point = zero_point .view (shape_after_reduction )
190253
191- if zero_point_domain == ZeroPointDomain .INT :
254+ if zero_point_domain == ZeroPointDomain .INT . name :
192255 quant = torch .clamp (
193256 torch .round (input * (1.0 / scale )) + zero_point , quant_min , quant_max
194257 ).to (output_dtype )
195258 else :
196- assert zero_point_domain == ZeroPointDomain .FLOAT
259+ assert zero_point_domain == ZeroPointDomain .FLOAT . name
197260 mid_point = (quant_max + quant_min + 1 ) / 2
198261 min_val = zero_point - scale * mid_point
199262 quant = (
@@ -216,7 +279,7 @@ def dequantize_affine(
216279 zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
217280 * ,
218281 output_dtype : torch .dtype = torch .float32 ,
219- ):
282+ ) -> torch . Tensor :
220283 """
221284 Args:
222285 input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument
@@ -238,6 +301,34 @@ def dequantize_affine(
238301 Output:
239302 dequantized Tensor, with requested dtype or fp32
240303 """
304+ return _dequantize_affine (
305+ input ,
306+ block_size ,
307+ scale ,
308+ zero_point ,
309+ input_dtype ,
310+ quant_min ,
311+ quant_max ,
312+ zero_point_domain .name ,
313+ output_dtype = output_dtype ,
314+ )
315+
316+
317+ @register_custom_op (quant_lib , 'dequantize_affine(Tensor input, int[] block_size, Tensor scale, Tensor zero_point, ScalarType input_dtype, int? quant_min=None, int? quant_max=None, str zero_point_domain="INT", ScalarType output_dtype=float) -> Tensor' )
318+ def _dequantize_affine (
319+ input : torch .Tensor ,
320+ block_size : List [int ],
321+ scale : torch .Tensor ,
322+ zero_point : Optional [torch .Tensor ],
323+ input_dtype : torch .dtype ,
324+ quant_min : Optional [int ] = None ,
325+ quant_max : Optional [int ] = None ,
326+ zero_point_domain : str = "INT" ,
327+ * ,
328+ output_dtype : torch .dtype = torch .float32 ,
329+ ) -> torch .Tensor :
330+ """op definition that has compatible signatures with custom op library
331+ """
241332
242333 # TODO: validations
243334 # TODO: validate scale/zero_point dimensions are compatible with block_size
@@ -255,16 +346,16 @@ def dequantize_affine(
255346 if zero_point is not None :
256347 zero_point = zero_point .view (shape_after_reduction )
257348
258- if zero_point_domain == ZeroPointDomain .INT :
349+ if zero_point_domain == ZeroPointDomain .INT . name :
259350 # Force a copy to avoid input modification due
260351 # to upcoming in-place operations.
261352 dequant = input .to (torch .int32 , copy = True )
262353 if zero_point is not None :
263- dequant -= zero_point .to (torch .int32 )
354+ dequant = dequant - zero_point .to (torch .int32 )
264355 dequant = dequant .to (output_dtype )
265- dequant *= scale
356+ dequant = dequant * scale
266357 else :
267- assert zero_point_domain == ZeroPointDomain .FLOAT , f"Unexpected zero point domain: { zero_point_domain } "
358+ assert zero_point_domain == ZeroPointDomain .FLOAT . name , f"Unexpected zero point domain: { zero_point_domain } "
268359 mid_point = (quant_max + quant_min + 1 ) / 2
269360 # This should allocate new memory and avoid input modification
270361 dequant = input - mid_point
@@ -320,8 +411,38 @@ def choose_qparams_affine(
320411 Output:
321412 Tuple of scales and zero_points Tensor with requested dtype
322413 """
414+ return _choose_qparams_affine (
415+ input ,
416+ mapping_type .name ,
417+ block_size ,
418+ target_dtype ,
419+ quant_min ,
420+ quant_max ,
421+ eps ,
422+ scale_dtype ,
423+ zero_point_dtype ,
424+ preserve_zero ,
425+ zero_point_domain .name
426+ )
427+
428+ @register_custom_op (quant_lib , 'choose_qparams_affine(Tensor input, str mapping_type, int[] block_size, ScalarType target_dtype, int? quant_min=None, int? quant_max=None, float? eps=None, ScalarType? scale_dtype=None, ScalarType? zero_point_dtype=None, bool preserve_zero=True, str zero_point_domain="INT") -> (Tensor, Tensor)' )
429+ def _choose_qparams_affine (
430+ input : torch .Tensor ,
431+ mapping_type : str ,
432+ block_size : List [int ],
433+ target_dtype : torch .dtype ,
434+ quant_min : Optional [int ] = None ,
435+ quant_max : Optional [int ] = None ,
436+ eps : Optional [float ] = None ,
437+ scale_dtype : Optional [torch .dtype ] = None ,
438+ zero_point_dtype : Optional [torch .dtype ] = None ,
439+ preserve_zero : bool = True ,
440+ zero_point_domain : str = "INT" ,
441+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
442+ """op definition that has compatible signatures with custom op library
443+ """
323444 quant_min , quant_max = _get_and_check_qmin_qmax (target_dtype , quant_min , quant_max )
324- assert mapping_type in [MappingType .SYMMETRIC , MappingType .ASYMMETRIC ], f"Unsupported mapping type: { mapping_type } "
445+ assert mapping_type in [MappingType .SYMMETRIC . name , MappingType .ASYMMETRIC . name ], f"Unsupported mapping type: { mapping_type } "
325446
326447 if scale_dtype is None :
327448 scale_dtype = input .dtype
@@ -342,21 +463,22 @@ def choose_qparams_affine(
342463 min_val_neg = min_val
343464 max_val_pos = max_val
344465
345- if mapping_type == MappingType .SYMMETRIC :
466+ if mapping_type == MappingType .SYMMETRIC . name :
346467 max_val_pos = torch .max (- min_val_neg , max_val_pos )
347468 scale = max_val_pos / (float (quant_max - quant_min ) / 2 )
348469 if not preserve_zero :
349470 raise ValueError ("preserve_zero == False is not supported for symmetric quantization" )
350- if zero_point_domain != ZeroPointDomain .INT :
471+ if zero_point_domain != ZeroPointDomain .INT . name :
351472 raise ValueError ("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization" )
352473 zero_point = torch .full_like (scale , int ((quant_max + quant_min + 1 ) / 2 ))
353474 else :
475+ assert mapping_type == MappingType .ASYMMETRIC .name
354476 scale = (max_val_pos - min_val_neg ) / float (quant_max - quant_min )
355477 if preserve_zero :
356478 zero_point = quant_min - torch .round (min_val_neg / scale )
357479 zero_point = torch .clamp (zero_point , quant_min , quant_max )
358480 else :
359- assert zero_point_domain == ZeroPointDomain .FLOAT , "if not preserve_zero, zero_point must be in FLOAT domain"
481+ assert zero_point_domain == ZeroPointDomain .FLOAT . name , "if not preserve_zero, zero_point must be in FLOAT domain"
360482 mid_point = (quant_max + quant_min + 1 ) / 2
361483 zero_point = min_val_neg + scale * mid_point
362484
0 commit comments