-
Notifications
You must be signed in to change notification settings - Fork 291
Fix numeric mismatches #2085
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix numeric mismatches #2085
Changes from all commits
04ca743
ba50fec
c57544f
c7fb2d7
05e2d00
4d62fc3
26fbfd2
808372f
a64e271
24e4f7c
718e650
2b07718
5eab1cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -177,6 +177,7 @@ def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: | |
scale_precision=self.scale_precision, | ||
zero_point_precision=self.zero_point_precision, | ||
device=child.weight.device, | ||
dtype=child.weight.dtype, | ||
) | ||
# In distributed training, the model may be instantiated | ||
# on the meta device, in which case there is no need to | ||
|
@@ -227,14 +228,19 @@ def _convert_helper(self, module: torch.nn.Module): | |
scale_precision=scale_precision, | ||
zero_point_precision=zero_point_precision, | ||
device=child.weight.device, | ||
output_dtype=child.weight.dtype, | ||
) | ||
setattr(module, name, quantized_embedding) | ||
|
||
# Load weights and qparams into quantized embedding | ||
(qmin, qmax) = _get_qmin_qmax(self.bit_width) | ||
(s, zp) = get_group_qparams_symmetric( | ||
child.weight, self.bit_width, group_size | ||
child.weight, | ||
self.bit_width, | ||
group_size, | ||
precision=scale_precision, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we probably need to set the zero point precision separately to be more correct? I think the function doesn't let you set them separately, so maybe we need to do an extra cast after calling it here, like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For numeric purposes, it doesn't matter because the zero point is an int8 number, which you can represent with floating point number. But I'll add the extra cast. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah but it would be better to be consistent to have out zero point dtype separately specified. for example if you were to quantize bias to in32 than you would want zero point in int32 although usually it is just zero. Besides, if this is for affine quantization than maybe have function name reflect that |
||
) | ||
zp = zp.to(zero_point_precision) | ||
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( | ||
child.weight, | ||
s, | ||
|
@@ -324,6 +330,7 @@ def __init__( | |
scale_precision: torch.dtype = torch.float32, | ||
zero_point_precision: torch.dtype = torch.int32, | ||
device: torch.device = None, | ||
output_dtype: torch.dtype = torch.float32, | ||
): | ||
super().__init__() | ||
|
||
|
@@ -341,6 +348,7 @@ def __init__( | |
self.group_size = group_size | ||
self.scale_precision = scale_precision | ||
self.zero_point_precision = zero_point_precision | ||
self.output_dtype = output_dtype | ||
|
||
# currently storing unpacked int8 weights | ||
self.register_buffer( | ||
|
@@ -367,20 +375,24 @@ def __init__( | |
) | ||
|
||
def forward(self, x): | ||
from torchao._executorch_ops import ( | ||
_quantized_decomposed_dequantize_per_channel_group_wrapper, | ||
from torchao.quantization.quant_primitives import ( | ||
dequantize_affine, | ||
) | ||
|
||
qmin, qmax = _get_qmin_qmax(self.bit_width) | ||
w_dq = _quantized_decomposed_dequantize_per_channel_group_wrapper( | ||
|
||
# dequantize_affine casts to output_dtype before scaling | ||
# dequantize_per_channel_group scales and then casts to output_dtype | ||
# The two do not agree when dtype != torch.float32 | ||
w_dq = dequantize_affine( | ||
self.weight, | ||
[1, self.group_size], | ||
self.scale, | ||
self.zero_point, | ||
torch.int8, | ||
qmin, | ||
qmax, | ||
torch.int8, | ||
self.group_size, | ||
x.dtype, | ||
output_dtype=self.output_dtype, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should output dtype be the same as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, this is an embedding op, so x.dtype is an int (the index). But the weights should not be cast to integers. This was one of the existing numerical issues in the embedding convert step There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I see, thanks for catching that! |
||
) | ||
return F.embedding( | ||
x, | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -219,8 +219,12 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module): | |||||
n_bit = 4 | ||||||
(qmin, qmax) = _get_qmin_qmax(n_bit) | ||||||
(s, zp) = get_group_qparams_symmetric( | ||||||
child.weight, n_bit, config.group_size | ||||||
child.weight, | ||||||
n_bit, | ||||||
config.group_size, | ||||||
precision=config.scale_precision, | ||||||
andrewor14 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
) | ||||||
zp = zp.to(config.zero_point_precision) | ||||||
from torchao._executorch_ops import ( | ||||||
_quantized_decomposed_quantize_per_channel_group_wrapper, | ||||||
) | ||||||
|
@@ -258,6 +262,10 @@ class Int8DynActInt4WeightQATLinear(FakeQuantizedLinear): | |||||
groupsize: the number of elements in each quantized group for weights | ||||||
precision: precision of weights | ||||||
scales_precision: precision of per group scales and zero points | ||||||
|
||||||
Note: we hardcode activation scales to use torch.fp32, but allow users to specify the weight scales (defaults to torch.fp32). | ||||||
To get an exact numerical match with Int8DynamicActivationInt4WeightConfig, users must use the same dtype for both the weights | ||||||
and the scales. Here scales_precision refers specifically to the weight scales only, not the activation scales. | ||||||
""" | ||||||
|
||||||
def __init__( | ||||||
|
@@ -270,7 +278,9 @@ def __init__( | |||||
precision: torch.dtype = torch.float32, | ||||||
scales_precision: torch.dtype = torch.float32, | ||||||
) -> None: | ||||||
activation_config = _get_8da4w_activation_config(scales_precision) | ||||||
# Use torch.float32 to match torchao.quantization.quant_api._int8_asymm_per_token_quant, | ||||||
# which is used in PTQ routines | ||||||
activation_config = _get_8da4w_activation_config(torch.float32) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this change needed? Looks like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the PTQ routines, scale_precision on the dynamic activation quantization is always FP64 (see _int8_asymm_per_token_quant, https://fburl.com/uf90caon; in this diff I change it to FP32, though.), even if the scale_dtype on the weight or the input_dtype are set to something else. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
why is that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, makes sense. We should definitely update the docstring to mention this then, something like "Note: We hardcode activation scales to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't know if there is a great reason for that, it's just how LinearActivationQuantizedTensor works (https://github.com/pytorch/ao/blob/main/torchao/quantization/linear_activation_quantized_tensor.py#L24). Unlike the weights, torchao does not serialize parameters related to activation quantization right now. Instead it serializes a function that does the activation quantization (which is opaque). To support serialization, the function is added to torch.serialization.add_safe_globals. I think to support specifying information about the activation quantization, we would have to serialize this information in the LinearActivationQuantizedTensor (https://github.com/pytorch/ao/blob/main/torchao/quantization/linear_activation_quantized_tensor.py#L24), or register many variants of _int8_asymm_per_token_quant with different dtypes to torch.serialization.add_safe_globals. The former probably makes more sense. cc @jerryzh168 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is the question around we hard coded the args with linear activation quantized tensor? we can use
it allows passing around some args for input_quant_function:
|
||||||
weight_config = _get_8da4w_weight_config(groupsize, scales_precision) | ||||||
super().__init__( | ||||||
in_features, | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like in
_int8_asymm_per_token_quant
you set the zero point to torch.int8, so this is different? A bit confused hereThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That difference doesn't actually matter because the zero point domain is int. So the zero points will be int8 numbers (which can be represented in fp32).
But maybe it makes sense to just use quant_api._int8_asymm_per_token_quant directly here, rather than use per_token_dynamic_quant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I would prefer us to be consistent in scale and zero point dtype. In fact if we could make scale fp32 and zero point int32 that would be fine. zero point should actually be same as quant dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah agree that we should just use
torch.int8
here for zero points to be consistent, even if it doesn't change the numerics in practice. We probably shouldn't use_int8_asymm_per_token_quant
here because that would introduce subclasses into a module swap only flow.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant to use _int8_asymm_per_token_quant followed be dequantize, e.g., _int8_asymm_per_token_quant(x).dequantize()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would still introduce AQT into this path right? I think I would prefer to keep these separate. IMO just calling the same primitives with unit tests seems enough
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It also creates a circular import. Changing back