44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from enum import Enum
7+ from enum import Enum , auto
88from typing import List , Optional , Tuple , Dict
99import torch
1010
1111from torchao .kernel .intmm import int_scaled_matmul
1212from torchao .kernel .intmm import safe_int_mm
13- from torchao .utils import TORCH_VERSION_AFTER_2_3
13+ from torchao .utils import (
14+ TORCH_VERSION_AFTER_2_3 ,
15+ TORCH_VERSION_AFTER_2_5 ,
16+ )
1417
1518
1619__all__ = [
@@ -34,17 +37,17 @@ class MappingType(Enum):
3437 based on this mapping
3538 e.g. scale = (10.2 - (-3.5)) / (7 - (-8))
3639 """
37- SYMMETRIC = 0
38- ASYMMETRIC = 1
40+ SYMMETRIC = auto ()
41+ ASYMMETRIC = auto ()
3942
4043class ZeroPointDomain (Enum ):
4144 """Enum that indicate whether zero_point is in integer domain or floating point domain
4245
4346 integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer)
4447 float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
4548 """
46- INT = 0
47- FLOAT = 1
49+ INT = auto ()
50+ FLOAT = auto ()
4851
4952"""
5053Map from dtype to the bound value of integers
@@ -69,6 +72,53 @@ class ZeroPointDomain(Enum):
6972 })
7073
7174
75+ quant_lib = torch .library .Library ("quant" , "FRAGMENT" )
76+
77+ def register_custom_op (lib ):
78+ """This decorator is used to preserve some high level operators for torch.export.export
79+ while still allow them to be decomposed for inductor path
80+
81+ requirement: make sure `fn.__name__[1:]` is the operator name you want to register
82+
83+ NOTE: This should be applied at the top, after all other decorators have been applied
84+ NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input,
85+ e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make
86+ sense for downstream system (like executorch) to accept as well
87+
88+ Example:
89+ lib = torch.library.Library("my_namespace', "FRAGMENT")
90+ @register_custom_op(lib)
91+ def _the_op_that_needs_to_be_preserved(...)
92+ ...
93+
94+ # after this, `_the_op_that_needs_to_be_preserved` will be preserved as
95+ # torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
96+ # torch.export.export / torch._export.capture_pre_autograd_graph
97+
98+ """
99+ from torch ._inductor .decomposition import register_decomposition
100+
101+ def decorator (fn ):
102+ if TORCH_VERSION_AFTER_2_5 :
103+ from torch ._library .infer_schema import infer_schema
104+
105+ # assuming fn.__name__ starts with `_` and we want to take the rest
106+ # to be the name of the custom op
107+ op_name = fn .__name__ [1 :]
108+ schema = op_name + infer_schema (fn )
109+ lib .define (schema )
110+ lib .impl (op_name , fn , "CompositeImplicitAutograd" )
111+
112+ lib_namespace = lib .ns
113+ op = getattr (getattr (torch .ops , lib_namespace ), op_name )
114+ register_decomposition ([op ])(fn )
115+ return op
116+ else :
117+ return fn
118+
119+ return decorator
120+
121+
72122# TODO: decide on if we want to allow custom quant_min/quant_max here
73123def _get_and_check_qmin_qmax (dtype , quant_min , quant_max ):
74124 """Get quant_min and quant_max args based on dtype and also
@@ -140,7 +190,7 @@ def quantize_affine(
140190 quant_min : Optional [int ] = None ,
141191 quant_max : Optional [int ] = None ,
142192 zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
143- ):
193+ ) -> torch . Tensor :
144194 """
145195 Args:
146196 input (torch.Tensor): original float32, float16 or bfloat16 Tensor
@@ -174,6 +224,31 @@ def quantize_affine(
174224 Output:
175225 quantized tensor with requested dtype
176226 """
227+ return _quantize_affine (
228+ input ,
229+ block_size ,
230+ scale ,
231+ zero_point ,
232+ output_dtype ,
233+ quant_min ,
234+ quant_max ,
235+ zero_point_domain .name ,
236+ )
237+
238+
239+ @register_custom_op (quant_lib )
240+ def _quantize_affine (
241+ input : torch .Tensor ,
242+ block_size : List [int ],
243+ scale : torch .Tensor ,
244+ zero_point : Optional [torch .Tensor ],
245+ output_dtype : torch .dtype ,
246+ quant_min : Optional [int ] = None ,
247+ quant_max : Optional [int ] = None ,
248+ zero_point_domain : str = "INT" ,
249+ ) -> torch .Tensor :
250+ """op definition that has compatible signatures with custom op library
251+ """
177252 # TODO: validations
178253 # TODO: validate scale/zero_point dimensions are compatible with block_size
179254 assert input .dtype in [torch .float32 , torch .float16 , torch .bfloat16 ], f"Unsupported input dtype: { input .dtype } "
@@ -188,12 +263,12 @@ def quantize_affine(
188263 if zero_point is not None :
189264 zero_point = zero_point .view (shape_after_reduction )
190265
191- if zero_point_domain == ZeroPointDomain .INT :
266+ if zero_point_domain == ZeroPointDomain .INT . name :
192267 quant = torch .clamp (
193268 torch .round (input * (1.0 / scale )) + zero_point , quant_min , quant_max
194269 ).to (output_dtype )
195270 else :
196- assert zero_point_domain == ZeroPointDomain .FLOAT
271+ assert zero_point_domain == ZeroPointDomain .FLOAT . name
197272 mid_point = (quant_max + quant_min + 1 ) / 2
198273 min_val = zero_point - scale * mid_point
199274 quant = (
@@ -216,7 +291,7 @@ def dequantize_affine(
216291 zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
217292 * ,
218293 output_dtype : torch .dtype = torch .float32 ,
219- ):
294+ ) -> torch . Tensor :
220295 """
221296 Args:
222297 input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument
@@ -238,6 +313,34 @@ def dequantize_affine(
238313 Output:
239314 dequantized Tensor, with requested dtype or fp32
240315 """
316+ return _dequantize_affine (
317+ input ,
318+ block_size ,
319+ scale ,
320+ zero_point ,
321+ input_dtype ,
322+ quant_min ,
323+ quant_max ,
324+ zero_point_domain .name ,
325+ output_dtype = output_dtype ,
326+ )
327+
328+
329+ # @register_custom_op(quant_lib, 'dequantize_affine(Tensor input, int[] block_size, Tensor scale, Tensor zero_point, ScalarType input_dtype, int? quant_min=None, int? quant_max=None, str zero_point_domain="INT", ScalarType output_dtype=float) -> Tensor')
330+ @register_custom_op (quant_lib )
331+ def _dequantize_affine (
332+ input : torch .Tensor ,
333+ block_size : List [int ],
334+ scale : torch .Tensor ,
335+ zero_point : Optional [torch .Tensor ],
336+ input_dtype : torch .dtype ,
337+ quant_min : Optional [int ] = None ,
338+ quant_max : Optional [int ] = None ,
339+ zero_point_domain : str = "INT" ,
340+ output_dtype : torch .dtype = torch .float32 ,
341+ ) -> torch .Tensor :
342+ """op definition that has compatible signatures with custom op library
343+ """
241344
242345 # TODO: validations
243346 # TODO: validate scale/zero_point dimensions are compatible with block_size
@@ -255,16 +358,16 @@ def dequantize_affine(
255358 if zero_point is not None :
256359 zero_point = zero_point .view (shape_after_reduction )
257360
258- if zero_point_domain == ZeroPointDomain .INT :
361+ if zero_point_domain == ZeroPointDomain .INT . name :
259362 # Force a copy to avoid input modification due
260363 # to upcoming in-place operations.
261364 dequant = input .to (torch .int32 , copy = True )
262365 if zero_point is not None :
263- dequant -= zero_point .to (torch .int32 )
366+ dequant = dequant - zero_point .to (torch .int32 )
264367 dequant = dequant .to (output_dtype )
265- dequant *= scale
368+ dequant = dequant * scale
266369 else :
267- assert zero_point_domain == ZeroPointDomain .FLOAT , f"Unexpected zero point domain: { zero_point_domain } "
370+ assert zero_point_domain == ZeroPointDomain .FLOAT . name , f"Unexpected zero point domain: { zero_point_domain } "
268371 mid_point = (quant_max + quant_min + 1 ) / 2
269372 # This should allocate new memory and avoid input modification
270373 dequant = input - mid_point
@@ -320,8 +423,39 @@ def choose_qparams_affine(
320423 Output:
321424 Tuple of scales and zero_points Tensor with requested dtype
322425 """
426+ return _choose_qparams_affine (
427+ input ,
428+ mapping_type .name ,
429+ block_size ,
430+ target_dtype ,
431+ quant_min ,
432+ quant_max ,
433+ eps ,
434+ scale_dtype ,
435+ zero_point_dtype ,
436+ preserve_zero ,
437+ zero_point_domain .name
438+ )
439+
440+ # @register_custom_op(quant_lib, 'choose_qparams_affine(Tensor input, str mapping_type, int[] block_size, ScalarType target_dtype, int? quant_min=None, int? quant_max=None, float? eps=None, ScalarType? scale_dtype=None, ScalarType? zero_point_dtype=None, bool preserve_zero=True, str zero_point_domain="INT") -> (Tensor, Tensor)')
441+ @register_custom_op (quant_lib )
442+ def _choose_qparams_affine (
443+ input : torch .Tensor ,
444+ mapping_type : str ,
445+ block_size : List [int ],
446+ target_dtype : torch .dtype ,
447+ quant_min : Optional [int ] = None ,
448+ quant_max : Optional [int ] = None ,
449+ eps : Optional [float ] = None ,
450+ scale_dtype : Optional [torch .dtype ] = None ,
451+ zero_point_dtype : Optional [torch .dtype ] = None ,
452+ preserve_zero : bool = True ,
453+ zero_point_domain : str = "INT" ,
454+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
455+ """op definition that has compatible signatures with custom op library
456+ """
323457 quant_min , quant_max = _get_and_check_qmin_qmax (target_dtype , quant_min , quant_max )
324- assert mapping_type in [MappingType .SYMMETRIC , MappingType .ASYMMETRIC ], f"Unsupported mapping type: { mapping_type } "
458+ assert mapping_type in [MappingType .SYMMETRIC . name , MappingType .ASYMMETRIC . name ], f"Unsupported mapping type: { mapping_type } "
325459
326460 if scale_dtype is None :
327461 scale_dtype = input .dtype
@@ -342,21 +476,22 @@ def choose_qparams_affine(
342476 min_val_neg = min_val
343477 max_val_pos = max_val
344478
345- if mapping_type == MappingType .SYMMETRIC :
479+ if mapping_type == MappingType .SYMMETRIC . name :
346480 max_val_pos = torch .max (- min_val_neg , max_val_pos )
347481 scale = max_val_pos / (float (quant_max - quant_min ) / 2 )
348482 if not preserve_zero :
349483 raise ValueError ("preserve_zero == False is not supported for symmetric quantization" )
350- if zero_point_domain != ZeroPointDomain .INT :
484+ if zero_point_domain != ZeroPointDomain .INT . name :
351485 raise ValueError ("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization" )
352486 zero_point = torch .full_like (scale , int ((quant_max + quant_min + 1 ) / 2 ))
353487 else :
488+ assert mapping_type == MappingType .ASYMMETRIC .name
354489 scale = (max_val_pos - min_val_neg ) / float (quant_max - quant_min )
355490 if preserve_zero :
356491 zero_point = quant_min - torch .round (min_val_neg / scale )
357492 zero_point = torch .clamp (zero_point , quant_min , quant_max )
358493 else :
359- assert zero_point_domain == ZeroPointDomain .FLOAT , "if not preserve_zero, zero_point must be in FLOAT domain"
494+ assert zero_point_domain == ZeroPointDomain .FLOAT . name , "if not preserve_zero, zero_point must be in FLOAT domain"
360495 mid_point = (quant_max + quant_min + 1 ) / 2
361496 zero_point = min_val_neg + scale * mid_point
362497
0 commit comments