From d3db3db715a90ab90a67d5e71fa546ba30f1107f Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Thu, 15 Aug 2024 16:52:06 -0400 Subject: [PATCH 01/69] init --- torchao/dtypes/affine_quantized_tensor.py | 6 + torchao/prototype/awq/core.py | 139 ++++++++++++++++++++++ torchao/prototype/awq/test.py | 97 +++++++++++++++ 3 files changed, 242 insertions(+) create mode 100644 torchao/prototype/awq/core.py create mode 100644 torchao/prototype/awq/test.py diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index a8a56b0d22..f0f2cc02ef 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -858,7 +858,12 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): y += bias.to(m.dtype) return y +def _linear_awq_check(input_tensor, weight_tensor, bias): + from torchao.prototype.awq.core import AWQ_AQTLayout + return isinstance(weight_tensor.layout_tensor, AWQ_AQTLayout) +def _linear_awq_impl(input_tensor, weight_tensor, bias): + return torch.nn.functional.linear(input_tensor / weight_tensor.layout_tensor.layout_type.scales, weight_tensor.dequantize(), bias) def _register_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), @@ -866,6 +871,7 @@ def _register_quantized_linear_dispatches(): (_linear_quantized_act_fallback_check, _linear_quantized_act_fallback_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), + (_linear_awq_check, _linear_awq_impl), ]: _register_quantized_linear_dispatch(dispatch_condition, impl) diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py new file mode 100644 index 0000000000..edf224ee15 --- /dev/null +++ b/torchao/prototype/awq/core.py @@ -0,0 +1,139 @@ +import torch +import torch.nn.functional as F +from dataclasses import dataclass +from typing import Optional, Tuple +from copy import deepcopy +from torchao.dtypes.utils import ( + LayoutType, +) +from torchao.dtypes.affine_quantized_tensor import ( + PlainAQTLayout, + register_layout_cls, + to_affine_quantized + +) + +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, +) +from torchao.quantization.observer import ( + PerAxis, + AffineQuantizedObserverBase, +) + +class AWQObserver(AffineQuantizedObserverBase): + def __init__(self, + weight_shape: Tuple[int, int], + mapping_type: MappingType, + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: Optional[bool] = True, + zero_point_domain = ZeroPointDomain.INT, + ): + self.block_size = (1, -1) + super().__init__( + mapping_type, + target_dtype, + block_size = self.block_size, + quant_min = quant_min, + quant_max = quant_max, + eps = eps, + scale_dtype = scale_dtype, + zero_point_dtype = zero_point_dtype, + preserve_zero = preserve_zero, + zero_point_domain = zero_point_domain, + ) + self.average = torch.zeros(weight_shape[-1], dtype=torch.float32) + self.counter = 0 + self.output_sum = torch.zeros(weight_shape[0], dtype=torch.float32) + + def forward(self, input: torch.Tensor, output: torch.Tensor): + self.average = self.average * self.counter / (self.counter + input.shape[0]) + input.abs().sum(dim=0) / (self.counter + input.shape[0]) + self.counter += 1 + self.output_sum += output.sum(dim=0) + + + def calculate_qparams(self, weight, calibration_data): + best_error = float("inf") + best_ratio = -1 + best_scales = None + + n_grid = 20 + history = [] + x_max = self.average + for ratio in range(n_grid): + ratio = ratio * 1 / n_grid + scales = x_max.pow(ratio).clamp(min=1e-4).view(-1) + scales = scales / (scales.max() * scales.min()).sqrt() + weight.mul_(scales) + quantized_weight = to_affine_quantized( + weight, + self.mapping_type, + self.block_size, + self.target_dtype, + quant_min = self.quant_min, + quant_max = self.quant_max, + eps = self.eps, + scale_dtype = self.scale_dtype, + zero_point_dtype = self.zero_point_dtype, + preserve_zero = self.preserve_zero, + zero_point_domain = self.zero_point_domain, + layout_type = AWQLayoutType(scales) + ) + scaled_activation = (calibration_data).to(torch.bfloat16) + out = F.linear(scaled_activation, quantized_weight).sum(dim=0) + + loss = ( + (self.output_sum - out).float().pow(2).mean().item() + ) # float prevents overflow + # print(f"ratio: {ratio}, loss: {loss}") + history.append(loss) + + is_best = loss < best_error + if is_best: + best_error = loss + best_ratio = ratio + best_scales = scales + print(f"best scale: {best_scales}, best error: {best_error}") + + if best_ratio == -1: + print(history) + raise Exception + # print(best_ratio) + best_scales = best_scales.view(-1) + + assert torch.isnan(best_scales).sum() == 0, best_scales + return best_scales.detach() + +@dataclass(frozen=True) +class AWQLayoutType(LayoutType): + scales: torch.Tensor + + def pre_process(self, input: torch.Tensor) -> torch.Tensor: + return input / (self.scales.view(1, -1)) + +@register_layout_cls(AWQLayoutType) +class AWQ_AQTLayout(PlainAQTLayout): + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + layout_type: LayoutType, + ): + assert isinstance(layout_type, AWQLayoutType) + return cls(int_data, scale, zero_point, layout_type) + + + + + + + + \ No newline at end of file diff --git a/torchao/prototype/awq/test.py b/torchao/prototype/awq/test.py new file mode 100644 index 0000000000..6c7fa709d4 --- /dev/null +++ b/torchao/prototype/awq/test.py @@ -0,0 +1,97 @@ +from copy import deepcopy +import torch +import torch.nn.functional as F +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter +from torchao.prototype.awq.core import AWQ_AQTLayout, AWQLayoutType, AWQObserver +from torchao.quantization import quantize_, int8_weight_only +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, +) +from torchao.dtypes import to_affine_quantized + +# class ObservedLinear(torch.nn.Linear): +# def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None): +# super().__init__(in_features, out_features, bias, device, dtype) +# self.act_obs = act_obs + +# def forward(self, input: torch.Tensor): +# output = F.linear(input, self.weight, self.bias) +# self.act_obs(input, output) +# return output + +# @classmethod +# def from_float(cls, float_linear, act_obs): +# observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype) +# observed_linear.weight = float_linear.weight +# observed_linear.bias = float_linear.bias +# return observed_linear + +# def insert_awq_observer(model): +# _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) +# def replace_with_observer(layer): +# observer = AWQObserver((layer.weight.shape), MappingType.ASYMMETRIC, torch.int8) +# return ObservedLinear.from_float(layer, observer) +# _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) + +# # converting observed linear module to linear module with quantzied weights +# # with tensor subclasses +# def awq_quant(observed_linear, calibration_data): +# target_dtype = torch.int8 +# block_size = (1, observed_linear.weight.shape[1]) +# mapping_type = MappingType.ASYMMETRIC +# # weight quantization +# equalization_scale = observed_linear.act_obs.calculate_qparams(observed_linear.weight.detach(), calibration_data) +# layout_type = AWQLayoutType(equalization_scale) +# def weight_quant_func(weight): +# return to_affine_quantized(weight, mapping_type, block_size, target_dtype, layout_type = layout_type) + +# linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype) +# linear.weight = observed_linear.weight +# linear.bias = observed_linear.bias +# linear.weight = torch.nn.Parameter(weight_quant_func(linear.weight), requires_grad=False) + +# return linear + + +class ToyLinearModel(torch.nn.Module): + def __init__(self, m=64, n=32, k=64): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): + return torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + +torch.manual_seed(0) +dtype = torch.bfloat16 +m = ToyLinearModel().eval().to(dtype) +m_bf16 = deepcopy(m) +example_inputs = m.example_inputs(batch_size = 1024, dtype=dtype) +bf16_out = m_bf16(example_inputs) + +m_int8wo = deepcopy(m) +quantize_(m_int8wo, int8_weight_only) +int8wo_out = m_int8wo(example_inputs) + +# # calibrate +# insert_awq_observer(m) +# m(example_inputs) + +# # quantize +# is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) +# apply_awq_quant = lambda m: awq_quant(m, example_inputs) +# quantize_(m, apply_awq_quant, is_observed_linear) +# awq_out = m(example_inputs) + + +# # compare accuracy +# awq_err = torch.sum(torch.abs(awq_out - bf16_out)) +# int8wo_err = torch.sum(torch.abs(int8wo_out - bf16_out)) +# print(f"AWQ error: {awq_err}") +# print(f"Int8WO error: {int8wo_err}") \ No newline at end of file From ed864a26e1fb94a438e447c279bf3845f0af432f Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 16 Aug 2024 14:33:06 -0400 Subject: [PATCH 02/69] fixed implementation --- torchao/dtypes/affine_quantized_tensor.py | 3 +- torchao/prototype/awq/core.py | 31 +++--- torchao/prototype/awq/test.py | 127 +++++++++++----------- 3 files changed, 84 insertions(+), 77 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 83cae2b1e9..cc08e86e3d 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -879,7 +879,8 @@ def _linear_awq_check(input_tensor, weight_tensor, bias): return isinstance(weight_tensor.layout_tensor, AWQ_AQTLayout) def _linear_awq_impl(input_tensor, weight_tensor, bias): - return torch.nn.functional.linear(input_tensor / weight_tensor.layout_tensor.layout_type.scales, weight_tensor.dequantize(), bias) + # print('awq inp, scales: ',input_tensor.shape, weight_tensor.layout_tensor.layout_type.equalization_scale.shape) + return torch.nn.functional.linear(input_tensor / weight_tensor.layout_tensor.layout_type.equalization_scale, weight_tensor.dequantize(), bias) def _register_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index edf224ee15..f8fd90dd93 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -50,25 +50,28 @@ def __init__(self, ) self.average = torch.zeros(weight_shape[-1], dtype=torch.float32) self.counter = 0 - self.output_sum = torch.zeros(weight_shape[0], dtype=torch.float32) + self.calibration_data = [] - def forward(self, input: torch.Tensor, output: torch.Tensor): + def forward(self, input: torch.Tensor): self.average = self.average * self.counter / (self.counter + input.shape[0]) + input.abs().sum(dim=0) / (self.counter + input.shape[0]) self.counter += 1 - self.output_sum += output.sum(dim=0) + self.calibration_data.append(input) - def calculate_qparams(self, weight, calibration_data): + def calculate_qparams(self, orig_weight): best_error = float("inf") best_ratio = -1 best_scales = None n_grid = 20 history = [] + calibration_data = torch.cat(self.calibration_data, dim=0) + unquantized_result = F.linear(calibration_data, orig_weight).sum(dim=0) x_max = self.average for ratio in range(n_grid): ratio = ratio * 1 / n_grid - scales = x_max.pow(ratio).clamp(min=1e-4).view(-1) + weight = deepcopy(orig_weight) + scales = x_max.pow(ratio) scales = scales / (scales.max() * scales.min()).sqrt() weight.mul_(scales) quantized_weight = to_affine_quantized( @@ -85,13 +88,14 @@ def calculate_qparams(self, weight, calibration_data): zero_point_domain = self.zero_point_domain, layout_type = AWQLayoutType(scales) ) - scaled_activation = (calibration_data).to(torch.bfloat16) + scaled_activation = (calibration_data) out = F.linear(scaled_activation, quantized_weight).sum(dim=0) - + loss = ( - (self.output_sum - out).float().pow(2).mean().item() + (unquantized_result - out).pow(2).mean().item() ) # float prevents overflow # print(f"ratio: {ratio}, loss: {loss}") + # print(ratio, loss) history.append(loss) is_best = loss < best_error @@ -99,23 +103,20 @@ def calculate_qparams(self, weight, calibration_data): best_error = loss best_ratio = ratio best_scales = scales - print(f"best scale: {best_scales}, best error: {best_error}") + # print(f"best error: {best_error}") - if best_ratio == -1: - print(history) - raise Exception + # print(best_ratio) - best_scales = best_scales.view(-1) assert torch.isnan(best_scales).sum() == 0, best_scales return best_scales.detach() @dataclass(frozen=True) class AWQLayoutType(LayoutType): - scales: torch.Tensor + equalization_scale: torch.Tensor def pre_process(self, input: torch.Tensor) -> torch.Tensor: - return input / (self.scales.view(1, -1)) + return input * self.equalization_scale @register_layout_cls(AWQLayoutType) class AWQ_AQTLayout(PlainAQTLayout): diff --git a/torchao/prototype/awq/test.py b/torchao/prototype/awq/test.py index 6c7fa709d4..ee06db64ee 100644 --- a/torchao/prototype/awq/test.py +++ b/torchao/prototype/awq/test.py @@ -10,55 +10,55 @@ ) from torchao.dtypes import to_affine_quantized -# class ObservedLinear(torch.nn.Linear): -# def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None): -# super().__init__(in_features, out_features, bias, device, dtype) -# self.act_obs = act_obs +class ObservedLinear(torch.nn.Linear): + def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None): + super().__init__(in_features, out_features, bias, device, dtype) + self.act_obs = act_obs -# def forward(self, input: torch.Tensor): -# output = F.linear(input, self.weight, self.bias) -# self.act_obs(input, output) -# return output + def forward(self, input: torch.Tensor): + self.act_obs(input) + return F.linear(input, self.weight, self.bias) -# @classmethod -# def from_float(cls, float_linear, act_obs): -# observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype) -# observed_linear.weight = float_linear.weight -# observed_linear.bias = float_linear.bias -# return observed_linear + @classmethod + def from_float(cls, float_linear, act_obs): + observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype) + observed_linear.weight = float_linear.weight + observed_linear.bias = float_linear.bias + return observed_linear -# def insert_awq_observer(model): -# _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) -# def replace_with_observer(layer): -# observer = AWQObserver((layer.weight.shape), MappingType.ASYMMETRIC, torch.int8) -# return ObservedLinear.from_float(layer, observer) -# _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) +def insert_awq_observer(model): + _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) + def replace_with_observer(layer): + observer = AWQObserver((layer.weight.shape), MappingType.ASYMMETRIC, torch.int8) + return ObservedLinear.from_float(layer, observer) + _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) -# # converting observed linear module to linear module with quantzied weights -# # with tensor subclasses -# def awq_quant(observed_linear, calibration_data): -# target_dtype = torch.int8 -# block_size = (1, observed_linear.weight.shape[1]) -# mapping_type = MappingType.ASYMMETRIC -# # weight quantization -# equalization_scale = observed_linear.act_obs.calculate_qparams(observed_linear.weight.detach(), calibration_data) -# layout_type = AWQLayoutType(equalization_scale) -# def weight_quant_func(weight): -# return to_affine_quantized(weight, mapping_type, block_size, target_dtype, layout_type = layout_type) +# converting observed linear module to linear module with quantzied weights +# with tensor subclasses +def awq_quant(observed_linear): + assert len(observed_linear.act_obs.calibration_data) > 0, "Calibrate the observer first" + target_dtype = torch.int8 + block_size = (1, -1) + mapping_type = MappingType.ASYMMETRIC + # weight quantization + equalization_scale = observed_linear.act_obs.calculate_qparams(observed_linear.weight.detach()) + layout_type = AWQLayoutType(equalization_scale) + def weight_quant_func(weight): + return to_affine_quantized(weight, mapping_type, block_size, target_dtype, layout_type = layout_type) -# linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype) -# linear.weight = observed_linear.weight -# linear.bias = observed_linear.bias -# linear.weight = torch.nn.Parameter(weight_quant_func(linear.weight), requires_grad=False) - -# return linear + linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype) + linear.weight = observed_linear.weight + linear.bias = observed_linear.bias + linear.weight = torch.nn.Parameter(weight_quant_func(linear.weight), requires_grad=False) + return linear class ToyLinearModel(torch.nn.Module): - def __init__(self, m=64, n=32, k=64): + def __init__(self, m=512, n=256, k=128): super().__init__() self.linear1 = torch.nn.Linear(m, n, bias=False) self.linear2 = torch.nn.Linear(n, k, bias=False) + self.linear3 = torch.nn.Linear(k, 1, bias=False) def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): return torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device) @@ -66,32 +66,37 @@ def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): def forward(self, x): x = self.linear1(x) x = self.linear2(x) + x = self.linear3(x) return x - -torch.manual_seed(0) -dtype = torch.bfloat16 -m = ToyLinearModel().eval().to(dtype) -m_bf16 = deepcopy(m) -example_inputs = m.example_inputs(batch_size = 1024, dtype=dtype) -bf16_out = m_bf16(example_inputs) -m_int8wo = deepcopy(m) -quantize_(m_int8wo, int8_weight_only) -int8wo_out = m_int8wo(example_inputs) +for i in range(10): + torch.manual_seed(i) + dataset_size = 200 + dtype = torch.bfloat16 + m = ToyLinearModel().eval().to(dtype) + m_bf16 = deepcopy(m) + + dataset = m.example_inputs(batch_size = dataset_size, dtype=dtype) + calibration_data = dataset[:200] + bf16_out = m_bf16(dataset) + + m_int8wo = deepcopy(m) + quantize_(m_int8wo, int8_weight_only()) + int8wo_out = m_int8wo(dataset) -# # calibrate -# insert_awq_observer(m) -# m(example_inputs) + # calibrate + insert_awq_observer(m) + m(calibration_data) + # print('calibrated') -# # quantize -# is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) -# apply_awq_quant = lambda m: awq_quant(m, example_inputs) -# quantize_(m, apply_awq_quant, is_observed_linear) -# awq_out = m(example_inputs) + # quantize + is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) + quantize_(m, awq_quant, is_observed_linear) + awq_out = m(dataset) -# # compare accuracy -# awq_err = torch.sum(torch.abs(awq_out - bf16_out)) -# int8wo_err = torch.sum(torch.abs(int8wo_out - bf16_out)) -# print(f"AWQ error: {awq_err}") -# print(f"Int8WO error: {int8wo_err}") \ No newline at end of file + # compare accuracy + awq_err = torch.sum(torch.abs(awq_out - bf16_out)).item() / dataset_size + int8wo_err = torch.sum(torch.abs(int8wo_out - bf16_out)).item() / dataset_size + print(f"AWQ error: {awq_err}") + print(f"Int8WO error: {int8wo_err}") \ No newline at end of file From 0e690f9251d86227dd43c40a1bb745c0ba7cf421 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 21 Aug 2024 16:16:04 -0400 Subject: [PATCH 03/69] reduced vmem req --- torchao/prototype/awq/core.py | 79 +++++++++++------------------------ torchao/prototype/awq/test.py | 70 ++++++++++++++++--------------- 2 files changed, 62 insertions(+), 87 deletions(-) diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index f8fd90dd93..811da51423 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -24,9 +24,11 @@ class AWQObserver(AffineQuantizedObserverBase): def __init__(self, - weight_shape: Tuple[int, int], + weight: torch.Tensor, mapping_type: MappingType, target_dtype: torch.dtype, + device: str, + scale_search_space_size: int = 20, quant_min: Optional[int] = None, quant_max: Optional[int] = None, eps: Optional[float] = None, @@ -48,34 +50,22 @@ def __init__(self, preserve_zero = preserve_zero, zero_point_domain = zero_point_domain, ) - self.average = torch.zeros(weight_shape[-1], dtype=torch.float32) + self.weight = weight + self.scale_options = scale_search_space_size + self.losses = [0] * self.scale_options + self.average = torch.zeros(weight.shape[-1], dtype=torch.float32).to(device) self.counter = 0 - self.calibration_data = [] def forward(self, input: torch.Tensor): - self.average = self.average * self.counter / (self.counter + input.shape[0]) + input.abs().sum(dim=0) / (self.counter + input.shape[0]) - self.counter += 1 - self.calibration_data.append(input) - - - def calculate_qparams(self, orig_weight): - best_error = float("inf") - best_ratio = -1 - best_scales = None - - n_grid = 20 - history = [] - calibration_data = torch.cat(self.calibration_data, dim=0) - unquantized_result = F.linear(calibration_data, orig_weight).sum(dim=0) - x_max = self.average - for ratio in range(n_grid): - ratio = ratio * 1 / n_grid - weight = deepcopy(orig_weight) - scales = x_max.pow(ratio) + self.average = self.average * self.counter / (self.counter + input.shape[0]) + input.abs().sum(dim=1).squeeze(0) / (self.counter + input.shape[0]) + self.counter += input.shape[0] + for i in range(self.scale_options): + unquantized_result = F.linear(input, self.weight) + ratio = i *1.0 / self.scale_options + scales = self.average.pow(ratio).clamp(min=1e-4) scales = scales / (scales.max() * scales.min()).sqrt() - weight.mul_(scales) quantized_weight = to_affine_quantized( - weight, + self.weight.data * scales, self.mapping_type, self.block_size, self.target_dtype, @@ -87,29 +77,17 @@ def calculate_qparams(self, orig_weight): preserve_zero = self.preserve_zero, zero_point_domain = self.zero_point_domain, layout_type = AWQLayoutType(scales) - ) - scaled_activation = (calibration_data) - out = F.linear(scaled_activation, quantized_weight).sum(dim=0) - - loss = ( - (unquantized_result - out).pow(2).mean().item() - ) # float prevents overflow - # print(f"ratio: {ratio}, loss: {loss}") - # print(ratio, loss) - history.append(loss) - - is_best = loss < best_error - if is_best: - best_error = loss - best_ratio = ratio - best_scales = scales - # print(f"best error: {best_error}") - - - # print(best_ratio) - - assert torch.isnan(best_scales).sum() == 0, best_scales - return best_scales.detach() + ) + scaled_activation = (input / scales) + out = F.linear(scaled_activation, quantized_weight) + self.losses[i] += (unquantized_result - out).pow(2).mean().item() + + def calculate_qparams(self): + losses = torch.tensor(self.losses) + ratio = torch.argmin(losses) * 1.0 / self.scale_options + scales = self.average.pow(ratio).clamp(min=1e-4) + scales = scales / (scales.max() * scales.min()).sqrt() + return scales.detach() @dataclass(frozen=True) class AWQLayoutType(LayoutType): @@ -130,11 +108,4 @@ def from_plain( ): assert isinstance(layout_type, AWQLayoutType) return cls(int_data, scale, zero_point, layout_type) - - - - - - - \ No newline at end of file diff --git a/torchao/prototype/awq/test.py b/torchao/prototype/awq/test.py index ee06db64ee..c8bc401ed3 100644 --- a/torchao/prototype/awq/test.py +++ b/torchao/prototype/awq/test.py @@ -26,22 +26,22 @@ def from_float(cls, float_linear, act_obs): observed_linear.bias = float_linear.bias return observed_linear -def insert_awq_observer(model): +def insert_awq_observer(model, device): _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) def replace_with_observer(layer): - observer = AWQObserver((layer.weight.shape), MappingType.ASYMMETRIC, torch.int8) + observer = AWQObserver((layer.weight), MappingType.ASYMMETRIC, torch.int8, device) return ObservedLinear.from_float(layer, observer) _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) # converting observed linear module to linear module with quantzied weights # with tensor subclasses def awq_quant(observed_linear): - assert len(observed_linear.act_obs.calibration_data) > 0, "Calibrate the observer first" + assert observed_linear.act_obs.counter > 0, "Calibrate the observer first" target_dtype = torch.int8 block_size = (1, -1) mapping_type = MappingType.ASYMMETRIC # weight quantization - equalization_scale = observed_linear.act_obs.calculate_qparams(observed_linear.weight.detach()) + equalization_scale = observed_linear.act_obs.calculate_qparams() layout_type = AWQLayoutType(equalization_scale) def weight_quant_func(weight): return to_affine_quantized(weight, mapping_type, block_size, target_dtype, layout_type = layout_type) @@ -60,43 +60,47 @@ def __init__(self, m=512, n=256, k=128): self.linear2 = torch.nn.Linear(n, k, bias=False) self.linear3 = torch.nn.Linear(k, 1, bias=False) - def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): - return torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device) + def example_inputs(self, batch_size, sequence_length=10, dtype=torch.half, device="cpu"): + return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)] def forward(self, x): x = self.linear1(x) x = self.linear2(x) x = self.linear3(x) return x + +if __name__ == "__main__": + for i in range(10): + device = ("cpu") + torch.manual_seed(i) + dataset_size = 1000 + dtype = torch.bfloat16 + l1,l2,l3 = 512,256,128 + m = ToyLinearModel(l1,l2,l3).eval().to(dtype).to(device) + m_bf16 = deepcopy(m) -for i in range(10): - torch.manual_seed(i) - dataset_size = 200 - dtype = torch.bfloat16 - m = ToyLinearModel().eval().to(dtype) - m_bf16 = deepcopy(m) - - dataset = m.example_inputs(batch_size = dataset_size, dtype=dtype) - calibration_data = dataset[:200] - bf16_out = m_bf16(dataset) - - m_int8wo = deepcopy(m) - quantize_(m_int8wo, int8_weight_only()) - int8wo_out = m_int8wo(dataset) + dataset = m.example_inputs(dataset_size, dtype=dtype, device=device) + calibration_data = dataset[:100] + bf16_out = torch.cat([m_bf16(i.squeeze(0)) for i in dataset], dim=0) + - # calibrate - insert_awq_observer(m) - m(calibration_data) - # print('calibrated') + m_int8wo = deepcopy(m) + quantize_(m_int8wo, int8_weight_only()) + int8wo_out = torch.cat([m_int8wo(i.squeeze(0)) for i in dataset]) - # quantize - is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) - quantize_(m, awq_quant, is_observed_linear) - awq_out = m(dataset) + # calibrate + insert_awq_observer(m, device) + for example in calibration_data: + m(example.to(device)) + # print('calibrated') + # quantize + is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) + quantize_(m, awq_quant, is_observed_linear) + awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) - # compare accuracy - awq_err = torch.sum(torch.abs(awq_out - bf16_out)).item() / dataset_size - int8wo_err = torch.sum(torch.abs(int8wo_out - bf16_out)).item() / dataset_size - print(f"AWQ error: {awq_err}") - print(f"Int8WO error: {int8wo_err}") \ No newline at end of file + # compare accuracy + awq_err = torch.sum(torch.abs(awq_out - bf16_out)).sum().item() / dataset_size + int8wo_err = torch.sum(torch.abs(int8wo_out - bf16_out)).sum().item() / dataset_size + print(f"AWQ error: {awq_err}") + print(f"Int8WO error: {int8wo_err}") \ No newline at end of file From 4519792db9cac46ebb300143189bdeaea6a7a54e Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Fri, 23 Aug 2024 21:47:01 -0400 Subject: [PATCH 04/69] eval on LLMs --- scripts/hf_eval.py | 27 +++++++++++-- scripts/hfconfig.py | 33 ++++++++++++++++ torchao/_models/_eval.py | 39 +++++-------------- torchao/_models/llama/eval.py | 21 +++++++++- torchao/prototype/awq/core.py | 72 ++++++++++++----------------------- torchao/prototype/awq/test.py | 69 +++++++++++++++++---------------- 6 files changed, 146 insertions(+), 115 deletions(-) create mode 100644 scripts/hfconfig.py diff --git a/scripts/hf_eval.py b/scripts/hf_eval.py index 27b1b568df..c475d357c9 100644 --- a/scripts/hf_eval.py +++ b/scripts/hf_eval.py @@ -43,7 +43,7 @@ def format_value(value): def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, save, batch_size, max_length): tokenizer = AutoTokenizer.from_pretrained(repo_id) - model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision) + model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision).to(device="cpu", dtype=precision) if quantization == "autoquant" and compile: model = torch.compile(model, mode="max-autotune", fullgraph=True) @@ -57,9 +57,28 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi quantize_(model.to(device=device), int4_weight_only()) elif quantization == "autoquant": model = autoquant(model.to(device=device)) + elif quantization == "awq": + from torchao.prototype.awq.test import ObservedLinear, insert_awq_observer, awq_quant + insert_awq_observer(model, device) + from datasets import load_dataset + wikitext103 = load_dataset("wikitext", "wikitext-103-v1") + wikitext103_train = wikitext103["train"] + wikitext103_calibration = wikitext103_train.select(range(100)) + calibration_input_ids = [tokenizer.encode(text, return_tensors="pt") for text in wikitext103_calibration["text"]] + print(len(calibration_input_ids)) + model.to(device) + print("running awq calibration") + for i, ids in enumerate(calibration_input_ids): + if ids.shape[-1] == 0: + continue + model(ids.to(device)) + + + is_observed_linear = lambda m, fqn: isinstance(model, ObservedLinear) + quantize_(model, awq_quant, is_observed_linear) if quantization != "autoquant" and compile: - model = torch.compile(model, mode="max-autotune", fullgraph=True) + model = torch.compile(model, fullgraph=True) with torch.no_grad(): result = evaluate( @@ -84,12 +103,12 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Run HF Model Evaluation') - parser.add_argument('--repo_id', type=str, default="meta-llama/Meta-Llama-3-8B", help='Repository ID to download from HF.') + parser.add_argument('--repo_id', type=str, default="meta-llama/Llama-2-7b-hf", help='Repository ID to download from HF.') parser.add_argument('--tasks', nargs='+', type=str, default=["wikitext"], help='List of lm-eluther tasks to evaluate usage: --tasks task1 task2') parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') - parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply') + parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "awq", "None"], help='Which quantization technique to apply') parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--save', action='store_true', help='Whether to save the model.') parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes') diff --git a/scripts/hfconfig.py b/scripts/hfconfig.py new file mode 100644 index 0000000000..148e2d61c5 --- /dev/null +++ b/scripts/hfconfig.py @@ -0,0 +1,33 @@ +import json +import torch +from transformers import AutoModel + +def create_weight_map(model_name): + # Load the model + model = AutoModel.from_pretrained(model_name) + + # Get the state dict + state_dict = model.state_dict() + + # Create the weight map + weight_map = {} + for key, tensor in state_dict.items(): + # In this example, we're assuming all weights are in a single file + # You may need to adjust this if your model uses sharded weights + weight_map[key] = "pytorch_model.bin" + + # Create the index dictionary + index_dict = { + "metadata": {"total_size": sum(param.numel() * param.element_size() for param in model.parameters())}, + "weight_map": weight_map + } + + # Save the index dictionary to a JSON file + with open("pytorch_model.bin.index.json", "w") as f: + json.dump(index_dict, f, indent=2) + + print("Created pytorch_model.bin.index.json") + +# Usage +model_name = "checkpoints/Xenova/llama2.c-stories15M" +create_weight_map(model_name) \ No newline at end of file diff --git a/torchao/_models/_eval.py b/torchao/_models/_eval.py index 858196776f..c58a1869de 100644 --- a/torchao/_models/_eval.py +++ b/torchao/_models/_eval.py @@ -45,6 +45,7 @@ class InputRecorder(eval_wrapper): def __init__( self, tokenizer, + model, calibration_seq_length, input_prep_func=None, pad_calibration_inputs=False, @@ -63,7 +64,7 @@ def __init__( self.vocab_size = vocab_size self._max_seq_length = calibration_seq_length self.calibration_seq_length = calibration_seq_length - + self.model_ = model # need to take inps and convert to corrent input # for model self.input_prep_func = ( @@ -145,35 +146,13 @@ def get_inputs(self): return self.inputs def _model_call(self, inps): - inps = inps.squeeze(0) - T = len(inps) - if ( - # can't use inputs that are too short when padding disabled - (T < self.calibration_seq_length and not self.pad_calibration_inputs) - or - # can't use inputs that actually use token we use for padding - (self.pad_calibration_inputs and self.pad_token in inps) - ): - # give random output - return torch.randn( - (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device - ) - - # pad or truncate to the right size - if T >= self.calibration_seq_length: - inps = inps[: self.calibration_seq_length] - else: - inps = F.pad(inps, (self.pad_token, self.calibration_seq_length - T)) - - inps = inps.unsqueeze(0) - model_in = self.input_prep_func(inps) + input = self.input_prep_func(inps.to(self._device)) - self.add_input(model_in) - - # output `something` with correct shape to keep eval going - return torch.randn( - (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device - ) + max_seq_length = min(max(inps.size()), self.max_length) + with torch.device(self._device): + self.model_.setup_caches(self.batch_size, max_seq_length) + logits = self.model_(*input) + return logits def _model_generate(self, context, max_length, eos_token_id): raise Exception("unimplemented") @@ -190,7 +169,7 @@ def __init__( input_prep_func=None, device="cuda" ): - super().__init__(tokenizer, None) + super().__init__(tokenizer, model, None) self._model = model # self.tokenizer = tokenizer self._device = torch.device(device) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index fc8634dd06..d34b75e5b9 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -49,7 +49,8 @@ def run_evaluation( print("Loading model ...") t0 = time.time() - model = _load_model(checkpoint_path, "cpu", precision) + model = _load_model(checkpoint_path, "cuda", precision).to(device) + print(model) if max_length is None: max_length = model.config.block_size @@ -75,6 +76,7 @@ def run_evaluation( assert "cuda" in device, "int4 gptq quantization only works on cuda" inputs = InputRecorder( tokenizer, + model, calibration_seq_length, prepare_inputs_for_model, pad_calibration_inputs, @@ -88,6 +90,23 @@ def run_evaluation( quantizer = Int4WeightOnlyGPTQQuantizer(groupsize=groupsize, device=device) model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) model = quantizer.quantize(model, inputs).to(device) + elif "awq" in quantization: + from torchao.prototype.awq.test import ObservedLinear, insert_awq_observer, awq_quant + insert_awq_observer(model, device) + InputRecorder( + tokenizer, + model, + calibration_seq_length, + prepare_inputs_for_model, + pad_calibration_inputs, + model.config.vocab_size, + device=device + ).record_inputs( + calibration_tasks, + calibration_limit, + ).get_inputs() + is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) + quantize_(model, awq_quant, is_observed_linear) else: if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(model) diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index f8fd90dd93..87102c7aaa 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -24,9 +24,11 @@ class AWQObserver(AffineQuantizedObserverBase): def __init__(self, - weight_shape: Tuple[int, int], + weight: torch.Tensor, mapping_type: MappingType, target_dtype: torch.dtype, + device: str, + scale_search_space_size: int = 20, quant_min: Optional[int] = None, quant_max: Optional[int] = None, eps: Optional[float] = None, @@ -48,34 +50,22 @@ def __init__(self, preserve_zero = preserve_zero, zero_point_domain = zero_point_domain, ) - self.average = torch.zeros(weight_shape[-1], dtype=torch.float32) + self.weight = weight + self.scale_options = scale_search_space_size + self.losses = [0] * self.scale_options + self.average = torch.zeros(weight.shape[-1], dtype=torch.float32).to(device) self.counter = 0 - self.calibration_data = [] def forward(self, input: torch.Tensor): - self.average = self.average * self.counter / (self.counter + input.shape[0]) + input.abs().sum(dim=0) / (self.counter + input.shape[0]) - self.counter += 1 - self.calibration_data.append(input) - - - def calculate_qparams(self, orig_weight): - best_error = float("inf") - best_ratio = -1 - best_scales = None - - n_grid = 20 - history = [] - calibration_data = torch.cat(self.calibration_data, dim=0) - unquantized_result = F.linear(calibration_data, orig_weight).sum(dim=0) - x_max = self.average - for ratio in range(n_grid): - ratio = ratio * 1 / n_grid - weight = deepcopy(orig_weight) - scales = x_max.pow(ratio) + self.average = self.average * self.counter / (self.counter + input.shape[0]) + input.abs().sum(dim=1).squeeze(0) / (self.counter + input.shape[0]) + self.counter += input.shape[0] + for i in range(self.scale_options): + unquantized_result = F.linear(input, self.weight) + ratio = i *1.0 / self.scale_options + scales = self.average.pow(ratio).clamp(min=1e-4) scales = scales / (scales.max() * scales.min()).sqrt() - weight.mul_(scales) quantized_weight = to_affine_quantized( - weight, + self.weight.data * scales, self.mapping_type, self.block_size, self.target_dtype, @@ -87,29 +77,17 @@ def calculate_qparams(self, orig_weight): preserve_zero = self.preserve_zero, zero_point_domain = self.zero_point_domain, layout_type = AWQLayoutType(scales) - ) - scaled_activation = (calibration_data) - out = F.linear(scaled_activation, quantized_weight).sum(dim=0) - - loss = ( - (unquantized_result - out).pow(2).mean().item() - ) # float prevents overflow - # print(f"ratio: {ratio}, loss: {loss}") - # print(ratio, loss) - history.append(loss) - - is_best = loss < best_error - if is_best: - best_error = loss - best_ratio = ratio - best_scales = scales - # print(f"best error: {best_error}") - - - # print(best_ratio) - - assert torch.isnan(best_scales).sum() == 0, best_scales - return best_scales.detach() + ) + scaled_activation = (input / scales) + out = F.linear(scaled_activation, quantized_weight) + self.losses[i] += (unquantized_result - out).pow(2).mean().item() + + def calculate_qparams(self): + losses = torch.tensor(self.losses) + ratio = torch.argmin(losses) * 1.0 / self.scale_options + scales = self.average.pow(ratio).clamp(min=1e-4) + scales = scales / (scales.max() * scales.min()).sqrt() + return scales.detach() @dataclass(frozen=True) class AWQLayoutType(LayoutType): diff --git a/torchao/prototype/awq/test.py b/torchao/prototype/awq/test.py index ee06db64ee..bc60894c71 100644 --- a/torchao/prototype/awq/test.py +++ b/torchao/prototype/awq/test.py @@ -26,22 +26,22 @@ def from_float(cls, float_linear, act_obs): observed_linear.bias = float_linear.bias return observed_linear -def insert_awq_observer(model): +def insert_awq_observer(model, device): _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) def replace_with_observer(layer): - observer = AWQObserver((layer.weight.shape), MappingType.ASYMMETRIC, torch.int8) + observer = AWQObserver((layer.weight), MappingType.ASYMMETRIC, torch.int8, device) return ObservedLinear.from_float(layer, observer) _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) # converting observed linear module to linear module with quantzied weights # with tensor subclasses def awq_quant(observed_linear): - assert len(observed_linear.act_obs.calibration_data) > 0, "Calibrate the observer first" + assert observed_linear.act_obs.counter > 0, "Calibrate the observer first" target_dtype = torch.int8 block_size = (1, -1) mapping_type = MappingType.ASYMMETRIC # weight quantization - equalization_scale = observed_linear.act_obs.calculate_qparams(observed_linear.weight.detach()) + equalization_scale = observed_linear.act_obs.calculate_qparams() layout_type = AWQLayoutType(equalization_scale) def weight_quant_func(weight): return to_affine_quantized(weight, mapping_type, block_size, target_dtype, layout_type = layout_type) @@ -60,43 +60,46 @@ def __init__(self, m=512, n=256, k=128): self.linear2 = torch.nn.Linear(n, k, bias=False) self.linear3 = torch.nn.Linear(k, 1, bias=False) - def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): - return torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device) + def example_inputs(self, batch_size, sequence_length=10, dtype=torch.half, device="cpu"): + return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)] def forward(self, x): x = self.linear1(x) x = self.linear2(x) x = self.linear3(x) return x +if __name__ == "__main__": + for i in range(10): + device = ("cpu") + torch.manual_seed(i) + dataset_size = 1000 + dtype = torch.bfloat16 + l1,l2,l3 = 512,256,128 + m = ToyLinearModel(l1,l2,l3).eval().to(dtype).to(device) + m_bf16 = deepcopy(m) -for i in range(10): - torch.manual_seed(i) - dataset_size = 200 - dtype = torch.bfloat16 - m = ToyLinearModel().eval().to(dtype) - m_bf16 = deepcopy(m) + dataset = m.example_inputs(dataset_size, dtype=dtype, device=device) + calibration_data = dataset[:100] + bf16_out = torch.cat([m_bf16(i.squeeze(0)) for i in dataset], dim=0) + - dataset = m.example_inputs(batch_size = dataset_size, dtype=dtype) - calibration_data = dataset[:200] - bf16_out = m_bf16(dataset) + m_int8wo = deepcopy(m) + quantize_(m_int8wo, int8_weight_only()) + int8wo_out = torch.cat([m_int8wo(i.squeeze(0)) for i in dataset]) - m_int8wo = deepcopy(m) - quantize_(m_int8wo, int8_weight_only()) - int8wo_out = m_int8wo(dataset) + # calibrate + insert_awq_observer(m, device) + for example in calibration_data: + m(example.to(device)) + # print('calibrated') - # calibrate - insert_awq_observer(m) - m(calibration_data) - # print('calibrated') + # quantize + is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) + quantize_(m, awq_quant, is_observed_linear) + awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) - # quantize - is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) - quantize_(m, awq_quant, is_observed_linear) - awq_out = m(dataset) - - - # compare accuracy - awq_err = torch.sum(torch.abs(awq_out - bf16_out)).item() / dataset_size - int8wo_err = torch.sum(torch.abs(int8wo_out - bf16_out)).item() / dataset_size - print(f"AWQ error: {awq_err}") - print(f"Int8WO error: {int8wo_err}") \ No newline at end of file + # compare accuracy + awq_err = torch.sum(torch.abs(awq_out - bf16_out)).sum().item() / dataset_size + int8wo_err = torch.sum(torch.abs(int8wo_out - bf16_out)).sum().item() / dataset_size + print(f"AWQ error: {awq_err}") + print(f"Int8WO error: {int8wo_err}") \ No newline at end of file From 0096a83e67d20c6aee4d3647a24f735a8463d7d4 Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Fri, 23 Aug 2024 22:16:37 -0400 Subject: [PATCH 05/69] eval on llm --- torchao/_models/llama/eval.py | 4 ++-- torchao/prototype/awq/core.py | 20 ++++++++------------ torchao/prototype/awq/test.py | 9 ++++----- 3 files changed, 14 insertions(+), 19 deletions(-) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index d34b75e5b9..ea81a8eccf 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -49,7 +49,7 @@ def run_evaluation( print("Loading model ...") t0 = time.time() - model = _load_model(checkpoint_path, "cuda", precision).to(device) + model = _load_model(checkpoint_path, "cpu", precision).to(device) print(model) if max_length is None: @@ -92,7 +92,7 @@ def run_evaluation( model = quantizer.quantize(model, inputs).to(device) elif "awq" in quantization: from torchao.prototype.awq.test import ObservedLinear, insert_awq_observer, awq_quant - insert_awq_observer(model, device) + insert_awq_observer(model, precision, device) InputRecorder( tokenizer, model, diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index 87102c7aaa..7099f2bdea 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -25,6 +25,7 @@ class AWQObserver(AffineQuantizedObserverBase): def __init__(self, weight: torch.Tensor, + input_dtype: torch.dtype, mapping_type: MappingType, target_dtype: torch.dtype, device: str, @@ -52,12 +53,14 @@ def __init__(self, ) self.weight = weight self.scale_options = scale_search_space_size - self.losses = [0] * self.scale_options - self.average = torch.zeros(weight.shape[-1], dtype=torch.float32).to(device) + self.losses = torch.zeros(self.scale_options) + self.average = torch.zeros(weight.shape[-1], dtype=input_dtype).to(device) self.counter = 0 def forward(self, input: torch.Tensor): - self.average = self.average * self.counter / (self.counter + input.shape[0]) + input.abs().sum(dim=1).squeeze(0) / (self.counter + input.shape[0]) + if input.dim() == 3: + input = input.squeeze(0) + self.average = self.average * self.counter / (self.counter + input.shape[0]) + input.abs().sum(dim=0) / (self.counter + input.shape[0]) self.counter += input.shape[0] for i in range(self.scale_options): unquantized_result = F.linear(input, self.weight) @@ -81,10 +84,10 @@ def forward(self, input: torch.Tensor): scaled_activation = (input / scales) out = F.linear(scaled_activation, quantized_weight) self.losses[i] += (unquantized_result - out).pow(2).mean().item() + # print(self.losses[0]) def calculate_qparams(self): - losses = torch.tensor(self.losses) - ratio = torch.argmin(losses) * 1.0 / self.scale_options + ratio = torch.argmin(self.losses) * 1.0 / self.scale_options scales = self.average.pow(ratio).clamp(min=1e-4) scales = scales / (scales.max() * scales.min()).sqrt() return scales.detach() @@ -108,11 +111,4 @@ def from_plain( ): assert isinstance(layout_type, AWQLayoutType) return cls(int_data, scale, zero_point, layout_type) - - - - - - - \ No newline at end of file diff --git a/torchao/prototype/awq/test.py b/torchao/prototype/awq/test.py index bc60894c71..76c86da578 100644 --- a/torchao/prototype/awq/test.py +++ b/torchao/prototype/awq/test.py @@ -26,18 +26,17 @@ def from_float(cls, float_linear, act_obs): observed_linear.bias = float_linear.bias return observed_linear -def insert_awq_observer(model, device): +def insert_awq_observer(model, input_dtype, device): _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) def replace_with_observer(layer): - observer = AWQObserver((layer.weight), MappingType.ASYMMETRIC, torch.int8, device) + observer = AWQObserver(layer.weight, input_dtype, MappingType.ASYMMETRIC, torch.int8, device) return ObservedLinear.from_float(layer, observer) _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) # converting observed linear module to linear module with quantzied weights # with tensor subclasses -def awq_quant(observed_linear): +def awq_quant(observed_linear, target_dtype=torch.int8): assert observed_linear.act_obs.counter > 0, "Calibrate the observer first" - target_dtype = torch.int8 block_size = (1, -1) mapping_type = MappingType.ASYMMETRIC # weight quantization @@ -88,7 +87,7 @@ def forward(self, x): int8wo_out = torch.cat([m_int8wo(i.squeeze(0)) for i in dataset]) # calibrate - insert_awq_observer(m, device) + insert_awq_observer(m, dtype, device) for example in calibration_data: m(example.to(device)) # print('calibrated') From 7614d512c223ca66143b042dd652bfd6d88e9318 Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Fri, 23 Aug 2024 22:34:45 -0400 Subject: [PATCH 06/69] convert list to tensor --- torchao/prototype/awq/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index cf83835e1e..f068c6b27a 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -53,7 +53,7 @@ def __init__(self, ) self.weight = weight self.scale_options = scale_search_space_size - self.losses = [0] * self.scale_options + self.losses = torch.zeros(self.scale_options, dtype= input_dtype) self.average = torch.zeros(weight.shape[-1], dtype=torch.float32).to(device) self.counter = 0 From 33a28dd8964f2b18dd0ed42073d3810e72de554e Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 23 Aug 2024 22:57:34 -0400 Subject: [PATCH 07/69] restructuring --- scripts/{hfconfig.py => create_weight_map.py} | 4 ++ test/prototype/test_awq.py | 57 +++++++++++++++++++ torchao/_models/_eval.py | 32 +++++++++++ torchao/_models/llama/eval.py | 2 +- torchao/prototype/awq/{test.py => api.py} | 55 +----------------- 5 files changed, 95 insertions(+), 55 deletions(-) rename scripts/{hfconfig.py => create_weight_map.py} (84%) create mode 100644 test/prototype/test_awq.py rename torchao/prototype/awq/{test.py => api.py} (53%) diff --git a/scripts/hfconfig.py b/scripts/create_weight_map.py similarity index 84% rename from scripts/hfconfig.py rename to scripts/create_weight_map.py index 148e2d61c5..79d611f738 100644 --- a/scripts/hfconfig.py +++ b/scripts/create_weight_map.py @@ -1,3 +1,7 @@ +""" +This file produces a file named pytorch_model.bin.index.json based on the downloaded model weights. +It was primarily used to create run evals on llama2.c-stories15M model. +""" import json import torch from transformers import AutoModel diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py new file mode 100644 index 0000000000..3c0d054ef9 --- /dev/null +++ b/test/prototype/test_awq.py @@ -0,0 +1,57 @@ +from copy import deepcopy +import torch +import torch.nn.functional as F +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter +from torchao.quantization import quantize_, int8_weight_only +from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer, awq_quant + +class ToyLinearModel(torch.nn.Module): + def __init__(self, m=512, n=256, k=128): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + self.linear3 = torch.nn.Linear(k, 1, bias=False) + + def example_inputs(self, batch_size, sequence_length=10, dtype=torch.half, device="cpu"): + return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)] + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + return x + + +device = ("cpu") +torch.manual_seed(34) +dataset_size = 1000 +dtype = torch.bfloat16 +l1,l2,l3 = 512,256,128 +m = ToyLinearModel(l1,l2,l3).eval().to(dtype).to(device) +m_bf16 = deepcopy(m) + +dataset = m.example_inputs(dataset_size, dtype=dtype, device=device) +calibration_data = dataset[:100] +bf16_out = torch.cat([m_bf16(i.squeeze(0)) for i in dataset], dim=0) + + +m_int8wo = deepcopy(m) +quantize_(m_int8wo, int8_weight_only()) +int8wo_out = torch.cat([m_int8wo(i.squeeze(0)) for i in dataset]) + +# calibrate +insert_awq_observer(m, dtype, device) +for example in calibration_data: + m(example.to(device)) +# print('calibrated') + +# quantize +is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) +quantize_(m, awq_quant, is_observed_linear) +awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) + +# compare accuracy +awq_err = torch.sum(torch.abs(awq_out - bf16_out)).sum().item() / dataset_size +int8wo_err = torch.sum(torch.abs(int8wo_out - bf16_out)).sum().item() / dataset_size +print(f"AWQ error: {awq_err}") +print(f"Int8WO error: {int8wo_err}") \ No newline at end of file diff --git a/torchao/_models/_eval.py b/torchao/_models/_eval.py index c58a1869de..a4d4342fff 100644 --- a/torchao/_models/_eval.py +++ b/torchao/_models/_eval.py @@ -145,6 +145,38 @@ def record_inputs( def get_inputs(self): return self.inputs + # TODO: dispatch based on quant method + def _GPTQ_model_call(self, inps): + inps = inps.squeeze(0) + T = len(inps) + if ( + # can't use inputs that are too short when padding disabled + (T < self.calibration_seq_length and not self.pad_calibration_inputs) + or + # can't use inputs that actually use token we use for padding + (self.pad_calibration_inputs and self.pad_token in inps) + ): + # give random output + return torch.randn( + (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device + ) + + # pad or truncate to the right size + if T >= self.calibration_seq_length: + inps = inps[: self.calibration_seq_length] + else: + inps = F.pad(inps, (self.pad_token, self.calibration_seq_length - T)) + + inps = inps.unsqueeze(0) + model_in = self.input_prep_func(inps) + + self.add_input(model_in) + + # output `something` with correct shape to keep eval going + return torch.randn( + (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device + ) + def _model_call(self, inps): input = self.input_prep_func(inps.to(self._device)) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index ea81a8eccf..02c161e8f6 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -91,7 +91,7 @@ def run_evaluation( model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) model = quantizer.quantize(model, inputs).to(device) elif "awq" in quantization: - from torchao.prototype.awq.test import ObservedLinear, insert_awq_observer, awq_quant + from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer, awq_quant insert_awq_observer(model, precision, device) InputRecorder( tokenizer, diff --git a/torchao/prototype/awq/test.py b/torchao/prototype/awq/api.py similarity index 53% rename from torchao/prototype/awq/test.py rename to torchao/prototype/awq/api.py index 2113e95773..86989dff76 100644 --- a/torchao/prototype/awq/test.py +++ b/torchao/prototype/awq/api.py @@ -1,9 +1,7 @@ -from copy import deepcopy import torch import torch.nn.functional as F from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.prototype.awq.core import AWQ_AQTLayout, AWQLayoutType, AWQObserver -from torchao.quantization import quantize_, int8_weight_only from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, @@ -43,7 +41,7 @@ def awq_quant(observed_linear, target_dtype=torch.int8): equalization_scale = observed_linear.act_obs.calculate_qparams() layout_type = AWQLayoutType(equalization_scale) def weight_quant_func(weight): - return to_affine_quantized(weight, mapping_type, block_size, target_dtype, layout_type = layout_type) + return to_affine_quantized(weight, mapping_type, block_size, target_dtype, layout_type = layout_type, zero_point_domain = ZeroPointDomain.INT) linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype) linear.weight = observed_linear.weight @@ -52,54 +50,3 @@ def weight_quant_func(weight): return linear -class ToyLinearModel(torch.nn.Module): - def __init__(self, m=512, n=256, k=128): - super().__init__() - self.linear1 = torch.nn.Linear(m, n, bias=False) - self.linear2 = torch.nn.Linear(n, k, bias=False) - self.linear3 = torch.nn.Linear(k, 1, bias=False) - - def example_inputs(self, batch_size, sequence_length=10, dtype=torch.half, device="cpu"): - return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)] - - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - x = self.linear3(x) - return x - -if __name__ == "__main__": - for i in range(10): - device = ("cpu") - torch.manual_seed(i) - dataset_size = 1000 - dtype = torch.bfloat16 - l1,l2,l3 = 512,256,128 - m = ToyLinearModel(l1,l2,l3).eval().to(dtype).to(device) - m_bf16 = deepcopy(m) - - dataset = m.example_inputs(dataset_size, dtype=dtype, device=device) - calibration_data = dataset[:100] - bf16_out = torch.cat([m_bf16(i.squeeze(0)) for i in dataset], dim=0) - - - m_int8wo = deepcopy(m) - quantize_(m_int8wo, int8_weight_only()) - int8wo_out = torch.cat([m_int8wo(i.squeeze(0)) for i in dataset]) - - # calibrate - insert_awq_observer(m, dtype, device) - for example in calibration_data: - m(example.to(device)) - # print('calibrated') - - # quantize - is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) - quantize_(m, awq_quant, is_observed_linear) - awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) - - # compare accuracy - awq_err = torch.sum(torch.abs(awq_out - bf16_out)).sum().item() / dataset_size - int8wo_err = torch.sum(torch.abs(int8wo_out - bf16_out)).sum().item() / dataset_size - print(f"AWQ error: {awq_err}") - print(f"Int8WO error: {int8wo_err}") \ No newline at end of file From 7d389b5a9cfd7c13e37c9ad54078e034c3f44c3f Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 23 Aug 2024 23:08:38 -0400 Subject: [PATCH 08/69] revert unecessary hf_eval changes --- scripts/hf_eval.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/hf_eval.py b/scripts/hf_eval.py index ae627ddc2e..d6081acb0f 100644 --- a/scripts/hf_eval.py +++ b/scripts/hf_eval.py @@ -62,14 +62,14 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, spars elif quantization == "autoquant": model = autoquant(model.to(device=device)) elif quantization == "awq": - from torchao.prototype.awq.test import ObservedLinear, insert_awq_observer, awq_quant - insert_awq_observer(model, device) from datasets import load_dataset + from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer, awq_quant + + insert_awq_observer(model, precision, device) wikitext103 = load_dataset("wikitext", "wikitext-103-v1") wikitext103_train = wikitext103["train"] wikitext103_calibration = wikitext103_train.select(range(100)) calibration_input_ids = [tokenizer.encode(text, return_tensors="pt") for text in wikitext103_calibration["text"]] - print(len(calibration_input_ids)) model.to(device) print("running awq calibration") for i, ids in enumerate(calibration_input_ids): @@ -82,7 +82,7 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, spars quantize_(model, awq_quant, is_observed_linear) if quantization != "autoquant" and compile: - model = torch.compile(model, fullgraph=True) + model = torch.compile(model, mode= "max-autotune", fullgraph=True) if sparsity == "semi_sparse": def all_linear(mod, name): @@ -125,7 +125,7 @@ def all_linear(mod, name): if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Run HF Model Evaluation') - parser.add_argument('--repo_id', type=str, default="meta-llama/Llama-2-7b-hf", help='Repository ID to download from HF.') + parser.add_argument('--repo_id', type=str, default="meta-llama/Meta-Llama-3-8B", help='Repository ID to download from HF.') parser.add_argument('--tasks', nargs='+', type=str, default=["wikitext"], help='List of lm-eluther tasks to evaluate usage: --tasks task1 task2') parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') From 5913b989d2a8b9f02833c68a7d924d22cdf149fd Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Sun, 25 Aug 2024 15:46:09 -0400 Subject: [PATCH 09/69] added wikitext eval test --- scripts/hf_eval.py | 7 +- test/prototype/test_awq.py | 173 ++++++++++++++++++++++++++----------- 2 files changed, 127 insertions(+), 53 deletions(-) diff --git a/scripts/hf_eval.py b/scripts/hf_eval.py index d6081acb0f..7b043762b1 100644 --- a/scripts/hf_eval.py +++ b/scripts/hf_eval.py @@ -47,7 +47,7 @@ def format_value(value): def run_evaluation(repo_id, tasks, limit, device, precision, quantization, sparsity, compile, save, batch_size, max_length): tokenizer = AutoTokenizer.from_pretrained(repo_id) - model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision).to(device="cpu") + model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision).to(device) if quantization == "autoquant" and compile: model = torch.compile(model, mode="max-autotune", fullgraph=True) @@ -63,16 +63,17 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, spars model = autoquant(model.to(device=device)) elif quantization == "awq": from datasets import load_dataset + from tqdm import tqdm from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer, awq_quant insert_awq_observer(model, precision, device) wikitext103 = load_dataset("wikitext", "wikitext-103-v1") wikitext103_train = wikitext103["train"] - wikitext103_calibration = wikitext103_train.select(range(100)) + wikitext103_calibration = wikitext103_train.select(range(1)) calibration_input_ids = [tokenizer.encode(text, return_tensors="pt") for text in wikitext103_calibration["text"]] model.to(device) print("running awq calibration") - for i, ids in enumerate(calibration_input_ids): + for i, ids in tqdm(enumerate(calibration_input_ids)): if ids.shape[-1] == 0: continue model(ids.to(device)) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 3c0d054ef9..7c49925da7 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -5,53 +5,126 @@ from torchao.quantization import quantize_, int8_weight_only from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer, awq_quant -class ToyLinearModel(torch.nn.Module): - def __init__(self, m=512, n=256, k=128): - super().__init__() - self.linear1 = torch.nn.Linear(m, n, bias=False) - self.linear2 = torch.nn.Linear(n, k, bias=False) - self.linear3 = torch.nn.Linear(k, 1, bias=False) - - def example_inputs(self, batch_size, sequence_length=10, dtype=torch.half, device="cpu"): - return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)] - - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - x = self.linear3(x) - return x - - -device = ("cpu") -torch.manual_seed(34) -dataset_size = 1000 -dtype = torch.bfloat16 -l1,l2,l3 = 512,256,128 -m = ToyLinearModel(l1,l2,l3).eval().to(dtype).to(device) -m_bf16 = deepcopy(m) - -dataset = m.example_inputs(dataset_size, dtype=dtype, device=device) -calibration_data = dataset[:100] -bf16_out = torch.cat([m_bf16(i.squeeze(0)) for i in dataset], dim=0) - - -m_int8wo = deepcopy(m) -quantize_(m_int8wo, int8_weight_only()) -int8wo_out = torch.cat([m_int8wo(i.squeeze(0)) for i in dataset]) - -# calibrate -insert_awq_observer(m, dtype, device) -for example in calibration_data: - m(example.to(device)) -# print('calibrated') - -# quantize -is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) -quantize_(m, awq_quant, is_observed_linear) -awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) - -# compare accuracy -awq_err = torch.sum(torch.abs(awq_out - bf16_out)).sum().item() / dataset_size -int8wo_err = torch.sum(torch.abs(int8wo_out - bf16_out)).sum().item() / dataset_size -print(f"AWQ error: {awq_err}") -print(f"Int8WO error: {int8wo_err}") \ No newline at end of file +def simple_test(): + class ToyLinearModel(torch.nn.Module): + def __init__(self, m=512, n=256, k=128): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + self.linear3 = torch.nn.Linear(k, 1, bias=False) + + def example_inputs(self, batch_size, sequence_length=10, dtype=torch.half, device="cpu"): + return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)] + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + return x + + + device = ("cpu") + torch.manual_seed(34) + dataset_size = 1000 + dtype = torch.bfloat16 + l1,l2,l3 = 512,256,128 + m = ToyLinearModel(l1,l2,l3).eval().to(dtype).to(device) + m_bf16 = deepcopy(m) + + dataset = m.example_inputs(dataset_size, dtype=dtype, device=device) + calibration_data = dataset[:100] + bf16_out = torch.cat([m_bf16(i.squeeze(0)) for i in dataset], dim=0) + + + m_int8wo = deepcopy(m) + quantize_(m_int8wo, int8_weight_only()) + int8wo_out = torch.cat([m_int8wo(i.squeeze(0)) for i in dataset]) + + # calibrate + insert_awq_observer(m, dtype, device) + print(m) + for example in calibration_data: + m(example.to(device)) + # print('calibrated') + + # quantize + is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) + quantize_(m, awq_quant, is_observed_linear) + print(m) + awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) + m = torch.compile(m, fullgraph=True) + # compare accuracy + awq_err = torch.sum(torch.abs(awq_out - bf16_out)).sum().item() / dataset_size + int8wo_err = torch.sum(torch.abs(int8wo_out - bf16_out)).sum().item() / dataset_size + print(f"AWQ error: {awq_err}") + print(f"Int8WO error: {int8wo_err}") + + +from transformers import AutoModelForCausalLM, AutoTokenizer +from lm_eval.models.huggingface import HFLM +from lm_eval.evaluator import evaluate +from lm_eval.tasks import get_task_dict +from datasets import load_dataset +from tqdm import tqdm +import time + +def create_batches_generator(data, batch_size): + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + +def wikitext_eval(repo_id, quant, calibrate_size=100, eval_size=1000, device="cuda", precision=torch.bfloat16, max_length=2048, compile=False): + print("Loading model ...") + t0 = time.time() + tokenizer = AutoTokenizer.from_pretrained(repo_id) + model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision).to(device) + print(f"Time to load model: {time.time() - t0:.02f} seconds") + + wikitext103 = load_dataset("wikitext", "wikitext-103-v1") + wikitext103_train = wikitext103["train"] + + if quant =="awq": + print("running awq calibration") + insert_awq_observer(model, precision, device) + print(model) + wikitext103_calibration = wikitext103_train.select(range(calibrate_size)) + calibration_input_ids = [tokenizer.encode(text, return_tensors="pt") for text in wikitext103_calibration["text"]] + + + for ids in tqdm(calibration_input_ids): + model(ids.to(device)) + + is_observed_linear = lambda m, fqn: isinstance(model, ObservedLinear) + quantize_(model, int8_weight_only(), is_observed_linear) + print(model) + + elif quant=="int8": + print("running int8 quantization") + quantize_(model, int8_weight_only()) + + if compile: + model = torch.compile(model) + + eval_data = wikitext103["train"].select(range(calibrate_size, min(calibrate_size+eval_size,len(wikitext103["train"])))) + total_loss = 0.0 + total_tokens = 0 + print("Evaluating...") + for example in tqdm(eval_data): + inputs = tokenizer(example["text"], return_tensors="pt", truncation=True, max_length=max_length) + input_ids = inputs.input_ids.to(device) + + with torch.no_grad(): + outputs = model(input_ids, labels=input_ids) + + loss = 0 if torch.isnan(outputs.loss) else outputs.loss.item() + total_loss += loss * input_ids.size(1) + total_tokens += input_ids.size(1) + + ppl = torch.tensor(total_loss / total_tokens).exp().item() + print(f"Perplexity: {ppl:.5f}") + # int8 100,100: 5505.30371 + # awq int8 100,100: 5546.76807 + # bf16 100,100: 5546.76807 + +# wikitext_eval("Xenova/llama2.c-stories15M","awq", 1, 1000, compile=False) +simple_test() + From 7d045f9266fe9bc9a287e27f86f79191d6717cba Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Tue, 27 Aug 2024 09:23:23 -0400 Subject: [PATCH 10/69] added tinygemm integration --- test/prototype/test_awq.py | 81 ++++++++++------ torchao/dtypes/affine_quantized_tensor.py | 71 +++++++++++++- torchao/prototype/awq/api.py | 113 ++++++++++++++++++---- torchao/prototype/awq/core.py | 42 +++----- 4 files changed, 224 insertions(+), 83 deletions(-) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 7c49925da7..e53091c8d1 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -2,7 +2,7 @@ import torch import torch.nn.functional as F from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter -from torchao.quantization import quantize_, int8_weight_only +from torchao.quantization import quantize_, int4_weight_only, int8_weight_only from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer, awq_quant def simple_test(): @@ -13,7 +13,7 @@ def __init__(self, m=512, n=256, k=128): self.linear2 = torch.nn.Linear(n, k, bias=False) self.linear3 = torch.nn.Linear(k, 1, bias=False) - def example_inputs(self, batch_size, sequence_length=10, dtype=torch.half, device="cpu"): + def example_inputs(self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"): return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)] def forward(self, x): @@ -23,7 +23,7 @@ def forward(self, x): return x - device = ("cpu") + device = ("cuda") torch.manual_seed(34) dataset_size = 1000 dtype = torch.bfloat16 @@ -36,28 +36,30 @@ def forward(self, x): bf16_out = torch.cat([m_bf16(i.squeeze(0)) for i in dataset], dim=0) - m_int8wo = deepcopy(m) - quantize_(m_int8wo, int8_weight_only()) - int8wo_out = torch.cat([m_int8wo(i.squeeze(0)) for i in dataset]) + m_int4wo = deepcopy(m) + quantize_(m_int4wo, int8_weight_only()) + int4wo_out = torch.cat([m_int4wo(i.squeeze(0)) for i in dataset]) # calibrate - insert_awq_observer(m, dtype, device) - print(m) + quant_dtype = "int4" + group_size = 128 + insert_awq_observer(m, quant_dtype, group_size, dtype, device) for example in calibration_data: m(example.to(device)) # print('calibrated') # quantize is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) - quantize_(m, awq_quant, is_observed_linear) - print(m) + scales = [] + quantize_(m, awq_quant(quant_dtype = quant_dtype, scale_list=scales, group_size = group_size), is_observed_linear) + print(scales) awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) - m = torch.compile(m, fullgraph=True) + # m = torch.compile(m, fullgraph=True) # compare accuracy awq_err = torch.sum(torch.abs(awq_out - bf16_out)).sum().item() / dataset_size - int8wo_err = torch.sum(torch.abs(int8wo_out - bf16_out)).sum().item() / dataset_size + int4wo_err = torch.sum(torch.abs(int4wo_out - bf16_out)).sum().item() / dataset_size print(f"AWQ error: {awq_err}") - print(f"Int8WO error: {int8wo_err}") + print(f"Int4WO error: {int4wo_err}") from transformers import AutoModelForCausalLM, AutoTokenizer @@ -78,29 +80,46 @@ def wikitext_eval(repo_id, quant, calibrate_size=100, eval_size=1000, device="cu tokenizer = AutoTokenizer.from_pretrained(repo_id) model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision).to(device) print(f"Time to load model: {time.time() - t0:.02f} seconds") - + # print(model) wikitext103 = load_dataset("wikitext", "wikitext-103-v1") wikitext103_train = wikitext103["train"] - if quant =="awq": - print("running awq calibration") - insert_awq_observer(model, precision, device) - print(model) + if quant.startswith("awq"): + quant_dtype = quant.split("-")[1] + print(f"running {quant} calibration") + t0 = time.time() + quant_dtype = quant.split("-")[1] + group_size = 128 if quant_dtype == "int4" else -1 + insert_awq_observer(model, quant_dtype, group_size, precision, device) wikitext103_calibration = wikitext103_train.select(range(calibrate_size)) - calibration_input_ids = [tokenizer.encode(text, return_tensors="pt") for text in wikitext103_calibration["text"]] - - - for ids in tqdm(calibration_input_ids): - model(ids.to(device)) + calibration_input_ids = [tokenizer.encode(text, return_tensors="pt").to(device=device) for text in wikitext103_calibration["text"]] - is_observed_linear = lambda m, fqn: isinstance(model, ObservedLinear) - quantize_(model, int8_weight_only(), is_observed_linear) - print(model) + for example in tqdm(wikitext103_calibration): + inputs = tokenizer(example["text"], return_tensors="pt", truncation=True, max_length=max_length) + input_ids = inputs.input_ids.to(device=device) + model(input_ids) + + print(f"time for calibration: {time.time() - t0:.02f} seconds") + is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) + t0 = time.time() + scales = [] + quantize_(model, awq_quant(quant_dtype=quant_dtype, group_size = group_size, scale_list=scales), is_observed_linear) + print(f"time for quantization: {time.time() - t0:.02f} seconds") + + # print("scale distributions:") + # for scale in scales: + # print(f"min: {scale.min().item():.02f}, max: {scale.max().item():.02f}, avg: {scale.mean().item():.02f}") + # print(model) elif quant=="int8": print("running int8 quantization") quantize_(model, int8_weight_only()) + # print(model) + elif quant=="int4": + print("running int4 quantization") + quantize_(model, int4_weight_only()) + # print(model) if compile: model = torch.compile(model) @@ -110,8 +129,7 @@ def wikitext_eval(repo_id, quant, calibrate_size=100, eval_size=1000, device="cu print("Evaluating...") for example in tqdm(eval_data): inputs = tokenizer(example["text"], return_tensors="pt", truncation=True, max_length=max_length) - input_ids = inputs.input_ids.to(device) - + input_ids = inputs.input_ids.to(device=device) with torch.no_grad(): outputs = model(input_ids, labels=input_ids) @@ -121,10 +139,13 @@ def wikitext_eval(repo_id, quant, calibrate_size=100, eval_size=1000, device="cu ppl = torch.tensor(total_loss / total_tokens).exp().item() print(f"Perplexity: {ppl:.5f}") + return ppl # int8 100,100: 5505.30371 # awq int8 100,100: 5546.76807 # bf16 100,100: 5546.76807 -# wikitext_eval("Xenova/llama2.c-stories15M","awq", 1, 1000, compile=False) -simple_test() +awq = wikitext_eval("Xenova/llama2.c-stories15M","awq-int4", 100, 1000, compile=False) +int8 = wikitext_eval("Xenova/llama2.c-stories15M","awq-int8", 100, 1000, compile=False) +# print(f"wikitext perplexity on {10} sentences\nawq: {awq}\nint8wo: {int8}") +# simple_test() diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index a337e994f5..217c0f2145 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -26,6 +26,7 @@ PlainLayoutType, is_device, ) + from torch.utils._python_dispatch import is_traceable_wrapper_subclass from dataclasses import dataclass from torchao.utils import ( @@ -648,6 +649,66 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def get_layout_type(self) -> LayoutType: return self.layout_type + +@dataclass(frozen=True) +class AWQLayoutType(LayoutType): + equalization_scale: torch.Tensor + + def pre_process(self, input: torch.Tensor) -> torch.Tensor: + return input * self.equalization_scale + +@dataclass(frozen=True) +class AWQ_INT4_LayoutType(LayoutType): + equalization_scale: torch.Tensor + inner_k_tiles: int = 8 + def pre_process(self, input: torch.Tensor) -> torch.Tensor: + input = input * self.equalization_scale + orig_out_features, orig_in_features = input.shape + in_features = find_multiple(orig_in_features, 1024) + out_features = find_multiple(orig_out_features, 8) + input = torch.nn.functional.pad( + input, + (0, in_features - orig_in_features, 0, out_features - orig_out_features), + ) + return input + def extra_repr(self): + return f"inner_k_tiles={self.inner_k_tiles}, equilization_scale={equalization_scale}" + +@register_layout_cls(AWQ_INT4_LayoutType) +class AWQ_INT4_Layout(TensorCoreTiledAQTLayout): + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + layout_type: LayoutType, + ): + assert isinstance(layout_type, AWQ_INT4_LayoutType) + if TORCH_VERSION_AT_LEAST_2_5: + int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) + assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" + else: + assert int_data.dtype == torch.int32, "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, layout_type.inner_k_tiles) + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) + scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) + return cls(packed_weight, scale_and_zero, False, layout_type) + +@register_layout_cls(AWQLayoutType) +class AWQ_AQTLayout(PlainAQTLayout): + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + layout_type: LayoutType, + ): + assert isinstance(layout_type, AWQLayoutType) + return cls(int_data, scale, zero_point, layout_type) + ##################################################### # torch functional and aten operator implementation # ##################################################### @@ -783,7 +844,7 @@ def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): weight_tensor.dtype == torch.bfloat16 and len(weight_tensor.shape) == 2 and weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT and - isinstance(weight_tensor.layout_type, TensorCoreTiledLayoutType) + (isinstance(weight_tensor.layout_type, TensorCoreTiledLayoutType) or isinstance(weight_tensor.layout_type, AWQ_INT4_LayoutType)) ) @@ -793,7 +854,8 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): f"need input_tensor shape: {input_tensor.shape} final" f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " ) - + if isinstance(weight_tensor.layout_type, AWQ_INT4_LayoutType): + input_tensor /= weight_tensor.layout_tensor.layout_type.equalization_scale # TODO: check groupsize quantization # avoid circular dep, TODO: move this to a common util.py act_mat = input_tensor @@ -860,12 +922,11 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): return y def _linear_awq_check(input_tensor, weight_tensor, bias): - from torchao.prototype.awq.core import AWQ_AQTLayout return isinstance(weight_tensor.layout_tensor, AWQ_AQTLayout) def _linear_awq_impl(input_tensor, weight_tensor, bias): - # print('awq inp, scales: ',input_tensor.shape, weight_tensor.layout_tensor.layout_type.equalization_scale.shape) - return torch.nn.functional.linear(input_tensor / weight_tensor.layout_tensor.layout_type.equalization_scale, weight_tensor.dequantize(), bias) + return _linear_fp_act_int8_weight_impl(input_tensor / weight_tensor.layout_tensor.layout_type.equalization_scale, weight_tensor, bias) + def _register_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 86989dff76..0ee2b59f66 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -1,7 +1,8 @@ import torch import torch.nn.functional as F from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter -from torchao.prototype.awq.core import AWQ_AQTLayout, AWQLayoutType, AWQObserver +from torchao.dtypes.affine_quantized_tensor import AWQ_INT4_LayoutType, AWQLayoutType +from torchao.prototype.awq.core import AWQObserver from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, @@ -23,30 +24,102 @@ def from_float(cls, float_linear, act_obs): observed_linear.weight = float_linear.weight observed_linear.bias = float_linear.bias return observed_linear - -def insert_awq_observer(model, input_dtype, device): + + +def insert_awq_observer(model, quant_dtype, group_size, input_dtype, device): + assert quant_dtype in ["int4", "int8"] _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) + if quant_dtype == "int4": + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + preserve_zero = False + zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.FLOAT + + elif quant_dtype == "int8": + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + zero_point_domain = ZeroPointDomain.INT + preserve_zero = True + block_size = (1, -1) + quant_min = None + quant_max = None + def replace_with_observer(layer): - observer = AWQObserver(layer.weight, input_dtype, MappingType.ASYMMETRIC, torch.int8, device) + observer = AWQObserver( + layer.weight, + block_size, + input_dtype, + mapping_type, + target_dtype, + device, + preserve_zero = preserve_zero, + zero_point_domain = zero_point_domain, + zero_point_dtype = zero_point_dtype, + quant_min=quant_min, + quant_max = quant_max, + eps = eps) return ObservedLinear.from_float(layer, observer) _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) -# converting observed linear module to linear module with quantzied weights -# with tensor subclasses -def awq_quant(observed_linear, target_dtype=torch.int8): - assert observed_linear.act_obs.counter > 0, "Calibrate the observer first" - block_size = (1, -1) - mapping_type = MappingType.ASYMMETRIC - # weight quantization - equalization_scale = observed_linear.act_obs.calculate_qparams() - layout_type = AWQLayoutType(equalization_scale) - def weight_quant_func(weight): - return to_affine_quantized(weight, mapping_type, block_size, target_dtype, layout_type = layout_type, zero_point_domain = ZeroPointDomain.INT) +# variant of _get_linear_subclass_inserter that works with observed linear class +def _observed_linear_subclass_inserter(constructor): + def insert_subclass(observed_linear): + linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype) + linear.weight = torch.nn.Parameter(constructor(observed_linear), requires_grad=False) + return linear + + return insert_subclass + +def awq_quant(quant_dtype = "int4", group_size = 128, scale_list =[]): + + def weight_quant_func(observed_linear): + assert observed_linear.act_obs.counter > 0, "Calibrate the observer first" + assert quant_dtype in ["int4", "int8"] + mapping_type = MappingType.ASYMMETRIC + # weight quantization + equalization_scale = observed_linear.act_obs.calculate_qparams() + if quant_dtype == "int4": + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + preserve_zero = False + zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.FLOAT + layout_type = AWQ_INT4_LayoutType(equalization_scale) + + elif quant_dtype == "int8": + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + zero_point_domain = ZeroPointDomain.INT + preserve_zero = True + block_size = (1, observed_linear.weight.shape[1]) + layout_type = AWQLayoutType(equalization_scale) + quant_min = None + quant_max = None + + scale_list.append(equalization_scale) + return to_affine_quantized( + observed_linear.weight, + mapping_type, block_size, + target_dtype, quant_min, + quant_max, eps, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + layout_type=layout_type) - linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype) - linear.weight = observed_linear.weight - linear.bias = observed_linear.bias - linear.weight = torch.nn.Parameter(weight_quant_func(linear.weight), requires_grad=False) - return linear + return _observed_linear_subclass_inserter(weight_quant_func) diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index f068c6b27a..998b8d52b7 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -7,16 +7,21 @@ LayoutType, ) from torchao.dtypes.affine_quantized_tensor import ( - PlainAQTLayout, + PlainAQTLayout, + TensorCoreTiledAQTLayout, register_layout_cls, - to_affine_quantized + to_affine_quantized, + AWQ_INT4_LayoutType, + AWQLayoutType ) +from torchao.utils import find_multiple from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, ) +from torchao.quantization.quant_api import int4_weight_only from torchao.quantization.observer import ( PerAxis, AffineQuantizedObserverBase, @@ -25,6 +30,7 @@ class AWQObserver(AffineQuantizedObserverBase): def __init__(self, weight: torch.Tensor, + block_size: Tuple, input_dtype: torch.dtype, mapping_type: MappingType, target_dtype: torch.dtype, @@ -38,11 +44,10 @@ def __init__(self, preserve_zero: Optional[bool] = True, zero_point_domain = ZeroPointDomain.INT, ): - self.block_size = (1, -1) super().__init__( mapping_type, target_dtype, - block_size = self.block_size, + block_size = block_size, quant_min = quant_min, quant_max = quant_max, eps = eps, @@ -54,8 +59,9 @@ def __init__(self, self.weight = weight self.scale_options = scale_search_space_size self.losses = torch.zeros(self.scale_options, dtype= input_dtype) - self.average = torch.zeros(weight.shape[-1], dtype=torch.float32).to(device) + self.average = torch.zeros(self.weight.shape[-1], dtype=input_dtype, device=device) self.counter = 0 + self.device = device def forward(self, input: torch.Tensor): self.average = self.average * self.counter / (self.counter + input.shape[0]) + input.abs().sum(dim=1).squeeze(0) / (self.counter + input.shape[0]) @@ -65,8 +71,9 @@ def forward(self, input: torch.Tensor): ratio = i *1.0 / self.scale_options scales = self.average.pow(ratio).clamp(min=1e-4) scales = scales / (scales.max() * scales.min()).sqrt() + layout = AWQLayoutType(scales) if self.zero_point_domain == ZeroPointDomain.INT else AWQ_INT4_LayoutType(scales) quantized_weight = to_affine_quantized( - self.weight.data * scales, + self.weight.data, self.mapping_type, self.block_size, self.target_dtype, @@ -77,7 +84,7 @@ def forward(self, input: torch.Tensor): zero_point_dtype = self.zero_point_dtype, preserve_zero = self.preserve_zero, zero_point_domain = self.zero_point_domain, - layout_type = AWQLayoutType(scales) + layout_type = layout ) scaled_activation = (input / scales) out = F.linear(scaled_activation, quantized_weight) @@ -88,24 +95,3 @@ def calculate_qparams(self): scales = self.average.pow(ratio).clamp(min=1e-4) scales = scales / (scales.max() * scales.min()).sqrt() return scales.detach() - -@dataclass(frozen=True) -class AWQLayoutType(LayoutType): - equalization_scale: torch.Tensor - - def pre_process(self, input: torch.Tensor) -> torch.Tensor: - return input * self.equalization_scale - -@register_layout_cls(AWQLayoutType) -class AWQ_AQTLayout(PlainAQTLayout): - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - layout_type: LayoutType, - ): - assert isinstance(layout_type, AWQLayoutType) - return cls(int_data, scale, zero_point, layout_type) - \ No newline at end of file From db302ef4b9fc6d47e9b48fcedaa43d1ffc4d1924 Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Thu, 29 Aug 2024 15:42:59 -0400 Subject: [PATCH 11/69] made the calibration step much faster --- test/prototype/test_awq.py | 145 +++++++++++++++------- torchao/dtypes/affine_quantized_tensor.py | 9 +- torchao/prototype/awq/api.py | 88 +++++++++++-- torchao/prototype/awq/core.py | 82 +++++++++--- 4 files changed, 243 insertions(+), 81 deletions(-) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index e53091c8d1..0fcabf1b34 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -4,6 +4,13 @@ from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.quantization import quantize_, int4_weight_only, int8_weight_only from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer, awq_quant +import argparse +import os +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset +from tqdm import tqdm +import time + def simple_test(): class ToyLinearModel(torch.nn.Module): @@ -24,7 +31,6 @@ def forward(self, x): device = ("cuda") - torch.manual_seed(34) dataset_size = 1000 dtype = torch.bfloat16 l1,l2,l3 = 512,256,128 @@ -62,27 +68,45 @@ def forward(self, x): print(f"Int4WO error: {int4wo_err}") -from transformers import AutoModelForCausalLM, AutoTokenizer -from lm_eval.models.huggingface import HFLM -from lm_eval.evaluator import evaluate -from lm_eval.tasks import get_task_dict -from datasets import load_dataset -from tqdm import tqdm -import time def create_batches_generator(data, batch_size): for i in range(0, len(data), batch_size): yield data[i:i + batch_size] -def wikitext_eval(repo_id, quant, calibrate_size=100, eval_size=1000, device="cuda", precision=torch.bfloat16, max_length=2048, compile=False): +def get_calib_dataset(tokenizer=None, n_samples=512, block_size=512, device="cuda"): + dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") + # dataset = dataset.shuffle(seed=42) + samples = [] + n_run = 0 + for data in dataset: + line = data["text"] + line = line.strip() + line_encoded = tokenizer.encode(line) + if len(line_encoded) > 512: + continue + sample = torch.tensor([line_encoded]) + if sample.numel() == 0: + continue + samples.append(sample) + n_run += 1 + if n_run == n_samples: + break + # now concatenate all samples and split according to block size + cat_samples = torch.cat(samples, dim=1) + n_split = cat_samples.shape[1] // block_size + print(f" * Split into {n_split} blocks") + return torch.cat([ + cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_split) + ], dim=0) + +def pile_eval(repo_id, quant, calibrate_size=100, eval_size=1000, device="cuda", precision=torch.bfloat16, max_length=2048, compile=False): print("Loading model ...") + torch.manual_seed(34) t0 = time.time() tokenizer = AutoTokenizer.from_pretrained(repo_id) model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision).to(device) print(f"Time to load model: {time.time() - t0:.02f} seconds") # print(model) - wikitext103 = load_dataset("wikitext", "wikitext-103-v1") - wikitext103_train = wikitext103["train"] if quant.startswith("awq"): quant_dtype = quant.split("-")[1] @@ -91,25 +115,16 @@ def wikitext_eval(repo_id, quant, calibrate_size=100, eval_size=1000, device="cu quant_dtype = quant.split("-")[1] group_size = 128 if quant_dtype == "int4" else -1 insert_awq_observer(model, quant_dtype, group_size, precision, device) - wikitext103_calibration = wikitext103_train.select(range(calibrate_size)) - calibration_input_ids = [tokenizer.encode(text, return_tensors="pt").to(device=device) for text in wikitext103_calibration["text"]] + calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibrate_size) - for example in tqdm(wikitext103_calibration): - inputs = tokenizer(example["text"], return_tensors="pt", truncation=True, max_length=max_length) - input_ids = inputs.input_ids.to(device=device) - model(input_ids) + model(calibration_data.to(device)) print(f"time for calibration: {time.time() - t0:.02f} seconds") is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) t0 = time.time() - scales = [] - quantize_(model, awq_quant(quant_dtype=quant_dtype, group_size = group_size, scale_list=scales), is_observed_linear) - print(f"time for quantization: {time.time() - t0:.02f} seconds") - - # print("scale distributions:") - # for scale in scales: - # print(f"min: {scale.min().item():.02f}, max: {scale.max().item():.02f}, avg: {scale.mean().item():.02f}") # print(model) + quantize_(model, awq_quant(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) + print(f"time for quantization: {time.time() - t0:.02f} seconds") elif quant=="int8": print("running int8 quantization") @@ -123,29 +138,63 @@ def wikitext_eval(repo_id, quant, calibrate_size=100, eval_size=1000, device="cu if compile: model = torch.compile(model) - eval_data = wikitext103["train"].select(range(calibrate_size, min(calibrate_size+eval_size,len(wikitext103["train"])))) - total_loss = 0.0 - total_tokens = 0 - print("Evaluating...") - for example in tqdm(eval_data): - inputs = tokenizer(example["text"], return_tensors="pt", truncation=True, max_length=max_length) - input_ids = inputs.input_ids.to(device=device) + testenc = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + testenc = tokenizer("\n\n".join(testenc["text"]), return_tensors="pt") + testenc = testenc.input_ids.to(model.device) + nsamples = testenc.numel() // max_length + model = model.eval() + nlls = [] + for i in tqdm(range(nsamples), desc="evaluating..."): + batch = testenc[:, (i * max_length) : ((i + 1) * max_length)].to( + model.device + ) with torch.no_grad(): - outputs = model(input_ids, labels=input_ids) - - loss = 0 if torch.isnan(outputs.loss) else outputs.loss.item() - total_loss += loss * input_ids.size(1) - total_tokens += input_ids.size(1) - - ppl = torch.tensor(total_loss / total_tokens).exp().item() - print(f"Perplexity: {ppl:.5f}") + lm_logits = model(batch).logits + shift_logits = lm_logits[:, :-1, :].contiguous().float() + shift_labels = testenc[ + :, (i * max_length) : ((i + 1) * max_length) + ][:, 1:] + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + neg_log_likelihood = loss.float() * max_length + nlls.append(neg_log_likelihood) + + ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * max_length)) + print(f"Perplexity: {ppl.item():.5f}") return ppl - # int8 100,100: 5505.30371 - # awq int8 100,100: 5546.76807 - # bf16 100,100: 5546.76807 - -awq = wikitext_eval("Xenova/llama2.c-stories15M","awq-int4", 100, 1000, compile=False) -int8 = wikitext_eval("Xenova/llama2.c-stories15M","awq-int8", 100, 1000, compile=False) -# print(f"wikitext perplexity on {10} sentences\nawq: {awq}\nint8wo: {int8}") -# simple_test() - + # int4: 29.75957 + # real-awq-int4: 28.9590 + # awq-int4: + + +parser = argparse.ArgumentParser(description="Evaluate a model with the specified parameters.") + +# Positional arguments +parser.add_argument("repo", type=str, help="Repository ID of the model.") +parser.add_argument("quant", type=str, help="Quantization method or file path.") + +# Optional arguments with default values +parser.add_argument("--calibrate_size", type=int, default=100, help="Calibration size. Default is 100.") +parser.add_argument("--eval_size", type=int, default=1000, help="Evaluation size. Default is 1000.") +parser.add_argument("--device", type=str, default="cuda", help="Device to run the evaluation on. Default is 'cuda'.") +parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") +parser.add_argument("--max_length", type=int, default=2048, help="Maximum length for evaluation. Default is 2048.") +parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") + +args = parser.parse_args() + +# Convert precision argument to torch dtype +precision_dtype = getattr(torch, args.precision, torch.bfloat16) + +pile_eval( + repo_id=args.repo, + quant=args.quant, + calibrate_size=args.calibrate_size, + eval_size=args.eval_size, + device=args.device, + precision=precision_dtype, + max_length=args.max_length, + compile=args.compile +) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 217c0f2145..781aee1095 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -653,7 +653,7 @@ def get_layout_type(self) -> LayoutType: @dataclass(frozen=True) class AWQLayoutType(LayoutType): equalization_scale: torch.Tensor - + def pre_process(self, input: torch.Tensor) -> torch.Tensor: return input * self.equalization_scale @@ -661,6 +661,7 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: class AWQ_INT4_LayoutType(LayoutType): equalization_scale: torch.Tensor inner_k_tiles: int = 8 + def pre_process(self, input: torch.Tensor) -> torch.Tensor: input = input * self.equalization_scale orig_out_features, orig_in_features = input.shape @@ -670,9 +671,11 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: input, (0, in_features - orig_in_features, 0, out_features - orig_out_features), ) - return input + return input + def extra_repr(self): - return f"inner_k_tiles={self.inner_k_tiles}, equilization_scale={equalization_scale}" + return f"inner_k_tiles={self.inner_k_tiles}, equilization_scale={self.equalization_scale}" + @register_layout_cls(AWQ_INT4_LayoutType) class AWQ_INT4_Layout(TensorCoreTiledAQTLayout): diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 0ee2b59f66..78ec0833c0 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -1,13 +1,15 @@ import torch import torch.nn.functional as F -from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.dtypes.affine_quantized_tensor import AWQ_INT4_LayoutType, AWQLayoutType -from torchao.prototype.awq.core import AWQObserver +from torchao.prototype.awq.core import AWQObserver, _awq_quant from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, ) from torchao.dtypes import to_affine_quantized +from torchao.dtypes.uintx.Uintx import to_uintx +from typing import Optional, Tuple + class ObservedLinear(torch.nn.Linear): def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None): @@ -15,8 +17,9 @@ def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module self.act_obs = act_obs def forward(self, input: torch.Tensor): - self.act_obs(input) - return F.linear(input, self.weight, self.bias) + output = F.linear(input, self.weight, self.bias) + self.act_obs(input, output) + return output @classmethod def from_float(cls, float_linear, act_obs): @@ -25,7 +28,71 @@ def from_float(cls, float_linear, act_obs): observed_linear.bias = float_linear.bias return observed_linear +class AWQ_int4(torch.nn.Module): + def __init__( + self, + int_weight: torch.Tensor, + bias: Optional[torch.Tensor], + eq_scales: torch.Tensor, + original_shape: Tuple, + scales: torch.Tensor, + zeros: torch.Tensor, + qdtype, + device=None): + + super().__init__() + self.weight = to_uintx(int_weight.to(torch.uint8), qdtype) + self.bias = bias + self.scales = scales + self.zeros = zeros + self.eq_scales = eq_scales + self.original_shape = original_shape + + def forward(self, input: torch.Tensor): + dq = (self.weight.get_plain() - self.zeros) * self.scales + return torch.nn.functional.linear(input / self.eq_scales, dq.view(self.original_shape), self.bias) + +def _replace_with_custom_fn_if_matches_filter( + model, + replacement_fn, + filter_fn, + cur_fqn="", + device=None, +) -> None: + """ + Recursively replaces each child module in `model` with the result of `replacement_fn(child)` + if `filter_fn(child)` returns `True`. + Args: + model (torch.nn.Module): The model containing modules to be replaced. + replacement_fn (Callable[[torch.nn.Module], torch.nn.Module]): The function to replace matching modules. + filter_fn (Callable[[torch.nn.Module], bool]): The filter function to determine which modules to replace. + cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "". + device (device, optional): Device to move the model to before applying `filter_fn`. Defaults to None. + + Returns: + None + """ + # print(model) + if filter_fn(model, cur_fqn[:-1]): + if device is not None: + model.to(device=device) # move to device before quantization + # print("replacing ", model) + model = replacement_fn(model) + return model + else: + for name, child in model.named_children(): + if "attn" in name: + continue + new_child = _replace_with_custom_fn_if_matches_filter( + child, replacement_fn, filter_fn, f"{cur_fqn}{name}.", device + ) + if new_child is not child: + setattr(model, name, new_child) + if device is not None: + model.to(device=device) # move parent module to device + return model + def insert_awq_observer(model, quant_dtype, group_size, input_dtype, device): assert quant_dtype in ["int4", "int8"] _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) @@ -53,7 +120,8 @@ def insert_awq_observer(model, quant_dtype, group_size, input_dtype, device): def replace_with_observer(layer): observer = AWQObserver( - layer.weight, + layer.weight, + layer.bias, block_size, input_dtype, mapping_type, @@ -71,16 +139,16 @@ def replace_with_observer(layer): # variant of _get_linear_subclass_inserter that works with observed linear class def _observed_linear_subclass_inserter(constructor): def insert_subclass(observed_linear): - linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype) + linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, observed_linear.bias!=None, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype) linear.weight = torch.nn.Parameter(constructor(observed_linear), requires_grad=False) + linear.bias = observed_linear.bias return linear return insert_subclass -def awq_quant(quant_dtype = "int4", group_size = 128, scale_list =[]): +def awq_quant(quant_dtype = "int4", group_size = 128): def weight_quant_func(observed_linear): - assert observed_linear.act_obs.counter > 0, "Calibrate the observer first" assert quant_dtype in ["int4", "int8"] mapping_type = MappingType.ASYMMETRIC # weight quantization @@ -96,7 +164,6 @@ def weight_quant_func(observed_linear): zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT layout_type = AWQ_INT4_LayoutType(equalization_scale) - elif quant_dtype == "int8": mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 @@ -108,8 +175,7 @@ def weight_quant_func(observed_linear): layout_type = AWQLayoutType(equalization_scale) quant_min = None quant_max = None - - scale_list.append(equalization_scale) + return to_affine_quantized( observed_linear.weight, mapping_type, block_size, diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index 998b8d52b7..9534868c5a 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -2,6 +2,7 @@ import torch.nn.functional as F from dataclasses import dataclass from typing import Optional, Tuple +from tqdm import tqdm from copy import deepcopy from torchao.dtypes.utils import ( LayoutType, @@ -22,14 +23,48 @@ ZeroPointDomain, ) from torchao.quantization.quant_api import int4_weight_only +from torchao.dtypes.uintx.Uintx import to_uintx from torchao.quantization.observer import ( PerAxis, AffineQuantizedObserverBase, ) +import pdb +def _awq_quant(w, n_bit=8, q_group_size=-1, get_scale_zp=False): + # pdb.set_trace() + org_w_shape = w.shape + if q_group_size > 0: + assert org_w_shape[-1] % q_group_size == 0 + w = w.reshape(-1, q_group_size) + assert w.dim() == 2 + + max_val = w.amax(dim=1, keepdim=True) + min_val = w.amin(dim=1, keepdim=True) + max_int = 2**n_bit - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-5) / max_int + zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) + + assert torch.isnan(scales).sum() == 0 + assert torch.isnan(w).sum() == 0 + + w = torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) + # perform dequant + if not get_scale_zp: + w = (w - zeros) * scales + w = w.reshape(org_w_shape) + assert torch.isnan(w).sum() == 0 + + + + if get_scale_zp: + return w, scales, zeros + else: + return w class AWQObserver(AffineQuantizedObserverBase): def __init__(self, weight: torch.Tensor, + bias: torch.Tensor, block_size: Tuple, input_dtype: torch.dtype, mapping_type: MappingType, @@ -57,22 +92,26 @@ def __init__(self, zero_point_domain = zero_point_domain, ) self.weight = weight + self.bias = bias self.scale_options = scale_search_space_size - self.losses = torch.zeros(self.scale_options, dtype= input_dtype) - self.average = torch.zeros(self.weight.shape[-1], dtype=input_dtype, device=device) - self.counter = 0 + self.scales = None self.device = device - def forward(self, input: torch.Tensor): - self.average = self.average * self.counter / (self.counter + input.shape[0]) + input.abs().sum(dim=1).squeeze(0) / (self.counter + input.shape[0]) - self.counter += input.shape[0] + @torch.no_grad() + def forward(self, input: torch.Tensor, output: torch.Tensor): + # print(input.shape) + average = input.abs().view(-1,input.shape[-1]).mean(0) + + best_loss = float('inf') + best_ratio = -1 + scaleopts = [] + ws = [] for i in range(self.scale_options): - unquantized_result = F.linear(input, self.weight) - ratio = i *1.0 / self.scale_options - scales = self.average.pow(ratio).clamp(min=1e-4) + ratio = i *1 / self.scale_options + scales = average.pow(ratio) scales = scales / (scales.max() * scales.min()).sqrt() layout = AWQLayoutType(scales) if self.zero_point_domain == ZeroPointDomain.INT else AWQ_INT4_LayoutType(scales) - quantized_weight = to_affine_quantized( + w = to_affine_quantized( self.weight.data, self.mapping_type, self.block_size, @@ -85,13 +124,18 @@ def forward(self, input: torch.Tensor): preserve_zero = self.preserve_zero, zero_point_domain = self.zero_point_domain, layout_type = layout - ) - scaled_activation = (input / scales) - out = F.linear(scaled_activation, quantized_weight) - self.losses[i] += (unquantized_result - out).pow(2).mean().item() - + ) + # w = deepcopy(self.weight) * scales + # w = _awq_quant(w, q_group_size=128, n_bit=4) / scales + # ws.append(w.mean().item()) + q_out = F.linear(input/scales, w, self.bias) + scaleopts.append(q_out.mean().item()) + loss = (output - q_out).pow(2).mean().item() + if loss < best_loss: + self.scales = scales + best_ratio = ratio + best_loss = loss + # print(f"x: {input.mean().item(): .03f} w_: {torch.tensor(ws).sum().item()} o: {torch.tensor(scaleopts).sum().item(): .05f} ratio: {self.best_ratio}") + def calculate_qparams(self): - ratio = torch.argmin(self.losses) * 1.0 / self.scale_options - scales = self.average.pow(ratio).clamp(min=1e-4) - scales = scales / (scales.max() * scales.min()).sqrt() - return scales.detach() + return self.scales.detach() From 2ec38f1d81f631193fedae3f2068a5e0265969d4 Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Tue, 3 Sep 2024 09:24:11 -0400 Subject: [PATCH 12/69] merge pt1 --- test/prototype/test_awq.py | 235 +++++----------------- torchao/dtypes/affine_quantized_tensor.py | 71 ------- torchao/prototype/awq/api.py | 136 +++---------- torchao/prototype/awq/core.py | 129 ++++++------ tutorials/calibration_flow/AWQ.py | 140 +++++++++++++ 5 files changed, 291 insertions(+), 420 deletions(-) create mode 100644 tutorials/calibration_flow/AWQ.py diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 0fcabf1b34..f7272c02cf 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -12,189 +12,58 @@ import time -def simple_test(): - class ToyLinearModel(torch.nn.Module): - def __init__(self, m=512, n=256, k=128): - super().__init__() - self.linear1 = torch.nn.Linear(m, n, bias=False) - self.linear2 = torch.nn.Linear(n, k, bias=False) - self.linear3 = torch.nn.Linear(k, 1, bias=False) - def example_inputs(self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"): - return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)] - - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - x = self.linear3(x) - return x - - - device = ("cuda") - dataset_size = 1000 - dtype = torch.bfloat16 - l1,l2,l3 = 512,256,128 - m = ToyLinearModel(l1,l2,l3).eval().to(dtype).to(device) - m_bf16 = deepcopy(m) - - dataset = m.example_inputs(dataset_size, dtype=dtype, device=device) - calibration_data = dataset[:100] - bf16_out = torch.cat([m_bf16(i.squeeze(0)) for i in dataset], dim=0) - - - m_int4wo = deepcopy(m) - quantize_(m_int4wo, int8_weight_only()) - int4wo_out = torch.cat([m_int4wo(i.squeeze(0)) for i in dataset]) - - # calibrate - quant_dtype = "int4" - group_size = 128 - insert_awq_observer(m, quant_dtype, group_size, dtype, device) - for example in calibration_data: - m(example.to(device)) - # print('calibrated') - - # quantize - is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) - scales = [] - quantize_(m, awq_quant(quant_dtype = quant_dtype, scale_list=scales, group_size = group_size), is_observed_linear) - print(scales) - awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) - # m = torch.compile(m, fullgraph=True) - # compare accuracy - awq_err = torch.sum(torch.abs(awq_out - bf16_out)).sum().item() / dataset_size - int4wo_err = torch.sum(torch.abs(int4wo_out - bf16_out)).sum().item() / dataset_size - print(f"AWQ error: {awq_err}") - print(f"Int4WO error: {int4wo_err}") - - - -def create_batches_generator(data, batch_size): - for i in range(0, len(data), batch_size): - yield data[i:i + batch_size] - -def get_calib_dataset(tokenizer=None, n_samples=512, block_size=512, device="cuda"): - dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") - # dataset = dataset.shuffle(seed=42) - samples = [] - n_run = 0 - for data in dataset: - line = data["text"] - line = line.strip() - line_encoded = tokenizer.encode(line) - if len(line_encoded) > 512: - continue - sample = torch.tensor([line_encoded]) - if sample.numel() == 0: - continue - samples.append(sample) - n_run += 1 - if n_run == n_samples: - break - # now concatenate all samples and split according to block size - cat_samples = torch.cat(samples, dim=1) - n_split = cat_samples.shape[1] // block_size - print(f" * Split into {n_split} blocks") - return torch.cat([ - cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_split) - ], dim=0) - -def pile_eval(repo_id, quant, calibrate_size=100, eval_size=1000, device="cuda", precision=torch.bfloat16, max_length=2048, compile=False): - print("Loading model ...") - torch.manual_seed(34) - t0 = time.time() - tokenizer = AutoTokenizer.from_pretrained(repo_id) - model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision).to(device) - print(f"Time to load model: {time.time() - t0:.02f} seconds") - # print(model) - - if quant.startswith("awq"): - quant_dtype = quant.split("-")[1] - print(f"running {quant} calibration") - t0 = time.time() - quant_dtype = quant.split("-")[1] - group_size = 128 if quant_dtype == "int4" else -1 - insert_awq_observer(model, quant_dtype, group_size, precision, device) - calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibrate_size) - - model(calibration_data.to(device)) - - print(f"time for calibration: {time.time() - t0:.02f} seconds") - is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) - t0 = time.time() - # print(model) - quantize_(model, awq_quant(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) - print(f"time for quantization: {time.time() - t0:.02f} seconds") - - elif quant=="int8": - print("running int8 quantization") - quantize_(model, int8_weight_only()) - # print(model) - - elif quant=="int4": - print("running int4 quantization") - quantize_(model, int4_weight_only()) - # print(model) - if compile: - model = torch.compile(model) - - testenc = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") - testenc = tokenizer("\n\n".join(testenc["text"]), return_tensors="pt") - testenc = testenc.input_ids.to(model.device) - nsamples = testenc.numel() // max_length - model = model.eval() - nlls = [] - for i in tqdm(range(nsamples), desc="evaluating..."): - batch = testenc[:, (i * max_length) : ((i + 1) * max_length)].to( - model.device - ) - with torch.no_grad(): - lm_logits = model(batch).logits - shift_logits = lm_logits[:, :-1, :].contiguous().float() - shift_labels = testenc[ - :, (i * max_length) : ((i + 1) * max_length) - ][:, 1:] - loss_fct = torch.nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) - ) - neg_log_likelihood = loss.float() * max_length - nlls.append(neg_log_likelihood) - - ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * max_length)) - print(f"Perplexity: {ppl.item():.5f}") - return ppl - # int4: 29.75957 - # real-awq-int4: 28.9590 - # awq-int4: - - -parser = argparse.ArgumentParser(description="Evaluate a model with the specified parameters.") +class ToyLinearModel(torch.nn.Module): + def __init__(self, m=512, n=256, k=128): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + self.linear3 = torch.nn.Linear(k, 1, bias=False) + + def example_inputs(self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"): + return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)] + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + return x -# Positional arguments -parser.add_argument("repo", type=str, help="Repository ID of the model.") -parser.add_argument("quant", type=str, help="Quantization method or file path.") - -# Optional arguments with default values -parser.add_argument("--calibrate_size", type=int, default=100, help="Calibration size. Default is 100.") -parser.add_argument("--eval_size", type=int, default=1000, help="Evaluation size. Default is 1000.") -parser.add_argument("--device", type=str, default="cuda", help="Device to run the evaluation on. Default is 'cuda'.") -parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") -parser.add_argument("--max_length", type=int, default=2048, help="Maximum length for evaluation. Default is 2048.") -parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") - -args = parser.parse_args() - -# Convert precision argument to torch dtype -precision_dtype = getattr(torch, args.precision, torch.bfloat16) -pile_eval( - repo_id=args.repo, - quant=args.quant, - calibrate_size=args.calibrate_size, - eval_size=args.eval_size, - device=args.device, - precision=precision_dtype, - max_length=args.max_length, - compile=args.compile -) +device = ("cuda") +dataset_size = 1000 +dtype = torch.bfloat16 +l1,l2,l3 = 512,256,128 + +m = ToyLinearModel(l1,l2,l3).eval().to(dtype).to(device) +m_bf16 = deepcopy(m) + +dataset = m.example_inputs(dataset_size, dtype=dtype, device=device) +calibration_data = dataset[:100] +bf16_out = torch.cat([m_bf16(i.squeeze(0)) for i in dataset], dim=0) + + +m_int4wo = deepcopy(m) +quantize_(m_int4wo, int8_weight_only()) +int4wo_out = torch.cat([m_int4wo(i.squeeze(0)) for i in dataset]) + +# calibrate +quant_dtype = torch.uint4 +group_size = 128 +insert_awq_observer(m, quant_dtype, group_size, dtype, device) +for example in calibration_data: + m(example.to(device)) +# print('calibrated') + +# quantize +is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) +scales = [] +quantize_(m, awq_quant(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear) +print(scales) +awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) +# m = torch.compile(m, fullgraph=True) +# compare accuracy +awq_err = torch.sum(torch.abs(awq_out - bf16_out)).sum().item() / dataset_size +int4wo_err = torch.sum(torch.abs(int4wo_out - bf16_out)).sum().item() / dataset_size +print(f"AWQ error: {awq_err}") +print(f"Int4WO error: {int4wo_err}") diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 781aee1095..0ad6ebf8b6 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -650,68 +650,6 @@ def get_layout_type(self) -> LayoutType: return self.layout_type -@dataclass(frozen=True) -class AWQLayoutType(LayoutType): - equalization_scale: torch.Tensor - - def pre_process(self, input: torch.Tensor) -> torch.Tensor: - return input * self.equalization_scale - -@dataclass(frozen=True) -class AWQ_INT4_LayoutType(LayoutType): - equalization_scale: torch.Tensor - inner_k_tiles: int = 8 - - def pre_process(self, input: torch.Tensor) -> torch.Tensor: - input = input * self.equalization_scale - orig_out_features, orig_in_features = input.shape - in_features = find_multiple(orig_in_features, 1024) - out_features = find_multiple(orig_out_features, 8) - input = torch.nn.functional.pad( - input, - (0, in_features - orig_in_features, 0, out_features - orig_out_features), - ) - return input - - def extra_repr(self): - return f"inner_k_tiles={self.inner_k_tiles}, equilization_scale={self.equalization_scale}" - - -@register_layout_cls(AWQ_INT4_LayoutType) -class AWQ_INT4_Layout(TensorCoreTiledAQTLayout): - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - layout_type: LayoutType, - ): - assert isinstance(layout_type, AWQ_INT4_LayoutType) - if TORCH_VERSION_AT_LEAST_2_5: - int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" - else: - assert int_data.dtype == torch.int32, "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" - packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, layout_type.inner_k_tiles) - scale = scale.reshape(int_data.shape[0], -1) - zero_point = zero_point.reshape(int_data.shape[0], -1) - scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) - return cls(packed_weight, scale_and_zero, False, layout_type) - -@register_layout_cls(AWQLayoutType) -class AWQ_AQTLayout(PlainAQTLayout): - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - layout_type: LayoutType, - ): - assert isinstance(layout_type, AWQLayoutType) - return cls(int_data, scale, zero_point, layout_type) - ##################################################### # torch functional and aten operator implementation # ##################################################### @@ -857,8 +795,6 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): f"need input_tensor shape: {input_tensor.shape} final" f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " ) - if isinstance(weight_tensor.layout_type, AWQ_INT4_LayoutType): - input_tensor /= weight_tensor.layout_tensor.layout_type.equalization_scale # TODO: check groupsize quantization # avoid circular dep, TODO: move this to a common util.py act_mat = input_tensor @@ -924,12 +860,6 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): y += bias.to(m.dtype) return y -def _linear_awq_check(input_tensor, weight_tensor, bias): - return isinstance(weight_tensor.layout_tensor, AWQ_AQTLayout) - -def _linear_awq_impl(input_tensor, weight_tensor, bias): - return _linear_fp_act_int8_weight_impl(input_tensor / weight_tensor.layout_tensor.layout_type.equalization_scale, weight_tensor, bias) - def _register_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), @@ -937,7 +867,6 @@ def _register_quantized_linear_dispatches(): (_linear_quantized_act_fallback_check, _linear_quantized_act_fallback_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), - (_linear_awq_check, _linear_awq_impl), ]: _register_quantized_linear_dispatch(dispatch_condition, impl) diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 78ec0833c0..70752a8acc 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -1,113 +1,34 @@ import torch import torch.nn.functional as F -from torchao.dtypes.affine_quantized_tensor import AWQ_INT4_LayoutType, AWQLayoutType -from torchao.prototype.awq.core import AWQObserver, _awq_quant +from torchao.dtypes.affine_quantized_tensor import AWQLayoutType +from torchao.prototype.awq.core import AWQObserver, ObservedLinear from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, ) +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.dtypes import to_affine_quantized from torchao.dtypes.uintx.Uintx import to_uintx from typing import Optional, Tuple -class ObservedLinear(torch.nn.Linear): - def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None): - super().__init__(in_features, out_features, bias, device, dtype) - self.act_obs = act_obs - def forward(self, input: torch.Tensor): - output = F.linear(input, self.weight, self.bias) - self.act_obs(input, output) - return output - - @classmethod - def from_float(cls, float_linear, act_obs): - observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype) - observed_linear.weight = float_linear.weight - observed_linear.bias = float_linear.bias - return observed_linear - -class AWQ_int4(torch.nn.Module): - def __init__( - self, - int_weight: torch.Tensor, - bias: Optional[torch.Tensor], - eq_scales: torch.Tensor, - original_shape: Tuple, - scales: torch.Tensor, - zeros: torch.Tensor, - qdtype, - device=None): - - super().__init__() - self.weight = to_uintx(int_weight.to(torch.uint8), qdtype) - self.bias = bias - self.scales = scales - self.zeros = zeros - self.eq_scales = eq_scales - self.original_shape = original_shape - - def forward(self, input: torch.Tensor): - dq = (self.weight.get_plain() - self.zeros) * self.scales - return torch.nn.functional.linear(input / self.eq_scales, dq.view(self.original_shape), self.bias) - -def _replace_with_custom_fn_if_matches_filter( - model, - replacement_fn, - filter_fn, - cur_fqn="", - device=None, -) -> None: - """ - Recursively replaces each child module in `model` with the result of `replacement_fn(child)` - if `filter_fn(child)` returns `True`. - - Args: - model (torch.nn.Module): The model containing modules to be replaced. - replacement_fn (Callable[[torch.nn.Module], torch.nn.Module]): The function to replace matching modules. - filter_fn (Callable[[torch.nn.Module], bool]): The filter function to determine which modules to replace. - cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "". - device (device, optional): Device to move the model to before applying `filter_fn`. Defaults to None. - - Returns: - None - """ - # print(model) - if filter_fn(model, cur_fqn[:-1]): - if device is not None: - model.to(device=device) # move to device before quantization - # print("replacing ", model) - model = replacement_fn(model) - return model - else: - for name, child in model.named_children(): - if "attn" in name: - continue - new_child = _replace_with_custom_fn_if_matches_filter( - child, replacement_fn, filter_fn, f"{cur_fqn}{name}.", device - ) - if new_child is not child: - setattr(model, name, new_child) - if device is not None: - model.to(device=device) # move parent module to device - return model -def insert_awq_observer(model, quant_dtype, group_size, input_dtype, device): - assert quant_dtype in ["int4", "int8"] +def insert_awq_observer(model: torch.nn.Module, quant_dtype: torch.dtype, group_size: int, input_dtype: torch.dtype, device: torch.device): _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) - if quant_dtype == "int4": + if quant_dtype == torch.uint4: mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) - target_dtype = torch.int32 + target_dtype = torch.uint4 quant_min = 0 quant_max = 15 - eps = 1e-6 - preserve_zero = False - zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT + eps = torch.finfo(torch.float32).eps + preserve_zero = True + zero_point_dtype = torch.int64 + zero_point_domain = ZeroPointDomain.INT + print("##########################\ninsert-uint4\n##########################\n") - elif quant_dtype == "int8": + elif quant_dtype == torch.int8: mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps @@ -117,6 +38,9 @@ def insert_awq_observer(model, quant_dtype, group_size, input_dtype, device): block_size = (1, -1) quant_min = None quant_max = None + print("##########################\ninsert-int8\n##########################\n") + else: + raise NotImplementedError(f"{quant_dtype} not supported. Use either torch.uint4 or torch.int8") def replace_with_observer(layer): observer = AWQObserver( @@ -136,7 +60,6 @@ def replace_with_observer(layer): return ObservedLinear.from_float(layer, observer) _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) -# variant of _get_linear_subclass_inserter that works with observed linear class def _observed_linear_subclass_inserter(constructor): def insert_subclass(observed_linear): linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, observed_linear.bias!=None, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype) @@ -146,35 +69,38 @@ def insert_subclass(observed_linear): return insert_subclass -def awq_quant(quant_dtype = "int4", group_size = 128): +def awq_quant(quant_dtype = torch.uint4, group_size = 128): def weight_quant_func(observed_linear): - assert quant_dtype in ["int4", "int8"] - mapping_type = MappingType.ASYMMETRIC # weight quantization equalization_scale = observed_linear.act_obs.calculate_qparams() - if quant_dtype == "int4": + if quant_dtype == torch.uint4: mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) - target_dtype = torch.int32 + target_dtype = torch.uint8 quant_min = 0 quant_max = 15 - eps = 1e-6 - preserve_zero = False - zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT - layout_type = AWQ_INT4_LayoutType(equalization_scale) - elif quant_dtype == "int8": + eps = torch.finfo(torch.float32).eps + preserve_zero = True + zero_point_dtype = torch.int64 + zero_point_domain = ZeroPointDomain.INT + layout_type = AWQLayoutType(equalization_scale, quant_dtype) + + elif quant_dtype == torch.int8: mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 zero_point_domain = ZeroPointDomain.INT preserve_zero = True - block_size = (1, observed_linear.weight.shape[1]) - layout_type = AWQLayoutType(equalization_scale) + block_size = (1, -1) quant_min = None quant_max = None + layout_type = AWQLayoutType(equalization_scale, quant_dtype) + + else: + print(quant_dtype) + raise("AWQ supports only uint4 and int8 quantization for now") return to_affine_quantized( observed_linear.weight, diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index 9534868c5a..d8555fe239 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -1,65 +1,25 @@ +from dataclasses import dataclass +from typing import Tuple, Optional + import torch import torch.nn.functional as F -from dataclasses import dataclass -from typing import Optional, Tuple -from tqdm import tqdm -from copy import deepcopy -from torchao.dtypes.utils import ( - LayoutType, -) + +from torchao.dtypes.uintx.Uintx import to_uintx from torchao.dtypes.affine_quantized_tensor import ( - PlainAQTLayout, - TensorCoreTiledAQTLayout, - register_layout_cls, to_affine_quantized, - AWQ_INT4_LayoutType, - AWQLayoutType - -) -from torchao.utils import find_multiple + LayoutType, + register_layout_cls, + PlainAQTLayout +) from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, ) -from torchao.quantization.quant_api import int4_weight_only -from torchao.dtypes.uintx.Uintx import to_uintx from torchao.quantization.observer import ( - PerAxis, AffineQuantizedObserverBase, ) -import pdb -def _awq_quant(w, n_bit=8, q_group_size=-1, get_scale_zp=False): - # pdb.set_trace() - org_w_shape = w.shape - if q_group_size > 0: - assert org_w_shape[-1] % q_group_size == 0 - w = w.reshape(-1, q_group_size) - assert w.dim() == 2 - - max_val = w.amax(dim=1, keepdim=True) - min_val = w.amin(dim=1, keepdim=True) - max_int = 2**n_bit - 1 - min_int = 0 - scales = (max_val - min_val).clamp(min=1e-5) / max_int - zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) - - assert torch.isnan(scales).sum() == 0 - assert torch.isnan(w).sum() == 0 - - w = torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - # perform dequant - if not get_scale_zp: - w = (w - zeros) * scales - w = w.reshape(org_w_shape) - assert torch.isnan(w).sum() == 0 - - - if get_scale_zp: - return w, scales, zeros - else: - return w class AWQObserver(AffineQuantizedObserverBase): def __init__(self, @@ -99,23 +59,21 @@ def __init__(self, @torch.no_grad() def forward(self, input: torch.Tensor, output: torch.Tensor): - # print(input.shape) average = input.abs().view(-1,input.shape[-1]).mean(0) best_loss = float('inf') - best_ratio = -1 scaleopts = [] - ws = [] for i in range(self.scale_options): - ratio = i *1 / self.scale_options + ratio = i * 1 / self.scale_options scales = average.pow(ratio) scales = scales / (scales.max() * scales.min()).sqrt() - layout = AWQLayoutType(scales) if self.zero_point_domain == ZeroPointDomain.INT else AWQ_INT4_LayoutType(scales) + layout = AWQLayoutType(scales, self.target_dtype) + tensor_dtype = torch.int8 if self.target_dtype == torch.int8 else torch.uint8 w = to_affine_quantized( self.weight.data, self.mapping_type, self.block_size, - self.target_dtype, + tensor_dtype, quant_min = self.quant_min, quant_max = self.quant_max, eps = self.eps, @@ -123,19 +81,68 @@ def forward(self, input: torch.Tensor, output: torch.Tensor): zero_point_dtype = self.zero_point_dtype, preserve_zero = self.preserve_zero, zero_point_domain = self.zero_point_domain, - layout_type = layout + # layout_type = layout ) - # w = deepcopy(self.weight) * scales - # w = _awq_quant(w, q_group_size=128, n_bit=4) / scales - # ws.append(w.mean().item()) q_out = F.linear(input/scales, w, self.bias) scaleopts.append(q_out.mean().item()) loss = (output - q_out).pow(2).mean().item() if loss < best_loss: self.scales = scales - best_ratio = ratio best_loss = loss - # print(f"x: {input.mean().item(): .03f} w_: {torch.tensor(ws).sum().item()} o: {torch.tensor(scaleopts).sum().item(): .05f} ratio: {self.best_ratio}") def calculate_qparams(self): return self.scales.detach() + + +class ObservedLinear(torch.nn.Linear): + def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None): + super().__init__(in_features, out_features, bias, device, dtype) + self.act_obs = act_obs + + def forward(self, input: torch.Tensor): + output = F.linear(input, self.weight, self.bias) + self.act_obs(input, output) + return output + + @classmethod + def from_float(cls, float_linear: torch.nn.Linear, act_obs: AWQObserver): + observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype) + observed_linear.weight = float_linear.weight + observed_linear.bias = float_linear.bias + return observed_linear + + +@dataclass(frozen=True) +class AWQLayoutType(LayoutType): + equalization_scale: torch.Tensor + dtype: torch.dtype + + def pre_process(self, input: torch.Tensor) -> torch.Tensor: + return input * self.equalization_scale + + def post_process(self, input: torch.Tensor) -> torch.Tensor: + + return to_uintx(input, self.dtype) + + def _linear_awq_check(input_tensor, weight_tensor, bias): + return isinstance(weight_tensor.layout_tensor, AWQ_AQTLayout) + + def _linear_awq_impl(input_tensor, weight_tensor, bias): + return F.linear(input_tensor / weight_tensor.layout_tensor.layout_type.equalization_scale, weight_tensor, bias) + + +@register_layout_cls(AWQLayoutType) +class AWQ_AQTLayout(PlainAQTLayout): + @classmethod + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return self.int_data.get_plain(), self.scale, self.zero_point + + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + layout_type: LayoutType, + ): + assert isinstance(layout_type, AWQLayoutType) + return cls(int_data, scale, zero_point, layout_type) \ No newline at end of file diff --git a/tutorials/calibration_flow/AWQ.py b/tutorials/calibration_flow/AWQ.py new file mode 100644 index 0000000000..0ccbc5dbd7 --- /dev/null +++ b/tutorials/calibration_flow/AWQ.py @@ -0,0 +1,140 @@ +import torch +from torchao.quantization import quantize_, int4_weight_only, int8_weight_only +from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer, awq_quant +import argparse +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset +from tqdm import tqdm +import time + +# adapted from: https://github.com/mit-han-lab/llm-awq +def get_calib_dataset(tokenizer=None, n_samples=512, device="cuda"): + dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") + block_size=512 + samples = [] + n_run = 0 + for data in dataset: + line = data["text"] + line = line.strip() + line_encoded = tokenizer.encode(line) + if len(line_encoded) > 512: + continue + sample = torch.tensor([line_encoded]) + if sample.numel() == 0: + continue + samples.append(sample) + n_run += 1 + if n_run == n_samples: + break + + cat_samples = torch.cat(samples, dim=1) + n_split = cat_samples.shape[1] // block_size + print(f" * Split into {n_split} blocks") + return torch.cat([ + cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_split) + ], dim=0) + +def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size: int = 128, device="cuda", precision=torch.bfloat16, max_length=2048, compile=False): + print("Loading model ...") + torch.manual_seed(34) + t0 = time.time() + tokenizer = AutoTokenizer.from_pretrained(repo_id) + model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision).to(device) + print(f"Time to load model: {time.time() - t0:.02f} seconds") + + if quant.startswith("awq"): + quant_dtype = quant.split("-")[1] + quant_dtype = getattr(torch, quant_dtype, torch.bfloat16) + print(f"running {quant_dtype} calibration") + t0 = time.time() + + insert_awq_observer(model, quant_dtype, group_size, precision, device) + calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibrate_size) + model(calibration_data.to(device)) + print(f"time for calibration: {time.time() - t0:.02f} seconds") + is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) + t0 = time.time() + quantize_(model, awq_quant(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) + print(f"time for quantization: {time.time() - t0:.02f} seconds") + + elif quant=="int8": + print("running int8 quantization") + quantize_(model, int8_weight_only()) + + elif quant=="int4": + print("running int4 quantization") + quantize_(model, int4_weight_only()) + + if compile: + model = torch.compile(model) + + testenc = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + testenc = tokenizer("\n\n".join(testenc["text"]), return_tensors="pt") + testenc = testenc.input_ids.to(model.device) + nsamples = testenc.numel() // max_length + model = model.eval() + # calculate perplexity + nlls = [] + for i in tqdm(range(nsamples), desc="evaluating..."): + batch = testenc[:, (i * max_length) : ((i + 1) * max_length)].to( + model.device + ) + with torch.no_grad(): + lm_logits = model(batch).logits + shift_logits = lm_logits[:, :-1, :].contiguous().float() + shift_labels = testenc[ + :, (i * max_length) : ((i + 1) * max_length) + ][:, 1:] + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + neg_log_likelihood = loss.float() * max_length + nlls.append(neg_log_likelihood) + + ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * max_length)) + + return ppl + + +parser = argparse.ArgumentParser(description="Evaluate a model with the specified parameters.") + + +# Optional arguments with default values +parser.add_argument("repo", type=str, help="Repository ID of the model.") +parser.add_argument("quant", type=str, help="Quantization method or file path.",choices=["uint4", "int8"]) +parser.add_argument("--calibration_size", type=int, default=100, help="Calibration size. Default is 100.") +parser.add_argument("--group_size", type=int, default=128, help="Group size to use for weights. Default is '128'") +parser.add_argument("--device", type=str, default="cuda", help="Device to run the evaluation on. Default is 'cuda'.") +parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") +parser.add_argument("--max_length", type=int, default=2048, help="Maximum length for evaluation. Default is 2048.") +parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") + +args = parser.parse_args() + +# Convert precision argument to torch dtype +precision_dtype = getattr(torch, args.precision, torch.bfloat16) + +awq = wikitext2_ppl( + repo_id=args.repo, + quant="awq-"+args.quant, + calibrate_size=args.calibration_size, + group_size= args.group_size, + device=args.device, + precision=precision_dtype, + max_length=args.max_length, + compile=args.compile +) + +# aqt = wikitext2_ppl( +# repo_id=args.repo, +# quant=args.quant, +# calibrate_size=args.calibrate_size, +# group_size= args.group_size, +# device=args.device, +# precision=precision_dtype, +# max_length=args.max_length, +# compile=args.compile +# ) +print(f"AWQ Perplexity: {awq.item():.5f}") +# print(f"Affine quantized Perplexity: {aqt.item():.5f}") \ No newline at end of file From 4aae94bb343c6fb8e5344ffe7421a8fcc327e29b Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Tue, 3 Sep 2024 15:47:14 -0400 Subject: [PATCH 13/69] works/created tutorial --- test/prototype/test_awq.py | 85 ++++++++++++++----------------- torchao/prototype/awq/api.py | 11 ++-- torchao/prototype/awq/core.py | 23 +++++---- tutorials/calibration_flow/AWQ.py | 29 ++++++----- 4 files changed, 71 insertions(+), 77 deletions(-) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index f7272c02cf..3f5e44eee2 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -1,16 +1,8 @@ from copy import deepcopy import torch -import torch.nn.functional as F -from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.quantization import quantize_, int4_weight_only, int8_weight_only from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer, awq_quant -import argparse -import os -from transformers import AutoModelForCausalLM, AutoTokenizer -from datasets import load_dataset -from tqdm import tqdm -import time - +import pytest class ToyLinearModel(torch.nn.Module): @@ -29,41 +21,40 @@ def forward(self, x): x = self.linear3(x) return x - -device = ("cuda") -dataset_size = 1000 -dtype = torch.bfloat16 -l1,l2,l3 = 512,256,128 - -m = ToyLinearModel(l1,l2,l3).eval().to(dtype).to(device) -m_bf16 = deepcopy(m) - -dataset = m.example_inputs(dataset_size, dtype=dtype, device=device) -calibration_data = dataset[:100] -bf16_out = torch.cat([m_bf16(i.squeeze(0)) for i in dataset], dim=0) - - -m_int4wo = deepcopy(m) -quantize_(m_int4wo, int8_weight_only()) -int4wo_out = torch.cat([m_int4wo(i.squeeze(0)) for i in dataset]) - -# calibrate -quant_dtype = torch.uint4 -group_size = 128 -insert_awq_observer(m, quant_dtype, group_size, dtype, device) -for example in calibration_data: - m(example.to(device)) -# print('calibrated') - -# quantize -is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) -scales = [] -quantize_(m, awq_quant(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear) -print(scales) -awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) -# m = torch.compile(m, fullgraph=True) -# compare accuracy -awq_err = torch.sum(torch.abs(awq_out - bf16_out)).sum().item() / dataset_size -int4wo_err = torch.sum(torch.abs(int4wo_out - bf16_out)).sum().item() / dataset_size -print(f"AWQ error: {awq_err}") -print(f"Int4WO error: {int4wo_err}") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test(): + device = ("cuda") + dataset_size = 1000 + original_dtype = torch.bfloat16 + l1,l2,l3 = 512,256,128 + + m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device) + m_bf16 = deepcopy(m) + + dataset = m.example_inputs(dataset_size, dtype=original_dtype, device=device) + calibration_data = dataset[:100] + bf16_out = torch.cat([m_bf16(i.squeeze(0)) for i in dataset], dim=0) + + + m_int4wo = deepcopy(m) + quantize_(m_int4wo, int4_weight_only()) + int4wo_out = torch.cat([m_int4wo(i.squeeze(0)) for i in dataset]) + + # calibrate + quant_dtype = torch.uint4 + group_size = 128 + insert_awq_observer(m, quant_dtype, group_size, original_dtype, device) + for example in calibration_data: + m(example.to(device)) + # print('calibrated') + + # quantize + is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) + quantize_(m, awq_quant(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear) + awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) + m = torch.compile(m, fullgraph=True) + # compare accuracy + awq_err = torch.sum(torch.abs(awq_out - bf16_out)).sum().item() / dataset_size + int4wo_err = torch.sum(torch.abs(int4wo_out - bf16_out)).sum().item() / dataset_size + print(f"AWQ error: {awq_err}") + print(f"Int4WO error: {int4wo_err}") diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 70752a8acc..fe78379e42 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -1,13 +1,12 @@ import torch import torch.nn.functional as F -from torchao.dtypes.affine_quantized_tensor import AWQLayoutType -from torchao.prototype.awq.core import AWQObserver, ObservedLinear +from torchao.prototype.awq.core import AWQObserver, ObservedLinear, AWQLayoutType from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, ) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter -from torchao.dtypes import to_affine_quantized +from torchao.dtypes import to_affine_quantized_intx from torchao.dtypes.uintx.Uintx import to_uintx from typing import Optional, Tuple @@ -26,7 +25,6 @@ def insert_awq_observer(model: torch.nn.Module, quant_dtype: torch.dtype, group_ preserve_zero = True zero_point_dtype = torch.int64 zero_point_domain = ZeroPointDomain.INT - print("##########################\ninsert-uint4\n##########################\n") elif quant_dtype == torch.int8: mapping_type = MappingType.SYMMETRIC @@ -38,7 +36,7 @@ def insert_awq_observer(model: torch.nn.Module, quant_dtype: torch.dtype, group_ block_size = (1, -1) quant_min = None quant_max = None - print("##########################\ninsert-int8\n##########################\n") + else: raise NotImplementedError(f"{quant_dtype} not supported. Use either torch.uint4 or torch.int8") @@ -99,10 +97,9 @@ def weight_quant_func(observed_linear): layout_type = AWQLayoutType(equalization_scale, quant_dtype) else: - print(quant_dtype) raise("AWQ supports only uint4 and int8 quantization for now") - return to_affine_quantized( + return to_affine_quantized_intx( observed_linear.weight, mapping_type, block_size, target_dtype, quant_min, diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index d8555fe239..035e9c9a51 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -6,10 +6,11 @@ from torchao.dtypes.uintx.Uintx import to_uintx from torchao.dtypes.affine_quantized_tensor import ( - to_affine_quantized, + to_affine_quantized_intx, LayoutType, register_layout_cls, - PlainAQTLayout + PlainAQTLayout, + register_aqt_quantized_linear_dispatch ) from torchao.quantization.quant_primitives import ( @@ -69,7 +70,7 @@ def forward(self, input: torch.Tensor, output: torch.Tensor): scales = scales / (scales.max() * scales.min()).sqrt() layout = AWQLayoutType(scales, self.target_dtype) tensor_dtype = torch.int8 if self.target_dtype == torch.int8 else torch.uint8 - w = to_affine_quantized( + w = to_affine_quantized_intx( self.weight.data, self.mapping_type, self.block_size, @@ -81,7 +82,7 @@ def forward(self, input: torch.Tensor, output: torch.Tensor): zero_point_dtype = self.zero_point_dtype, preserve_zero = self.preserve_zero, zero_point_domain = self.zero_point_domain, - # layout_type = layout + layout_type = layout ) q_out = F.linear(input/scales, w, self.bias) scaleopts.append(q_out.mean().item()) @@ -123,20 +124,21 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: def post_process(self, input: torch.Tensor) -> torch.Tensor: return to_uintx(input, self.dtype) - + + def _quantized_linear_impl(input_tensor, weight_tensor, bias): + return F.linear(input_tensor / weight_tensor.layout_tensor.layout_type.equalization_scale, weight_tensor.dequantize(), bias) + def _linear_awq_check(input_tensor, weight_tensor, bias): return isinstance(weight_tensor.layout_tensor, AWQ_AQTLayout) - def _linear_awq_impl(input_tensor, weight_tensor, bias): - return F.linear(input_tensor / weight_tensor.layout_tensor.layout_type.equalization_scale, weight_tensor, bias) - - +register_aqt_quantized_linear_dispatch(AWQLayoutType._linear_awq_check, AWQLayoutType._quantized_linear_impl) + @register_layout_cls(AWQLayoutType) class AWQ_AQTLayout(PlainAQTLayout): - @classmethod def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.int_data.get_plain(), self.scale, self.zero_point + @classmethod def from_plain( cls, int_data: torch.Tensor, @@ -144,5 +146,4 @@ def from_plain( zero_point: torch.Tensor, layout_type: LayoutType, ): - assert isinstance(layout_type, AWQLayoutType) return cls(int_data, scale, zero_point, layout_type) \ No newline at end of file diff --git a/tutorials/calibration_flow/AWQ.py b/tutorials/calibration_flow/AWQ.py index 0ccbc5dbd7..72f8dc8eb4 100644 --- a/tutorials/calibration_flow/AWQ.py +++ b/tutorials/calibration_flow/AWQ.py @@ -38,6 +38,7 @@ def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size print("Loading model ...") torch.manual_seed(34) t0 = time.time() + # load any model with torch.nn.linear layers tokenizer = AutoTokenizer.from_pretrained(repo_id) model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision).to(device) print(f"Time to load model: {time.time() - t0:.02f} seconds") @@ -48,10 +49,13 @@ def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size print(f"running {quant_dtype} calibration") t0 = time.time() + # insert observers to find average magnitude and calculate scales insert_awq_observer(model, quant_dtype, group_size, precision, device) calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibrate_size) model(calibration_data.to(device)) print(f"time for calibration: {time.time() - t0:.02f} seconds") + + # use awq_quant() to apply awq quantization is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) t0 = time.time() quantize_(model, awq_quant(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) @@ -61,13 +65,14 @@ def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size print("running int8 quantization") quantize_(model, int8_weight_only()) - elif quant=="int4": + elif quant=="uint4": print("running int4 quantization") quantize_(model, int4_weight_only()) if compile: model = torch.compile(model) + # eval on wikitext2 testenc = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") testenc = tokenizer("\n\n".join(testenc["text"]), return_tensors="pt") testenc = testenc.input_ids.to(model.device) @@ -126,15 +131,15 @@ def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size compile=args.compile ) -# aqt = wikitext2_ppl( -# repo_id=args.repo, -# quant=args.quant, -# calibrate_size=args.calibrate_size, -# group_size= args.group_size, -# device=args.device, -# precision=precision_dtype, -# max_length=args.max_length, -# compile=args.compile -# ) +aqt = wikitext2_ppl( + repo_id=args.repo, + quant=args.quant, + calibrate_size=args.calibration_size, + group_size= args.group_size, + device=args.device, + precision=precision_dtype, + max_length=args.max_length, + compile=args.compile +) print(f"AWQ Perplexity: {awq.item():.5f}") -# print(f"Affine quantized Perplexity: {aqt.item():.5f}") \ No newline at end of file +print(f"Affine quantized Perplexity: {aqt.item():.5f}") \ No newline at end of file From 8e058d7ed16978b629ebdd32f9d4072c08a21dc3 Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Thu, 5 Sep 2024 17:46:51 -0400 Subject: [PATCH 14/69] added docs, tests, cleaned code --- scripts/create_weight_map.py | 1 - test/prototype/test_awq.py | 53 +++++---- torchao/_models/llama/eval.py | 46 ++++---- torchao/dtypes/affine_quantized_tensor.py | 3 +- torchao/dtypes/uintx/Uintx.py | 2 +- torchao/prototype/awq/api.py | 110 +++++++++--------- torchao/prototype/awq/core.py | 54 ++++++--- .../prototype/awq/example.py | 102 ++++++++-------- torchao/prototype/awq/readme.md | 13 +++ 9 files changed, 211 insertions(+), 173 deletions(-) rename tutorials/calibration_flow/AWQ.py => torchao/prototype/awq/example.py (60%) create mode 100644 torchao/prototype/awq/readme.md diff --git a/scripts/create_weight_map.py b/scripts/create_weight_map.py index 79d611f738..697dbc6caf 100644 --- a/scripts/create_weight_map.py +++ b/scripts/create_weight_map.py @@ -1,6 +1,5 @@ """ This file produces a file named pytorch_model.bin.index.json based on the downloaded model weights. -It was primarily used to create run evals on llama2.c-stories15M model. """ import json import torch diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 3f5e44eee2..63ff6a98ad 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -1,8 +1,10 @@ from copy import deepcopy -import torch -from torchao.quantization import quantize_, int4_weight_only, int8_weight_only -from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer, awq_quant import pytest +import torch +from torchao.quantization import quantize_ +from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer_, awq_uintx +from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + class ToyLinearModel(torch.nn.Module): @@ -21,40 +23,43 @@ def forward(self, x): x = self.linear3(x) return x -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test(): - device = ("cuda") - dataset_size = 1000 - original_dtype = torch.bfloat16 + +devices = ["cuda"] +# torch.uintx dtypes are introduced in 2.3 +if TORCH_VERSION_AT_LEAST_2_3: + qdtypes = (torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7, torch.uint8) +else: + qdtypes = () + +idtypes = (torch.bfloat16,)#, torch.half, torch.float32) +@pytest.mark.parametrize("device", devices) +@pytest.mark.parametrize("qdtype", qdtypes) +@pytest.mark.parametrize("idtype", idtypes) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3,reason="torch.uint(2-7) requires torch2.3+") +def test(device, qdtype, idtype): + dataset_size = 100 l1,l2,l3 = 512,256,128 + original_dtype = idtype + quant_dtype = qdtype + group_size = 128 m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device) m_bf16 = deepcopy(m) dataset = m.example_inputs(dataset_size, dtype=original_dtype, device=device) - calibration_data = dataset[:100] + calibration_data = dataset[:50] bf16_out = torch.cat([m_bf16(i.squeeze(0)) for i in dataset], dim=0) - - m_int4wo = deepcopy(m) - quantize_(m_int4wo, int4_weight_only()) - int4wo_out = torch.cat([m_int4wo(i.squeeze(0)) for i in dataset]) - # calibrate - quant_dtype = torch.uint4 - group_size = 128 - insert_awq_observer(m, quant_dtype, group_size, original_dtype, device) + insert_awq_observer_(m, quant_dtype=quant_dtype, group_size=group_size) for example in calibration_data: m(example.to(device)) # print('calibrated') # quantize is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) - quantize_(m, awq_quant(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear) + quantize_(m, awq_uintx(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear) awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) - m = torch.compile(m, fullgraph=True) - # compare accuracy - awq_err = torch.sum(torch.abs(awq_out - bf16_out)).sum().item() / dataset_size - int4wo_err = torch.sum(torch.abs(int4wo_out - bf16_out)).sum().item() / dataset_size - print(f"AWQ error: {awq_err}") - print(f"Int4WO error: {int4wo_err}") + + assert awq_out is not None diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 8cb41d3787..d26ab675bd 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -18,6 +18,7 @@ int8_weight_only, int8_dynamic_activation_int8_weight, fpx_weight_only, + uintx_weight_only, unwrap_tensor_subclass, ) from torchao._models._eval import TransformerEvalWrapper, InputRecorder @@ -53,8 +54,7 @@ def run_evaluation( print("Loading model ...") t0 = time.time() - model = _load_model(checkpoint_path, "cpu", precision).to(device) - print(model) + model = _load_model(checkpoint_path, "cpu", precision) if max_length is None: max_length = model.config.block_size @@ -69,12 +69,26 @@ def run_evaluation( quantize_(model, int8_weight_only()) if "int8dq" in quantization: quantize_(model, int8_dynamic_activation_int8_weight()) + if "fp6" in quantization: + quantize_(model, fpx_weight_only(3, 2)) if "int4wo" in quantization and not "gptq" in quantization: + if "hqq" in quantization: + quantization = quantization[:-4] + use_hqq = True + else: + use_hqq = False groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - quantize_(model.to(device), int4_weight_only(group_size=groupsize)) - if "fp6" in quantization: - quantize_(model, fpx_weight_only(3, 2)) + quantize_(model.to(device), int4_weight_only(group_size=groupsize, use_hqq=use_hqq)) + if "uintx" in quantization: + # uintx-nbits-group_size + # "uintx-2-64" + _quant_args = quantization.split("-") + nbits = int(_quant_args[1]) + _NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8} + dtype = _NBITS_TO_DTYPE[nbits] + group_size = int(_quant_args[2]) + quantize_(model, uintx_weight_only(dtype, group_size)) if "int4wo" in quantization and "gptq" in quantization: groupsize=int(quantization.split("-")[-2]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" @@ -82,7 +96,6 @@ def run_evaluation( assert "cuda" in device, "int4 gptq quantization only works on cuda" inputs = InputRecorder( tokenizer, - model, calibration_seq_length, prepare_inputs_for_model, pad_calibration_inputs, @@ -96,23 +109,6 @@ def run_evaluation( quantizer = Int4WeightOnlyGPTQQuantizer(groupsize=groupsize, device=device) model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) model = quantizer.quantize(model, inputs).to(device) - elif "awq" in quantization: - from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer, awq_quant - insert_awq_observer(model, precision, device) - InputRecorder( - tokenizer, - model, - calibration_seq_length, - prepare_inputs_for_model, - pad_calibration_inputs, - model.config.vocab_size, - device=device - ).record_inputs( - calibration_tasks, - calibration_limit, - ).get_inputs() - is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) - quantize_(model, awq_quant, is_observed_linear) else: if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(model) @@ -139,7 +135,7 @@ def run_evaluation( parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') - parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-, int4wo--gptq") + parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-, int4wo--gptq, int4wo--hqq, uintx--") parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') @@ -161,4 +157,4 @@ def run_evaluation( args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, - ) + ) \ No newline at end of file diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index fc54cd8c1d..9f406b14e7 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1009,7 +1009,8 @@ def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): _aqt_is_uint4(weight_tensor) and weight_tensor.dtype == torch.bfloat16 and len(weight_tensor.shape) == 2 and - weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT + weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT and + isinstance(weight_tensor.layout_type, TensorCoreTiledLayoutType) ) diff --git a/torchao/dtypes/uintx/Uintx.py b/torchao/dtypes/uintx/Uintx.py index 067f4924fd..980ab533bd 100644 --- a/torchao/dtypes/uintx/Uintx.py +++ b/torchao/dtypes/uintx/Uintx.py @@ -32,7 +32,7 @@ _BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()} else: - print("uintx feature need torch 2.3+, please upgrade pytorch") + print("uintx feature requires torch 2.3+, please upgrade pytorch") class UintxTensor(torch.Tensor): diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index fe78379e42..9220e69720 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -1,54 +1,52 @@ import torch import torch.nn.functional as F -from torchao.prototype.awq.core import AWQObserver, ObservedLinear, AWQLayoutType +from torchao.prototype.awq.core import AWQObserver, ObservedLinear, AwqLayoutType from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, ) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.dtypes import to_affine_quantized_intx -from torchao.dtypes.uintx.Uintx import to_uintx +from torchao.dtypes.uintx.Uintx import _DTYPE_TO_BIT_WIDTH from typing import Optional, Tuple +assert len(_DTYPE_TO_BIT_WIDTH) > 0, "Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+" + -def insert_awq_observer(model: torch.nn.Module, quant_dtype: torch.dtype, group_size: int, input_dtype: torch.dtype, device: torch.device): +def insert_awq_observer_(model: torch.nn.Module, quant_dtype: torch.dtype = torch.uint4, scale_search_space_size: int = 20, group_size: int = 128): + """ + Inserts AWQObserver into Linear layers of a given model. + + Args: + model: The model to be modified (in place). Ensure model is on the desired device for calibration + quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8 + scale search space size: how many different scale options to try. Original AWQ implementation uses 20. A larger size can lead to better results but takes longer to calibrate + group_size: Quantization granularity. Use -1 for channel wise quantization + """ _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) - if quant_dtype == torch.uint4: - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - target_dtype = torch.uint4 - quant_min = 0 - quant_max = 15 - eps = torch.finfo(torch.float32).eps - preserve_zero = True - zero_point_dtype = torch.int64 - zero_point_domain = ZeroPointDomain.INT - elif quant_dtype == torch.int8: - mapping_type = MappingType.SYMMETRIC - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - zero_point_domain = ZeroPointDomain.INT - preserve_zero = True - block_size = (1, -1) - quant_min = None - quant_max = None - - else: - raise NotImplementedError(f"{quant_dtype} not supported. Use either torch.uint4 or torch.int8") + # AQT config + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + quant_min = 0 + quant_max = 255 if quant_dtype == torch.uint8 else 2 ** _DTYPE_TO_BIT_WIDTH[quant_dtype] - 1 + eps = torch.finfo(torch.float32).eps + preserve_zero = True + zero_point_dtype = torch.int64 + zero_point_domain = ZeroPointDomain.INT + assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" def replace_with_observer(layer): + # creates observer and replaces linear layers with observed linear layers observer = AWQObserver( layer.weight, layer.bias, block_size, - input_dtype, mapping_type, - target_dtype, - device, + quant_dtype, + scale_search_space_size, preserve_zero = preserve_zero, zero_point_domain = zero_point_domain, zero_point_dtype = zero_point_dtype, @@ -59,7 +57,14 @@ def replace_with_observer(layer): _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) def _observed_linear_subclass_inserter(constructor): + """ + Replaces unquantized observed linear instances with quantized linear instances. + + Args: + constructor: the function which applies quantization to the observed linear layer + """ def insert_subclass(observed_linear): + # creates the new linear layer using constructor linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, observed_linear.bias!=None, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype) linear.weight = torch.nn.Parameter(constructor(observed_linear), requires_grad=False) linear.bias = observed_linear.bias @@ -67,37 +72,30 @@ def insert_subclass(observed_linear): return insert_subclass -def awq_quant(quant_dtype = torch.uint4, group_size = 128): +def awq_uintx(quant_dtype: torch.dtype = torch.uint4, group_size: int = 128): + """ + Quantizes linear layers when passed into quantize_() + + Args: + quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8 + group_size: Quantization granularity. Use -1 for channel wise quantization + """ + assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" def weight_quant_func(observed_linear): # weight quantization equalization_scale = observed_linear.act_obs.calculate_qparams() - if quant_dtype == torch.uint4: - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - target_dtype = torch.uint8 - quant_min = 0 - quant_max = 15 - eps = torch.finfo(torch.float32).eps - preserve_zero = True - zero_point_dtype = torch.int64 - zero_point_domain = ZeroPointDomain.INT - layout_type = AWQLayoutType(equalization_scale, quant_dtype) - - elif quant_dtype == torch.int8: - mapping_type = MappingType.SYMMETRIC - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - zero_point_domain = ZeroPointDomain.INT - preserve_zero = True - block_size = (1, -1) - quant_min = None - quant_max = None - layout_type = AWQLayoutType(equalization_scale, quant_dtype) - - else: - raise("AWQ supports only uint4 and int8 quantization for now") + # AQT config + target_dtype = torch.uint8 + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + quant_min = 0 + quant_max = 255 if quant_dtype == torch.uint8 else 2 ** _DTYPE_TO_BIT_WIDTH[quant_dtype] - 1 + eps = torch.finfo(torch.float32).eps + preserve_zero = True + zero_point_dtype = torch.int64 + zero_point_domain = ZeroPointDomain.INT + layout_type = AwqLayoutType(equalization_scale, quant_dtype) return to_affine_quantized_intx( observed_linear.weight, diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index 035e9c9a51..7d4e00309b 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -27,10 +27,8 @@ def __init__(self, weight: torch.Tensor, bias: torch.Tensor, block_size: Tuple, - input_dtype: torch.dtype, mapping_type: MappingType, target_dtype: torch.dtype, - device: str, scale_search_space_size: int = 20, quant_min: Optional[int] = None, quant_max: Optional[int] = None, @@ -40,6 +38,26 @@ def __init__(self, preserve_zero: Optional[bool] = True, zero_point_domain = ZeroPointDomain.INT, ): + """ + A custom observer for Activation aware Weight Quantization (AWQ) + + Args: + weight: The weight tensor to be observed. + bias: The bias tensor to be observed. + block_size: The granularity of the quantization. + input_dtype: The data type of the input tensor. + mapping_type: Always set to asymmetric + target_dtype: The target data type of the quantized tensor + scale_search_space_size: The number of scales to search for. + quant_min: The minimum quantized value + quant_max: The maximum quantized value + eps: The minimum scale. + scale_dtype: The data type of the scale tensor. + zero_point_dtype: The data type of the zero point tensor. + preserve_zero: A flag to indicate whether we need zero to be exactly + representable or not. + zero_point_domain: The domain of the zero point. + """ super().__init__( mapping_type, target_dtype, @@ -56,8 +74,9 @@ def __init__(self, self.bias = bias self.scale_options = scale_search_space_size self.scales = None - self.device = device - + self.device = self.weight.device + if self.bias is not None: + self.bias.to(self.device) @torch.no_grad() def forward(self, input: torch.Tensor, output: torch.Tensor): average = input.abs().view(-1,input.shape[-1]).mean(0) @@ -68,8 +87,9 @@ def forward(self, input: torch.Tensor, output: torch.Tensor): ratio = i * 1 / self.scale_options scales = average.pow(ratio) scales = scales / (scales.max() * scales.min()).sqrt() - layout = AWQLayoutType(scales, self.target_dtype) - tensor_dtype = torch.int8 if self.target_dtype == torch.int8 else torch.uint8 + layout = AwqLayoutType(scales, self.target_dtype) + # regardless of weight dtype, we have to store as packed uint8 tensors + tensor_dtype = torch.uint8 w = to_affine_quantized_intx( self.weight.data, self.mapping_type, @@ -114,7 +134,7 @@ def from_float(cls, float_linear: torch.nn.Linear, act_obs: AWQObserver): @dataclass(frozen=True) -class AWQLayoutType(LayoutType): +class AwqLayoutType(LayoutType): equalization_scale: torch.Tensor dtype: torch.dtype @@ -122,21 +142,26 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: return input * self.equalization_scale def post_process(self, input: torch.Tensor) -> torch.Tensor: - - return to_uintx(input, self.dtype) + # pack weights for sub dtype bit size + if self.dtype != torch.uint8: + return to_uintx(input, self.dtype) + return input def _quantized_linear_impl(input_tensor, weight_tensor, bias): + # divide activations by awq scales return F.linear(input_tensor / weight_tensor.layout_tensor.layout_type.equalization_scale, weight_tensor.dequantize(), bias) def _linear_awq_check(input_tensor, weight_tensor, bias): - return isinstance(weight_tensor.layout_tensor, AWQ_AQTLayout) + return isinstance(weight_tensor.layout_tensor, AwqAQTLayout) -register_aqt_quantized_linear_dispatch(AWQLayoutType._linear_awq_check, AWQLayoutType._quantized_linear_impl) +register_aqt_quantized_linear_dispatch(AwqLayoutType._linear_awq_check, AwqLayoutType._quantized_linear_impl) -@register_layout_cls(AWQLayoutType) -class AWQ_AQTLayout(PlainAQTLayout): +@register_layout_cls(AwqLayoutType) +class AwqAQTLayout(PlainAQTLayout): def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return self.int_data.get_plain(), self.scale, self.zero_point + # unpack if needed + w = self.int_data if self.layout_type.dtype == torch.uint8 else self.int_data.get_plain() + return w, self.scale, self.zero_point @classmethod def from_plain( @@ -146,4 +171,5 @@ def from_plain( zero_point: torch.Tensor, layout_type: LayoutType, ): + assert isinstance(layout_type, AwqLayoutType) return cls(int_data, scale, zero_point, layout_type) \ No newline at end of file diff --git a/tutorials/calibration_flow/AWQ.py b/torchao/prototype/awq/example.py similarity index 60% rename from tutorials/calibration_flow/AWQ.py rename to torchao/prototype/awq/example.py index 72f8dc8eb4..cb35a6038c 100644 --- a/tutorials/calibration_flow/AWQ.py +++ b/torchao/prototype/awq/example.py @@ -1,6 +1,6 @@ import torch from torchao.quantization import quantize_, int4_weight_only, int8_weight_only -from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer, awq_quant +from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer_, awq_uintx import argparse from transformers import AutoModelForCausalLM, AutoTokenizer from datasets import load_dataset @@ -8,9 +8,9 @@ import time # adapted from: https://github.com/mit-han-lab/llm-awq -def get_calib_dataset(tokenizer=None, n_samples=512, device="cuda"): +def get_calib_dataset(tokenizer=None, n_samples=128, device="cuda"): dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") - block_size=512 + block_size=1024 samples = [] n_run = 0 for data in dataset: @@ -29,10 +29,10 @@ def get_calib_dataset(tokenizer=None, n_samples=512, device="cuda"): cat_samples = torch.cat(samples, dim=1) n_split = cat_samples.shape[1] // block_size - print(f" * Split into {n_split} blocks") - return torch.cat([ + # print(f" * Split into {n_split} blocks") + return [ cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_split) - ], dim=0) + ][0] def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size: int = 128, device="cuda", precision=torch.bfloat16, max_length=2048, compile=False): print("Loading model ...") @@ -45,12 +45,12 @@ def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size if quant.startswith("awq"): quant_dtype = quant.split("-")[1] - quant_dtype = getattr(torch, quant_dtype, torch.bfloat16) + quant_dtype = getattr(torch, quant_dtype, torch.uint4) print(f"running {quant_dtype} calibration") t0 = time.time() # insert observers to find average magnitude and calculate scales - insert_awq_observer(model, quant_dtype, group_size, precision, device) + insert_awq_observer_(model, quant_dtype=quant_dtype, group_size=group_size) calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibrate_size) model(calibration_data.to(device)) print(f"time for calibration: {time.time() - t0:.02f} seconds") @@ -58,10 +58,10 @@ def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size # use awq_quant() to apply awq quantization is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) t0 = time.time() - quantize_(model, awq_quant(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) + quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) print(f"time for quantization: {time.time() - t0:.02f} seconds") - elif quant=="int8": + elif quant=="uint8": print("running int8 quantization") quantize_(model, int8_weight_only()) @@ -101,45 +101,45 @@ def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size return ppl +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate a model with the specified parameters.") + -parser = argparse.ArgumentParser(description="Evaluate a model with the specified parameters.") - - -# Optional arguments with default values -parser.add_argument("repo", type=str, help="Repository ID of the model.") -parser.add_argument("quant", type=str, help="Quantization method or file path.",choices=["uint4", "int8"]) -parser.add_argument("--calibration_size", type=int, default=100, help="Calibration size. Default is 100.") -parser.add_argument("--group_size", type=int, default=128, help="Group size to use for weights. Default is '128'") -parser.add_argument("--device", type=str, default="cuda", help="Device to run the evaluation on. Default is 'cuda'.") -parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") -parser.add_argument("--max_length", type=int, default=2048, help="Maximum length for evaluation. Default is 2048.") -parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") - -args = parser.parse_args() - -# Convert precision argument to torch dtype -precision_dtype = getattr(torch, args.precision, torch.bfloat16) - -awq = wikitext2_ppl( - repo_id=args.repo, - quant="awq-"+args.quant, - calibrate_size=args.calibration_size, - group_size= args.group_size, - device=args.device, - precision=precision_dtype, - max_length=args.max_length, - compile=args.compile -) - -aqt = wikitext2_ppl( - repo_id=args.repo, - quant=args.quant, - calibrate_size=args.calibration_size, - group_size= args.group_size, - device=args.device, - precision=precision_dtype, - max_length=args.max_length, - compile=args.compile -) -print(f"AWQ Perplexity: {awq.item():.5f}") -print(f"Affine quantized Perplexity: {aqt.item():.5f}") \ No newline at end of file + # Optional arguments with default values + parser.add_argument("repo", type=str, help="Repository ID of the model.") + parser.add_argument("quant", type=str, help="Quantization method or file path.",choices=["uint4", "uint8"]) + parser.add_argument("--calibration_size", type=int, default=100, help="Calibration size. Default is 100.") + parser.add_argument("--group_size", type=int, default=128, help="Group size to use for weights. Default is '128'") + parser.add_argument("--device", type=str, default="cuda", help="Device to run the evaluation on. Default is 'cuda'.") + parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") + parser.add_argument("--max_length", type=int, default=2048, help="Maximum length for evaluation. Default is 2048.") + parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") + + args = parser.parse_args() + + # Convert precision argument to torch dtype + precision_dtype = getattr(torch, args.precision, torch.bfloat16) + + awq = wikitext2_ppl( + repo_id=args.repo, + quant="awq-"+args.quant, + calibrate_size=args.calibration_size, + group_size= args.group_size, + device=args.device, + precision=precision_dtype, + max_length=args.max_length, + compile=args.compile + ) + + aqt = wikitext2_ppl( + repo_id=args.repo, + quant=args.quant, + calibrate_size=args.calibration_size, + group_size= args.group_size, + device=args.device, + precision=precision_dtype, + max_length=args.max_length, + compile=args.compile + ) + print(f"AWQ Perplexity: {awq.item():.5f}") + print(f"Affine quantized Perplexity: {aqt.item():.5f}") \ No newline at end of file diff --git a/torchao/prototype/awq/readme.md b/torchao/prototype/awq/readme.md new file mode 100644 index 0000000000..e6f35f5925 --- /dev/null +++ b/torchao/prototype/awq/readme.md @@ -0,0 +1,13 @@ +# AWQ Quantization +Ported from https://github.com/mit-han-lab/llm-awq + +## Benchmarks +Benchmarks are run on a machine with a single A100 GPU using the script in _models/llama which generates text in a latency optimized way (batchsize=1), evaluation was done +Using the lm_eval. The models used were meta-llama/Llama-2-7b-chat-hf + +| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | +| ----------- | ------------------ | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | +| Llama-2-7B | Base (bfloat16) | 12.212 | 105.14 | 1389.35 | 13.88 | 13.21 | +| | int4wo-64 | 12.843 | 199.86 | 746.66 | 4.50 | 3.74 | +| | int4wo-64-GPTQ | 12.489 | 199.86 | 746.66 | 4.50 | 3.74 | +| | awq | 12.204 | 159.22 | 1069.87 | 8.91 | 6.72 | \ No newline at end of file From dced6e55780c482277fa1999d894e05545297bf6 Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Thu, 5 Sep 2024 19:45:43 -0400 Subject: [PATCH 15/69] updated benchmark --- test/prototype/test_awq.py | 1 - torchao/prototype/awq/api.py | 3 - torchao/prototype/awq/example.py | 163 ++++++++++++++++--------------- torchao/prototype/awq/readme.md | 16 ++- 4 files changed, 91 insertions(+), 92 deletions(-) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 63ff6a98ad..351c23e3cc 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -6,7 +6,6 @@ from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - class ToyLinearModel(torch.nn.Module): def __init__(self, m=512, n=256, k=128): super().__init__() diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 9220e69720..a3bf3cec0d 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -12,9 +12,6 @@ assert len(_DTYPE_TO_BIT_WIDTH) > 0, "Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+" - - - def insert_awq_observer_(model: torch.nn.Module, quant_dtype: torch.dtype = torch.uint4, scale_search_space_size: int = 20, group_size: int = 128): """ Inserts AWQObserver into Linear layers of a given model. diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index cb35a6038c..ab10f6270a 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -1,16 +1,17 @@ import torch -from torchao.quantization import quantize_, int4_weight_only, int8_weight_only -from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer_, awq_uintx import argparse from transformers import AutoModelForCausalLM, AutoTokenizer from datasets import load_dataset from tqdm import tqdm import time +from torchao.prototype.awq.api import insert_awq_observer_, ObservedLinear, awq_uintx +from torchao.quantization import quantize_, int4_weight_only, int8_weight_only, Int4WeightOnlyGPTQQuantizer + # adapted from: https://github.com/mit-han-lab/llm-awq -def get_calib_dataset(tokenizer=None, n_samples=128, device="cuda"): +def get_calib_dataset(tokenizer=None, n_samples=512, device="cuda"): dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") - block_size=1024 + block_size=512 samples = [] n_run = 0 for data in dataset: @@ -29,10 +30,38 @@ def get_calib_dataset(tokenizer=None, n_samples=128, device="cuda"): cat_samples = torch.cat(samples, dim=1) n_split = cat_samples.shape[1] // block_size - # print(f" * Split into {n_split} blocks") - return [ + return torch.cat([ cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_split) - ][0] + ], dim=0) + +def eval(model, tokenizer, max_length): + testenc = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + testenc = tokenizer("\n\n".join(testenc["text"]), return_tensors="pt") + testenc = testenc.input_ids.to(model.device) + nsamples = testenc.numel() // max_length + model = model.eval() + # calculate perplexity + nlls = [] + for i in tqdm(range(nsamples), desc="evaluating..."): + batch = testenc[:, (i * max_length) : ((i + 1) * max_length)].to( + model.device + ) + with torch.no_grad(): + lm_logits = model(batch).logits + shift_logits = lm_logits[:, :-1, :].contiguous().float() + shift_labels = testenc[ + :, (i * max_length) : ((i + 1) * max_length) + ][:, 1:] + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + neg_log_likelihood = loss.float() * max_length + nlls.append(neg_log_likelihood) + + ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * max_length)) + + return ppl def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size: int = 128, device="cuda", precision=torch.bfloat16, max_length=2048, compile=False): print("Loading model ...") @@ -45,7 +74,7 @@ def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size if quant.startswith("awq"): quant_dtype = quant.split("-")[1] - quant_dtype = getattr(torch, quant_dtype, torch.uint4) + quant_dtype = getattr(torch, quant_dtype, torch.bfloat16) print(f"running {quant_dtype} calibration") t0 = time.time() @@ -60,86 +89,62 @@ def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size t0 = time.time() quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) print(f"time for quantization: {time.time() - t0:.02f} seconds") - - elif quant=="uint8": + elif quant=="gptq": + calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibrate_size) + quantizer = Int4WeightOnlyGPTQQuantizer() + model = quantizer.quantize(model, calibration_data).to(device) + elif quant=="int8": print("running int8 quantization") quantize_(model, int8_weight_only()) elif quant=="uint4": print("running int4 quantization") - quantize_(model, int4_weight_only()) + quantize_(model, int4_weight_only(group_size=64)) if compile: model = torch.compile(model) - # eval on wikitext2 - testenc = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") - testenc = tokenizer("\n\n".join(testenc["text"]), return_tensors="pt") - testenc = testenc.input_ids.to(model.device) - nsamples = testenc.numel() // max_length - model = model.eval() - # calculate perplexity - nlls = [] - for i in tqdm(range(nsamples), desc="evaluating..."): - batch = testenc[:, (i * max_length) : ((i + 1) * max_length)].to( - model.device - ) - with torch.no_grad(): - lm_logits = model(batch).logits - shift_logits = lm_logits[:, :-1, :].contiguous().float() - shift_labels = testenc[ - :, (i * max_length) : ((i + 1) * max_length) - ][:, 1:] - loss_fct = torch.nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) - ) - neg_log_likelihood = loss.float() * max_length - nlls.append(neg_log_likelihood) + return eval(model, tokenizer, max_length) - ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * max_length)) - - return ppl -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Evaluate a model with the specified parameters.") - +parser = argparse.ArgumentParser(description="Evaluate a model with the specified parameters.") + - # Optional arguments with default values - parser.add_argument("repo", type=str, help="Repository ID of the model.") - parser.add_argument("quant", type=str, help="Quantization method or file path.",choices=["uint4", "uint8"]) - parser.add_argument("--calibration_size", type=int, default=100, help="Calibration size. Default is 100.") - parser.add_argument("--group_size", type=int, default=128, help="Group size to use for weights. Default is '128'") - parser.add_argument("--device", type=str, default="cuda", help="Device to run the evaluation on. Default is 'cuda'.") - parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") - parser.add_argument("--max_length", type=int, default=2048, help="Maximum length for evaluation. Default is 2048.") - parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") - - args = parser.parse_args() - - # Convert precision argument to torch dtype - precision_dtype = getattr(torch, args.precision, torch.bfloat16) - - awq = wikitext2_ppl( - repo_id=args.repo, - quant="awq-"+args.quant, - calibrate_size=args.calibration_size, - group_size= args.group_size, - device=args.device, - precision=precision_dtype, - max_length=args.max_length, - compile=args.compile - ) - - aqt = wikitext2_ppl( - repo_id=args.repo, - quant=args.quant, - calibrate_size=args.calibration_size, - group_size= args.group_size, - device=args.device, - precision=precision_dtype, - max_length=args.max_length, - compile=args.compile - ) - print(f"AWQ Perplexity: {awq.item():.5f}") - print(f"Affine quantized Perplexity: {aqt.item():.5f}") \ No newline at end of file +# Optional arguments with default values +parser.add_argument("repo", type=str, help="Repository ID of the model.") +parser.add_argument("quant", type=str, help="Quantization method or file path.",choices=["uint4", "int8","gptq"]) +parser.add_argument("--calibration_size", type=int, default=100, help="Calibration size. Default is 100.") +parser.add_argument("--group_size", type=int, default=128, help="Group size to use for weights. Default is '128'") +parser.add_argument("--device", type=str, default="cuda", help="Device to run the evaluation on. Default is 'cuda'.") +parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") +parser.add_argument("--max_length", type=int, default=2048, help="Maximum length for evaluation. Default is 2048.") +parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") + +args = parser.parse_args() + +# Convert precision argument to torch dtype +precision_dtype = getattr(torch, args.precision, torch.bfloat16) + +# awq = wikitext2_ppl( +# repo_id=args.repo, +# quant="awq-"+args.quant, +# calibrate_size=args.calibration_size, +# group_size= args.group_size, +# device=args.device, +# precision=precision_dtype, +# max_length=args.max_length, +# compile=args.compile +# ) + +aqt = wikitext2_ppl( + repo_id=args.repo, + quant=args.quant, + calibrate_size=args.calibration_size, + group_size= args.group_size, + device=args.device, + precision=precision_dtype, + max_length=args.max_length, + compile=args.compile +) +# print(f"AWQ Perplexity: {awq.item():.5f}") +print(f"Affine quantized Perplexity: {aqt.item():.5f}") \ No newline at end of file diff --git a/torchao/prototype/awq/readme.md b/torchao/prototype/awq/readme.md index e6f35f5925..ee0b819f98 100644 --- a/torchao/prototype/awq/readme.md +++ b/torchao/prototype/awq/readme.md @@ -1,13 +1,11 @@ # AWQ Quantization -Ported from https://github.com/mit-han-lab/llm-awq +Adapted from https://github.com/mit-han-lab/llm-awq ## Benchmarks -Benchmarks are run on a machine with a single A100 GPU using the script in _models/llama which generates text in a latency optimized way (batchsize=1), evaluation was done -Using the lm_eval. The models used were meta-llama/Llama-2-7b-chat-hf +Benchmarks are run on a machine with a single RTX 3090 GPU using the script in awq/example.py which calculates perplexity by concatenating wikitex2 test examples with newlines and dividing by context length. The model used was openai-community/gpt2 with a context length of 1024. Group size of 64 was used for both int4wo and awq-uint4. -| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | -| ----------- | ------------------ | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | -| Llama-2-7B | Base (bfloat16) | 12.212 | 105.14 | 1389.35 | 13.88 | 13.21 | -| | int4wo-64 | 12.843 | 199.86 | 746.66 | 4.50 | 3.74 | -| | int4wo-64-GPTQ | 12.489 | 199.86 | 746.66 | 4.50 | 3.74 | -| | awq | 12.204 | 159.22 | 1069.87 | 8.91 | 6.72 | \ No newline at end of file +| Quantization | wikitext2-perplexity | +| ------------------------ | ------------------- | +| Base (bfloat16) | 30.1904 | +| int4wo (tinygemm kernel) | 519.73108 | +| awq-uint4 | 485.54907 | \ No newline at end of file From c43b997d52454becfd2844d2d82cb062c9356f93 Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Thu, 5 Sep 2024 19:52:45 -0400 Subject: [PATCH 16/69] update example --- torchao/prototype/awq/example.py | 28 +++++----------------------- torchao/prototype/awq/readme.md | 3 ++- 2 files changed, 7 insertions(+), 24 deletions(-) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index ab10f6270a..1969fed182 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -89,15 +89,8 @@ def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size t0 = time.time() quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) print(f"time for quantization: {time.time() - t0:.02f} seconds") - elif quant=="gptq": - calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibrate_size) - quantizer = Int4WeightOnlyGPTQQuantizer() - model = quantizer.quantize(model, calibration_data).to(device) - elif quant=="int8": - print("running int8 quantization") - quantize_(model, int8_weight_only()) - elif quant=="uint4": + elif quant=="int4": print("running int4 quantization") quantize_(model, int4_weight_only(group_size=64)) @@ -112,7 +105,7 @@ def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size # Optional arguments with default values parser.add_argument("repo", type=str, help="Repository ID of the model.") -parser.add_argument("quant", type=str, help="Quantization method or file path.",choices=["uint4", "int8","gptq"]) +parser.add_argument("quant", type=str, help="Quantization method. Options are either int4 or awq-uintx where x is [1..8]") parser.add_argument("--calibration_size", type=int, default=100, help="Calibration size. Default is 100.") parser.add_argument("--group_size", type=int, default=128, help="Group size to use for weights. Default is '128'") parser.add_argument("--device", type=str, default="cuda", help="Device to run the evaluation on. Default is 'cuda'.") @@ -125,18 +118,7 @@ def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size # Convert precision argument to torch dtype precision_dtype = getattr(torch, args.precision, torch.bfloat16) -# awq = wikitext2_ppl( -# repo_id=args.repo, -# quant="awq-"+args.quant, -# calibrate_size=args.calibration_size, -# group_size= args.group_size, -# device=args.device, -# precision=precision_dtype, -# max_length=args.max_length, -# compile=args.compile -# ) - -aqt = wikitext2_ppl( +ppl = wikitext2_ppl( repo_id=args.repo, quant=args.quant, calibrate_size=args.calibration_size, @@ -146,5 +128,5 @@ def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size max_length=args.max_length, compile=args.compile ) -# print(f"AWQ Perplexity: {awq.item():.5f}") -print(f"Affine quantized Perplexity: {aqt.item():.5f}") \ No newline at end of file + +print(f"{args.quant} Perplexity: {ppl.item():.5f}") \ No newline at end of file diff --git a/torchao/prototype/awq/readme.md b/torchao/prototype/awq/readme.md index ee0b819f98..cd22a49ce8 100644 --- a/torchao/prototype/awq/readme.md +++ b/torchao/prototype/awq/readme.md @@ -8,4 +8,5 @@ Benchmarks are run on a machine with a single RTX 3090 GPU using the script in a | ------------------------ | ------------------- | | Base (bfloat16) | 30.1904 | | int4wo (tinygemm kernel) | 519.73108 | -| awq-uint4 | 485.54907 | \ No newline at end of file +| awq-uint4 | 485.54907 | +| awq-uint6 | 37.32335 | \ No newline at end of file From 2388091c46be6dfd1bec0548a9f3e8f1ed47510e Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Thu, 5 Sep 2024 19:52:55 -0400 Subject: [PATCH 17/69] update example --- torchao/prototype/awq/example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 1969fed182..6ee454ea2c 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -84,7 +84,7 @@ def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size model(calibration_data.to(device)) print(f"time for calibration: {time.time() - t0:.02f} seconds") - # use awq_quant() to apply awq quantization + # use awq_uintx() to apply awq quantization is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) t0 = time.time() quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) From 8619cd5a825fbf05c2cad11bafbc35a43b8a8a74 Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Fri, 6 Sep 2024 17:21:56 -0400 Subject: [PATCH 18/69] added init file --- scripts/create_weight_map.py | 36 ------------------------------- torchao/prototype/awq/__init__.py | 0 2 files changed, 36 deletions(-) delete mode 100644 scripts/create_weight_map.py create mode 100644 torchao/prototype/awq/__init__.py diff --git a/scripts/create_weight_map.py b/scripts/create_weight_map.py deleted file mode 100644 index 697dbc6caf..0000000000 --- a/scripts/create_weight_map.py +++ /dev/null @@ -1,36 +0,0 @@ -""" -This file produces a file named pytorch_model.bin.index.json based on the downloaded model weights. -""" -import json -import torch -from transformers import AutoModel - -def create_weight_map(model_name): - # Load the model - model = AutoModel.from_pretrained(model_name) - - # Get the state dict - state_dict = model.state_dict() - - # Create the weight map - weight_map = {} - for key, tensor in state_dict.items(): - # In this example, we're assuming all weights are in a single file - # You may need to adjust this if your model uses sharded weights - weight_map[key] = "pytorch_model.bin" - - # Create the index dictionary - index_dict = { - "metadata": {"total_size": sum(param.numel() * param.element_size() for param in model.parameters())}, - "weight_map": weight_map - } - - # Save the index dictionary to a JSON file - with open("pytorch_model.bin.index.json", "w") as f: - json.dump(index_dict, f, indent=2) - - print("Created pytorch_model.bin.index.json") - -# Usage -model_name = "checkpoints/Xenova/llama2.c-stories15M" -create_weight_map(model_name) \ No newline at end of file diff --git a/torchao/prototype/awq/__init__.py b/torchao/prototype/awq/__init__.py new file mode 100644 index 0000000000..e69de29bb2 From 9027082c5052ed64386881e68a75b3e95c233b94 Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Sun, 8 Sep 2024 14:18:11 -0400 Subject: [PATCH 19/69] reduce vram for calibration --- test/prototype/test_awq.py | 5 +-- torchao/prototype/awq/__init__.py | 3 ++ torchao/prototype/awq/api.py | 12 ++++--- torchao/prototype/awq/core.py | 57 +++++++++++++++++++++-------- torchao/prototype/awq/example.py | 59 ++++++++++++++++--------------- torchao/quantization/observer.py | 18 ++++++++++ 6 files changed, 105 insertions(+), 49 deletions(-) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 351c23e3cc..e0f57ab6e2 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -2,9 +2,10 @@ import pytest import torch from torchao.quantization import quantize_ -from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer_, awq_uintx -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 +if TORCH_VERSION_AT_LEAST_2_3: + from torchao.prototype.awq import insert_awq_observer_, awq_uintx, ObservedLinear class ToyLinearModel(torch.nn.Module): def __init__(self, m=512, n=256, k=128): diff --git a/torchao/prototype/awq/__init__.py b/torchao/prototype/awq/__init__.py index e69de29bb2..ce128e1aa2 100644 --- a/torchao/prototype/awq/__init__.py +++ b/torchao/prototype/awq/__init__.py @@ -0,0 +1,3 @@ +from .api import insert_awq_observer_, awq_uintx +from .core import ObservedLinear +from .example import get_calib_dataset \ No newline at end of file diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index a3bf3cec0d..9bf976c57c 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -8,16 +8,18 @@ from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.dtypes import to_affine_quantized_intx from torchao.dtypes.uintx.Uintx import _DTYPE_TO_BIT_WIDTH -from typing import Optional, Tuple +from typing import List, Optional, Tuple assert len(_DTYPE_TO_BIT_WIDTH) > 0, "Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+" -def insert_awq_observer_(model: torch.nn.Module, quant_dtype: torch.dtype = torch.uint4, scale_search_space_size: int = 20, group_size: int = 128): +def insert_awq_observer_(model: torch.nn.Module, n_validation_examples: int, validation_sequence_len: int, quant_dtype: torch.dtype = torch.uint4, scale_search_space_size: int = 20, group_size: int = 128): """ Inserts AWQObserver into Linear layers of a given model. Args: model: The model to be modified (in place). Ensure model is on the desired device for calibration + validation_sequence_len: Number of tokens in each validation example + n_validation_examples: Number of examples used to validate scale options quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8 scale search space size: how many different scale options to try. Original AWQ implementation uses 20. A larger size can lead to better results but takes longer to calibrate group_size: Quantization granularity. Use -1 for channel wise quantization @@ -43,6 +45,8 @@ def replace_with_observer(layer): block_size, mapping_type, quant_dtype, + n_validation_examples, + validation_sequence_len, scale_search_space_size, preserve_zero = preserve_zero, zero_point_domain = zero_point_domain, @@ -69,7 +73,7 @@ def insert_subclass(observed_linear): return insert_subclass -def awq_uintx(quant_dtype: torch.dtype = torch.uint4, group_size: int = 128): +def awq_uintx(n_calibration_tokens:int, quant_dtype: torch.dtype = torch.uint4, group_size: int = 128): """ Quantizes linear layers when passed into quantize_() @@ -81,7 +85,7 @@ def awq_uintx(quant_dtype: torch.dtype = torch.uint4, group_size: int = 128): assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" def weight_quant_func(observed_linear): # weight quantization - equalization_scale = observed_linear.act_obs.calculate_qparams() + equalization_scale = observed_linear.act_obs.calculate_qparams(n_calibration_tokens) # AQT config target_dtype = torch.uint8 mapping_type = MappingType.ASYMMETRIC diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index 7d4e00309b..4f76a56635 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -18,7 +18,7 @@ ZeroPointDomain, ) from torchao.quantization.observer import ( - AffineQuantizedObserverBase, + AffineQuantizedObserverBase, PerGroup ) @@ -29,6 +29,8 @@ def __init__(self, block_size: Tuple, mapping_type: MappingType, target_dtype: torch.dtype, + n_validation_examples: int, + validation_sequence_len: int, scale_search_space_size: int = 20, quant_min: Optional[int] = None, quant_max: Optional[int] = None, @@ -44,10 +46,12 @@ def __init__(self, Args: weight: The weight tensor to be observed. bias: The bias tensor to be observed. - block_size: The granularity of the quantization. + block_size: The weight tensor shape after being reshaped to support per group quantization input_dtype: The data type of the input tensor. mapping_type: Always set to asymmetric target_dtype: The target data type of the quantized tensor + n_validation_examples: Number of examples used to calibrate observer + validation_sequence_len: Number of tokens in each example scale_search_space_size: The number of scales to search for. quant_min: The minimum quantized value quant_max: The maximum quantized value @@ -61,7 +65,7 @@ def __init__(self, super().__init__( mapping_type, target_dtype, - block_size = block_size, + PerGroup(block_size[-1]), quant_min = quant_min, quant_max = quant_max, eps = eps, @@ -70,22 +74,43 @@ def __init__(self, preserve_zero = preserve_zero, zero_point_domain = zero_point_domain, ) + self.block_size = block_size self.weight = weight self.bias = bias + self.n_validation_examples = n_validation_examples + self.validation_sequence_len = validation_sequence_len + self.inputs = [] + self.outputs = [] self.scale_options = scale_search_space_size - self.scales = None self.device = self.weight.device + self.average = torch.zeros((1,weight.shape[1]), device= self.device) if self.bias is not None: self.bias.to(self.device) @torch.no_grad() def forward(self, input: torch.Tensor, output: torch.Tensor): - average = input.abs().view(-1,input.shape[-1]).mean(0) + # import pdb + # pdb.set_trace() + # print(input.shape, input.abs().sum(1).shape, self.average.shape) + if len(self.inputs) < self.n_validation_examples: + self.inputs.append(input.to("cpu")) + self.outputs.append(output.to("cpu")) + self.average += input.abs().sum(-2) + + def calculate_qparams(self, n_calibration_tokens): + # import pdb + # pdb.set_trace() + assert self.outputs != None, "calibrate observer first by running model on exemplar data" + self.average /= (n_calibration_tokens) + for i in range(self.n_validation_examples): + self.inputs[i] = self.inputs[i].to(self.device) + self.outputs[i] = self.outputs[i].to(self.device) + best_loss = float('inf') - scaleopts = [] + best_scales = None for i in range(self.scale_options): ratio = i * 1 / self.scale_options - scales = average.pow(ratio) + scales = self.average.pow(ratio) scales = scales / (scales.max() * scales.min()).sqrt() layout = AwqLayoutType(scales, self.target_dtype) # regardless of weight dtype, we have to store as packed uint8 tensors @@ -104,16 +129,17 @@ def forward(self, input: torch.Tensor, output: torch.Tensor): zero_point_domain = self.zero_point_domain, layout_type = layout ) - q_out = F.linear(input/scales, w, self.bias) - scaleopts.append(q_out.mean().item()) - loss = (output - q_out).pow(2).mean().item() + loss = 0 + for i in range(self.n_validation_examples): + q_out = F.linear(self.inputs[i]/scales, w, self.bias) + loss += (self.outputs[i] - q_out).pow(2).mean().item() if loss < best_loss: - self.scales = scales + best_scales = scales best_loss = loss - - def calculate_qparams(self): - return self.scales.detach() - + for i in range(self.n_validation_examples): + self.inputs[i].to("cpu") + self.outputs[i].to("cpu") + return best_scales.detach() class ObservedLinear(torch.nn.Linear): def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None): @@ -149,6 +175,7 @@ def post_process(self, input: torch.Tensor) -> torch.Tensor: def _quantized_linear_impl(input_tensor, weight_tensor, bias): # divide activations by awq scales + # print(input_tensor.dtype, weight_tensor.layout_tensor.layout_type.equalization_scale.dtype, weight_tensor.dequantize().dtype, bias.dtype) return F.linear(input_tensor / weight_tensor.layout_tensor.layout_type.equalization_scale, weight_tensor.dequantize(), bias) def _linear_awq_check(input_tensor, weight_tensor, bias): diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 6ee454ea2c..e88d52f772 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -5,15 +5,15 @@ from tqdm import tqdm import time from torchao.prototype.awq.api import insert_awq_observer_, ObservedLinear, awq_uintx -from torchao.quantization import quantize_, int4_weight_only, int8_weight_only, Int4WeightOnlyGPTQQuantizer +from torchao.quantization import quantize_, int4_weight_only # adapted from: https://github.com/mit-han-lab/llm-awq -def get_calib_dataset(tokenizer=None, n_samples=512, device="cuda"): +def get_calib_dataset(tokenizer=None, n_samples=100, block_size=512): dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") - block_size=512 samples = [] - n_run = 0 + n_tokens = n_samples * block_size + n_run = n_tokens for data in dataset: line = data["text"] line = line.strip() @@ -24,52 +24,49 @@ def get_calib_dataset(tokenizer=None, n_samples=512, device="cuda"): if sample.numel() == 0: continue samples.append(sample) - n_run += 1 - if n_run == n_samples: + n_run -= len(line_encoded) + if n_run <= n_samples: break cat_samples = torch.cat(samples, dim=1) - n_split = cat_samples.shape[1] // block_size - return torch.cat([ - cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_split) - ], dim=0) + return [cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_samples)] -def eval(model, tokenizer, max_length): +def wiki2_eval(model, tokenizer, sequence_length): testenc = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") testenc = tokenizer("\n\n".join(testenc["text"]), return_tensors="pt") testenc = testenc.input_ids.to(model.device) - nsamples = testenc.numel() // max_length + nsamples = testenc.numel() // sequence_length model = model.eval() # calculate perplexity nlls = [] for i in tqdm(range(nsamples), desc="evaluating..."): - batch = testenc[:, (i * max_length) : ((i + 1) * max_length)].to( + batch = testenc[:, (i * sequence_length) : ((i + 1) * sequence_length)].to( model.device ) with torch.no_grad(): lm_logits = model(batch).logits shift_logits = lm_logits[:, :-1, :].contiguous().float() shift_labels = testenc[ - :, (i * max_length) : ((i + 1) * max_length) + :, (i * sequence_length) : ((i + 1) * sequence_length) ][:, 1:] loss_fct = torch.nn.CrossEntropyLoss() loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) - neg_log_likelihood = loss.float() * max_length + neg_log_likelihood = loss.float() * sequence_length nlls.append(neg_log_likelihood) - ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * max_length)) + ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * sequence_length)) return ppl -def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size: int = 128, device="cuda", precision=torch.bfloat16, max_length=2048, compile=False): - print("Loading model ...") +def wikitext2_ppl(repo_id: str, quant: str, calibration_size: int =100, validation_size:int=100, group_size: int = 128, device="cuda", precision=torch.bfloat16, sequence_length=2048, compile=False): + print(f"Loading model on {device}...") torch.manual_seed(34) t0 = time.time() # load any model with torch.nn.linear layers tokenizer = AutoTokenizer.from_pretrained(repo_id) - model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision).to(device) + model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision).eval().to(device) print(f"Time to load model: {time.time() - t0:.02f} seconds") if quant.startswith("awq"): @@ -79,15 +76,19 @@ def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size t0 = time.time() # insert observers to find average magnitude and calculate scales - insert_awq_observer_(model, quant_dtype=quant_dtype, group_size=group_size) - calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibrate_size) - model(calibration_data.to(device)) + insert_awq_observer_(model,validation_size, sequence_length, quant_dtype=quant_dtype, group_size=group_size) + calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibration_size, block_size=sequence_length) + print(calibration_data[0].size(), calibration_data[0].dtype) + for batch in calibration_data: + model(batch.to(device)) + batch.to("cpu") print(f"time for calibration: {time.time() - t0:.02f} seconds") # use awq_uintx() to apply awq quantization is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) t0 = time.time() - quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) + n_calibration_tokens = calibration_size * sequence_length + quantize_(model, awq_uintx(n_calibration_tokens, quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) print(f"time for quantization: {time.time() - t0:.02f} seconds") elif quant=="int4": @@ -97,7 +98,7 @@ def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size if compile: model = torch.compile(model) - return eval(model, tokenizer, max_length) + return wiki2_eval(model, tokenizer, sequence_length) parser = argparse.ArgumentParser(description="Evaluate a model with the specified parameters.") @@ -107,25 +108,27 @@ def wikitext2_ppl(repo_id: str, quant: str, calibrate_size: int =100, group_size parser.add_argument("repo", type=str, help="Repository ID of the model.") parser.add_argument("quant", type=str, help="Quantization method. Options are either int4 or awq-uintx where x is [1..8]") parser.add_argument("--calibration_size", type=int, default=100, help="Calibration size. Default is 100.") +parser.add_argument("--validation_size", type=int, default=100, help="Validation size. Default is 100.") parser.add_argument("--group_size", type=int, default=128, help="Group size to use for weights. Default is '128'") parser.add_argument("--device", type=str, default="cuda", help="Device to run the evaluation on. Default is 'cuda'.") parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") -parser.add_argument("--max_length", type=int, default=2048, help="Maximum length for evaluation. Default is 2048.") +parser.add_argument("--sequence_length", type=int, default=512, help="Length of examples to calibrate/evaluate model on. Default 512") parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") args = parser.parse_args() # Convert precision argument to torch dtype precision_dtype = getattr(torch, args.precision, torch.bfloat16) - +print('hall0o') ppl = wikitext2_ppl( repo_id=args.repo, quant=args.quant, - calibrate_size=args.calibration_size, + calibration_size=args.calibration_size, + validation_size=args.validation_size, group_size= args.group_size, device=args.device, precision=precision_dtype, - max_length=args.max_length, + sequence_length=args.sequence_length, compile=args.compile ) diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index 08a3eacf6b..d32b387edd 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -53,6 +53,24 @@ class PerAxis(GranularityType): """ axis: int +@dataclass(frozen=True) +class PerGroup(GranularityType): + """ + Represents per-channel group granularity in quantization. + + This granularity type calcualtes different quantization parameters + for each group of elements. + + For example if the input tensor is shape [8, 16], and the group size is 4, then + the input tensor is reshaped to [64, 4] + quantization parameters are calculated for each group of 4 elements, + giving a total of 64 quantization parameters. + + Attributes: + group_size (int): The size of each quantization group + + """ + group_size: int # borrowed from torch.ao.quantization.observer class _PartialWrapper: From dbac7c8a0198a20d51084f4f988dca4f6396c395 Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Sun, 8 Sep 2024 15:35:41 -0400 Subject: [PATCH 20/69] eval changes+ llama2 data --- torchao/_models/llama/eval.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index f3c12d18a9..b0249d57ba 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -114,6 +114,36 @@ def run_evaluation( quantizer = Int4WeightOnlyGPTQQuantizer(groupsize=groupsize, device=device) model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) model = quantizer.quantize(model, inputs).to(device) + + if "awq" in quantization: + from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + if not TORCH_VERSION_AT_LEAST_2_3: + print("Awq requires torch2.3+") + exit() + from torchao.prototype.awq import insert_awq_observer_, awq_uintx, ObservedLinear + quant_dtype = quantization.split("-")[1] + group_size = quantization.split("-")[2] + quant_dtype = getattr(torch, quant_dtype, torch.uint8) + + # get calibration data + inputs = InputRecorder( + tokenizer, + calibration_seq_length, + prepare_inputs_for_model, + pad_calibration_inputs, + model.config.vocab_size, + device="cpu" + ).record_inputs( + calibration_tasks, + calibration_limit, + ).get_inputs() + insert_awq_observer_(model,calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size) + model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) + for batch in inputs: + model(batch.to(device)) + batch.to("cpu") + is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) + quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) else: if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(model) @@ -140,7 +170,7 @@ def run_evaluation( parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') - parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-, int4wo--gptq, int4wo--hqq, uintx--, uintx---hqq") + parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-, int4wo--gptq, int4wo--hqq, uintx--, uintx---hqq, awq-uint-") parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') From aa62e5f41cfa93fbdc9b040f468d4b47a9ec3896 Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Sun, 8 Sep 2024 15:36:02 -0400 Subject: [PATCH 21/69] llama2 data + eval script init changes --- torchao/prototype/awq/api.py | 4 ++-- torchao/prototype/awq/core.py | 7 +++++-- torchao/prototype/awq/example.py | 5 +---- torchao/prototype/awq/readme.md | 18 +++++++++++------- 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 9bf976c57c..3c8940ee17 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -73,7 +73,7 @@ def insert_subclass(observed_linear): return insert_subclass -def awq_uintx(n_calibration_tokens:int, quant_dtype: torch.dtype = torch.uint4, group_size: int = 128): +def awq_uintx(quant_dtype: torch.dtype = torch.uint4, group_size: int = 128): """ Quantizes linear layers when passed into quantize_() @@ -85,7 +85,7 @@ def awq_uintx(n_calibration_tokens:int, quant_dtype: torch.dtype = torch.uint4, assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" def weight_quant_func(observed_linear): # weight quantization - equalization_scale = observed_linear.act_obs.calculate_qparams(n_calibration_tokens) + equalization_scale = observed_linear.act_obs.calculate_qparams() # AQT config target_dtype = torch.uint8 mapping_type = MappingType.ASYMMETRIC diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index 4f76a56635..6d4da827db 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -79,6 +79,7 @@ def __init__(self, self.bias = bias self.n_validation_examples = n_validation_examples self.validation_sequence_len = validation_sequence_len + self.calibration_token_count = 0 self.inputs = [] self.outputs = [] self.scale_options = scale_search_space_size @@ -94,14 +95,16 @@ def forward(self, input: torch.Tensor, output: torch.Tensor): if len(self.inputs) < self.n_validation_examples: self.inputs.append(input.to("cpu")) self.outputs.append(output.to("cpu")) + self.calibration_token_count += input.shape[-2] self.average += input.abs().sum(-2) + - def calculate_qparams(self, n_calibration_tokens): + def calculate_qparams(self): # import pdb # pdb.set_trace() assert self.outputs != None, "calibrate observer first by running model on exemplar data" - self.average /= (n_calibration_tokens) + self.average /= (self.calibration_token_count) for i in range(self.n_validation_examples): self.inputs[i] = self.inputs[i].to(self.device) self.outputs[i] = self.outputs[i].to(self.device) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index e88d52f772..243118e2f9 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -78,7 +78,6 @@ def wikitext2_ppl(repo_id: str, quant: str, calibration_size: int =100, validati # insert observers to find average magnitude and calculate scales insert_awq_observer_(model,validation_size, sequence_length, quant_dtype=quant_dtype, group_size=group_size) calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibration_size, block_size=sequence_length) - print(calibration_data[0].size(), calibration_data[0].dtype) for batch in calibration_data: model(batch.to(device)) batch.to("cpu") @@ -87,8 +86,7 @@ def wikitext2_ppl(repo_id: str, quant: str, calibration_size: int =100, validati # use awq_uintx() to apply awq quantization is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) t0 = time.time() - n_calibration_tokens = calibration_size * sequence_length - quantize_(model, awq_uintx(n_calibration_tokens, quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) + quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) print(f"time for quantization: {time.time() - t0:.02f} seconds") elif quant=="int4": @@ -119,7 +117,6 @@ def wikitext2_ppl(repo_id: str, quant: str, calibration_size: int =100, validati # Convert precision argument to torch dtype precision_dtype = getattr(torch, args.precision, torch.bfloat16) -print('hall0o') ppl = wikitext2_ppl( repo_id=args.repo, quant=args.quant, diff --git a/torchao/prototype/awq/readme.md b/torchao/prototype/awq/readme.md index cd22a49ce8..502da11fe7 100644 --- a/torchao/prototype/awq/readme.md +++ b/torchao/prototype/awq/readme.md @@ -2,11 +2,15 @@ Adapted from https://github.com/mit-han-lab/llm-awq ## Benchmarks -Benchmarks are run on a machine with a single RTX 3090 GPU using the script in awq/example.py which calculates perplexity by concatenating wikitex2 test examples with newlines and dividing by context length. The model used was openai-community/gpt2 with a context length of 1024. Group size of 64 was used for both int4wo and awq-uint4. +Benchmarks are run on a machine with a single RTX 3090 GPU using the script in awq/example.py which calculates perplexity by concatenating wikitex2 test examples with newlines and dividing by context length. Group size of 64 was used for both int4wo and awq-uint4. For awq quantization, c refers to number of calibration sequences and v refers to number of validation sequences. Calibration data is used to find the average magnitude of activations while validation data is used to find optimal equilization scales. Note c is always larger than v. Calibration data came from Pile dataset validation split. + +| Model | Quantization | wikitext2-perplexity | +| ------------------ | ------------------------ | ------------------- | +| GPT-2 | Base (bfloat16) | 30.1904 | +| | int4wo (tinygemm kernel) | 519.73108 | +| | awq-uint4 | 485.54907 | +| Llama-2-7b-hf | Base (bfloat16) | 5.47367 | +| | int4wo (tinygemm kernel) | 5.73546 | +| | awq-uint4-c1-v1 | 5.72359 | +| | awq-uint4-c10-v1 | 5.72350 | -| Quantization | wikitext2-perplexity | -| ------------------------ | ------------------- | -| Base (bfloat16) | 30.1904 | -| int4wo (tinygemm kernel) | 519.73108 | -| awq-uint4 | 485.54907 | -| awq-uint6 | 37.32335 | \ No newline at end of file From 7f21bfc6f3dae23decf83aec1d1953828742f529 Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Sun, 8 Sep 2024 22:50:56 -0400 Subject: [PATCH 22/69] fixed qdtype bounds and example import --- torchao/prototype/awq/api.py | 13 +++++-- torchao/prototype/awq/example.py | 62 ++++++++++++++++---------------- 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 3c8940ee17..d597e8d401 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -1,14 +1,21 @@ import torch import torch.nn.functional as F -from torchao.prototype.awq.core import AWQObserver, ObservedLinear, AwqLayoutType + from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, + _DTYPE_TO_QVALUE_BOUNDS, ) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.dtypes import to_affine_quantized_intx from torchao.dtypes.uintx.Uintx import _DTYPE_TO_BIT_WIDTH -from typing import List, Optional, Tuple +from torchao.prototype.awq.core import( + AWQObserver, + ObservedLinear, + AwqLayoutType +) + + assert len(_DTYPE_TO_BIT_WIDTH) > 0, "Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+" @@ -91,7 +98,7 @@ def weight_quant_func(observed_linear): mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) quant_min = 0 - quant_max = 255 if quant_dtype == torch.uint8 else 2 ** _DTYPE_TO_BIT_WIDTH[quant_dtype] - 1 + quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype] eps = torch.finfo(torch.float32).eps preserve_zero = True zero_point_dtype = torch.int64 diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 243118e2f9..f8a57d4b12 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -98,35 +98,35 @@ def wikitext2_ppl(repo_id: str, quant: str, calibration_size: int =100, validati return wiki2_eval(model, tokenizer, sequence_length) +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate a model with the specified parameters.") + -parser = argparse.ArgumentParser(description="Evaluate a model with the specified parameters.") - - -# Optional arguments with default values -parser.add_argument("repo", type=str, help="Repository ID of the model.") -parser.add_argument("quant", type=str, help="Quantization method. Options are either int4 or awq-uintx where x is [1..8]") -parser.add_argument("--calibration_size", type=int, default=100, help="Calibration size. Default is 100.") -parser.add_argument("--validation_size", type=int, default=100, help="Validation size. Default is 100.") -parser.add_argument("--group_size", type=int, default=128, help="Group size to use for weights. Default is '128'") -parser.add_argument("--device", type=str, default="cuda", help="Device to run the evaluation on. Default is 'cuda'.") -parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") -parser.add_argument("--sequence_length", type=int, default=512, help="Length of examples to calibrate/evaluate model on. Default 512") -parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") - -args = parser.parse_args() - -# Convert precision argument to torch dtype -precision_dtype = getattr(torch, args.precision, torch.bfloat16) -ppl = wikitext2_ppl( - repo_id=args.repo, - quant=args.quant, - calibration_size=args.calibration_size, - validation_size=args.validation_size, - group_size= args.group_size, - device=args.device, - precision=precision_dtype, - sequence_length=args.sequence_length, - compile=args.compile -) - -print(f"{args.quant} Perplexity: {ppl.item():.5f}") \ No newline at end of file + # Optional arguments with default values + parser.add_argument("repo", type=str, help="Repository ID of the model.") + parser.add_argument("quant", type=str, help="Quantization method. Options are either int4 or awq-uintx where x is [1..8]") + parser.add_argument("--calibration_size", type=int, default=100, help="Calibration size. Default is 100.") + parser.add_argument("--validation_size", type=int, default=100, help="Validation size. Default is 100.") + parser.add_argument("--group_size", type=int, default=128, help="Group size to use for weights. Default is '128'") + parser.add_argument("--device", type=str, default="cuda", help="Device to run the evaluation on. Default is 'cuda'.") + parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") + parser.add_argument("--sequence_length", type=int, default=512, help="Length of examples to calibrate/evaluate model on. Default 512") + parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") + + args = parser.parse_args() + + # Convert precision argument to torch dtype + precision_dtype = getattr(torch, args.precision, torch.bfloat16) + ppl = wikitext2_ppl( + repo_id=args.repo, + quant=args.quant, + calibration_size=args.calibration_size, + validation_size=args.validation_size, + group_size= args.group_size, + device=args.device, + precision=precision_dtype, + sequence_length=args.sequence_length, + compile=args.compile + ) + + print(f"{args.quant} Perplexity: {ppl.item():.5f}") \ No newline at end of file From 863d503dac117f186142b58f6284da5fd5fa09cc Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Mon, 9 Sep 2024 12:38:53 -0400 Subject: [PATCH 23/69] fix tests --- torchao/prototype/awq/__init__.py | 3 +-- torchao/prototype/awq/api.py | 4 ++-- torchao/prototype/awq/example.py | 15 ++++++++++++--- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/torchao/prototype/awq/__init__.py b/torchao/prototype/awq/__init__.py index ce128e1aa2..e3a2a37454 100644 --- a/torchao/prototype/awq/__init__.py +++ b/torchao/prototype/awq/__init__.py @@ -1,3 +1,2 @@ from .api import insert_awq_observer_, awq_uintx -from .core import ObservedLinear -from .example import get_calib_dataset \ No newline at end of file +from .core import ObservedLinear \ No newline at end of file diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index d597e8d401..374d66e749 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -97,8 +97,8 @@ def weight_quant_func(observed_linear): target_dtype = torch.uint8 mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) - quant_min = 0 - quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype] + quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0] + quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1] eps = torch.finfo(torch.float32).eps preserve_zero = True zero_point_dtype = torch.int64 diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index f8a57d4b12..24592b0e8f 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -60,7 +60,16 @@ def wiki2_eval(model, tokenizer, sequence_length): return ppl -def wikitext2_ppl(repo_id: str, quant: str, calibration_size: int =100, validation_size:int=100, group_size: int = 128, device="cuda", precision=torch.bfloat16, sequence_length=2048, compile=False): +def wikitext2_ppl( + repo_id: str, + quant: str, + calibration_size: int =100, + validation_size:int=100, + group_size: int = 128, + device="cuda", + precision=torch.bfloat16, + sequence_length=2048, + compile=False): print(f"Loading model on {device}...") torch.manual_seed(34) t0 = time.time() @@ -105,8 +114,8 @@ def wikitext2_ppl(repo_id: str, quant: str, calibration_size: int =100, validati # Optional arguments with default values parser.add_argument("repo", type=str, help="Repository ID of the model.") parser.add_argument("quant", type=str, help="Quantization method. Options are either int4 or awq-uintx where x is [1..8]") - parser.add_argument("--calibration_size", type=int, default=100, help="Calibration size. Default is 100.") - parser.add_argument("--validation_size", type=int, default=100, help="Validation size. Default is 100.") + parser.add_argument("--calibration_size", type=int, default=10, help="Calibration size. Default is 10.") + parser.add_argument("--validation_size", type=int, default=10, help="Validation size. Default is 10.") parser.add_argument("--group_size", type=int, default=128, help="Group size to use for weights. Default is '128'") parser.add_argument("--device", type=str, default="cuda", help="Device to run the evaluation on. Default is 'cuda'.") parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") From 4ca91179721d2e87a04ac5c078094f93b8b8f578 Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Mon, 9 Sep 2024 21:30:31 -0400 Subject: [PATCH 24/69] fix tests --- test/prototype/test_awq.py | 10 ++++++---- torchao/prototype/awq/api.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index e0f57ab6e2..8bb03cb514 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -43,19 +43,21 @@ def test(device, qdtype, idtype): original_dtype = idtype quant_dtype = qdtype group_size = 128 + n_calibration_examples = 10 + n_validation_examples = 10 + sequence_length = 5 m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device) m_bf16 = deepcopy(m) - dataset = m.example_inputs(dataset_size, dtype=original_dtype, device=device) - calibration_data = dataset[:50] + dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device) + calibration_data = dataset[:n_calibration_examples] bf16_out = torch.cat([m_bf16(i.squeeze(0)) for i in dataset], dim=0) # calibrate - insert_awq_observer_(m, quant_dtype=quant_dtype, group_size=group_size) + insert_awq_observer_(m, n_validation_examples,sequence_length, quant_dtype=quant_dtype, group_size=group_size) for example in calibration_data: m(example.to(device)) - # print('calibrated') # quantize is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 374d66e749..89b13b1e7a 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -25,8 +25,8 @@ def insert_awq_observer_(model: torch.nn.Module, n_validation_examples: int, val Args: model: The model to be modified (in place). Ensure model is on the desired device for calibration - validation_sequence_len: Number of tokens in each validation example n_validation_examples: Number of examples used to validate scale options + validation_sequence_len: Number of tokens in each validation example quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8 scale search space size: how many different scale options to try. Original AWQ implementation uses 20. A larger size can lead to better results but takes longer to calibrate group_size: Quantization granularity. Use -1 for channel wise quantization From 3e70a6fc7ac98316db4601314ced8b0f0b485549 Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Tue, 10 Sep 2024 14:29:47 -0400 Subject: [PATCH 25/69] use rolling log liklihood for eval and calibrate awq with run_eval --- torchao/_models/llama/eval.py | 37 +++++++++++++++----------------- torchao/prototype/awq/example.py | 14 ++++++------ 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index b0249d57ba..34f2b8673e 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -89,10 +89,10 @@ def run_evaluation( else: use_hqq = False _quant_args = quantization.split("-") - nbits = int(_quant_args[0]) + nbits = int(_quant_args[1]) _NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8} dtype = _NBITS_TO_DTYPE[nbits] - group_size = int(_quant_args[1]) + group_size = int(_quant_args[2]) quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq)) if "int4wo" in quantization and "gptq" in quantization: groupsize=int(quantization.split("-")[-2]) @@ -107,7 +107,7 @@ def run_evaluation( model.config.vocab_size, device="cpu" ).record_inputs( - calibration_tasks, + ["pile_hackernews"], calibration_limit, ).get_inputs() @@ -122,28 +122,25 @@ def run_evaluation( exit() from torchao.prototype.awq import insert_awq_observer_, awq_uintx, ObservedLinear quant_dtype = quantization.split("-")[1] - group_size = quantization.split("-")[2] + group_size = int(quantization.split("-")[2]) quant_dtype = getattr(torch, quant_dtype, torch.uint8) - + model=model.to(device) # get calibration data - inputs = InputRecorder( - tokenizer, - calibration_seq_length, - prepare_inputs_for_model, - pad_calibration_inputs, - model.config.vocab_size, - device="cpu" - ).record_inputs( - calibration_tasks, - calibration_limit, - ).get_inputs() insert_awq_observer_(model,calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size) - model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) - for batch in inputs: - model(batch.to(device)) - batch.to("cpu") + with torch.no_grad(): + TransformerEvalWrapper( + model=model, + tokenizer=tokenizer, + max_seq_length=calibration_seq_length, + input_prep_func=prepare_inputs_for_model, + device=device, + ).run_eval( + tasks=calibration_tasks, + limit=calibration_limit, + ) is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) + pass else: if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(model) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 24592b0e8f..a1f55cc02f 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -5,7 +5,7 @@ from tqdm import tqdm import time from torchao.prototype.awq.api import insert_awq_observer_, ObservedLinear, awq_uintx -from torchao.quantization import quantize_, int4_weight_only +from torchao.quantization import quantize_, int4_weight_only, uintx_weight_only # adapted from: https://github.com/mit-han-lab/llm-awq @@ -35,19 +35,20 @@ def wiki2_eval(model, tokenizer, sequence_length): testenc = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") testenc = tokenizer("\n\n".join(testenc["text"]), return_tensors="pt") testenc = testenc.input_ids.to(model.device) - nsamples = testenc.numel() // sequence_length + nsamples = 100 model = model.eval() + model(testenc[:, :sequence_length].to(model.device)) # calculate perplexity nlls = [] for i in tqdm(range(nsamples), desc="evaluating..."): - batch = testenc[:, (i * sequence_length) : ((i + 1) * sequence_length)].to( + batch = testenc[:, i : i + sequence_length].to( model.device ) with torch.no_grad(): lm_logits = model(batch).logits shift_logits = lm_logits[:, :-1, :].contiguous().float() shift_labels = testenc[ - :, (i * sequence_length) : ((i + 1) * sequence_length) + :, i : i + sequence_length ][:, 1:] loss_fct = torch.nn.CrossEntropyLoss() loss = loss_fct( @@ -100,12 +101,13 @@ def wikitext2_ppl( elif quant=="int4": print("running int4 quantization") - quantize_(model, int4_weight_only(group_size=64)) + # quantize_(model, uintx_weight_only(torch.uint4, group_size=64)) + quantize_(model, int4_weight_only(group_size=group_size)) if compile: model = torch.compile(model) - return wiki2_eval(model, tokenizer, sequence_length) + return wiki2_eval(model, tokenizer, 1024) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Evaluate a model with the specified parameters.") From 8ab016a7708a53c36a8e81568192b11eb19daad8 Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Thu, 12 Sep 2024 11:18:47 -0400 Subject: [PATCH 26/69] make eval use less vram --- torchao/prototype/awq/example.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index a1f55cc02f..aa18462260 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -34,10 +34,9 @@ def get_calib_dataset(tokenizer=None, n_samples=100, block_size=512): def wiki2_eval(model, tokenizer, sequence_length): testenc = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") testenc = tokenizer("\n\n".join(testenc["text"]), return_tensors="pt") - testenc = testenc.input_ids.to(model.device) + testenc = testenc.input_ids nsamples = 100 model = model.eval() - model(testenc[:, :sequence_length].to(model.device)) # calculate perplexity nlls = [] for i in tqdm(range(nsamples), desc="evaluating..."): @@ -46,10 +45,10 @@ def wiki2_eval(model, tokenizer, sequence_length): ) with torch.no_grad(): lm_logits = model(batch).logits + batch = batch.to("cpu") + lm_logits = lm_logits.to("cpu") shift_logits = lm_logits[:, :-1, :].contiguous().float() - shift_labels = testenc[ - :, i : i + sequence_length - ][:, 1:] + shift_labels = testenc[:, i : i + sequence_length][:, 1:].to("cpu") loss_fct = torch.nn.CrossEntropyLoss() loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) @@ -116,12 +115,12 @@ def wikitext2_ppl( # Optional arguments with default values parser.add_argument("repo", type=str, help="Repository ID of the model.") parser.add_argument("quant", type=str, help="Quantization method. Options are either int4 or awq-uintx where x is [1..8]") - parser.add_argument("--calibration_size", type=int, default=10, help="Calibration size. Default is 10.") - parser.add_argument("--validation_size", type=int, default=10, help="Validation size. Default is 10.") + parser.add_argument("--calibration_samples", type=int, default=10, help="Number of samples to use for calibration. Default is 10.") + parser.add_argument("--validation_size", type=int, default=1, help="Validation size. Default is 1.") parser.add_argument("--group_size", type=int, default=128, help="Group size to use for weights. Default is '128'") parser.add_argument("--device", type=str, default="cuda", help="Device to run the evaluation on. Default is 'cuda'.") parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") - parser.add_argument("--sequence_length", type=int, default=512, help="Length of examples to calibrate/evaluate model on. Default 512") + parser.add_argument("--seq_len", type=int, default=512, help="Length of examples to calibrate/evaluate model on. Default 512") parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") args = parser.parse_args() @@ -131,12 +130,12 @@ def wikitext2_ppl( ppl = wikitext2_ppl( repo_id=args.repo, quant=args.quant, - calibration_size=args.calibration_size, + calibration_size=args.calibration_samples, validation_size=args.validation_size, group_size= args.group_size, device=args.device, precision=precision_dtype, - sequence_length=args.sequence_length, + sequence_length=args.seq_len, compile=args.compile ) From 27c062d27722ab8432a5df1f3542315ced60e944 Mon Sep 17 00:00:00 2001 From: Vayuda <120random.things@gmail.com> Date: Thu, 12 Sep 2024 13:54:47 -0400 Subject: [PATCH 27/69] updated uintx import --- torchao/prototype/awq/api.py | 2 +- torchao/prototype/awq/core.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 89b13b1e7a..20f450dc21 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -8,7 +8,7 @@ ) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.dtypes import to_affine_quantized_intx -from torchao.dtypes.uintx.Uintx import _DTYPE_TO_BIT_WIDTH +from torchao.dtypes.uintx.uintx import _DTYPE_TO_BIT_WIDTH from torchao.prototype.awq.core import( AWQObserver, ObservedLinear, diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index 6d4da827db..291f1459ac 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -4,7 +4,7 @@ import torch import torch.nn.functional as F -from torchao.dtypes.uintx.Uintx import to_uintx +from torchao.dtypes.uintx.uintx import to_uintx from torchao.dtypes.affine_quantized_tensor import ( to_affine_quantized_intx, LayoutType, From 9d52c93838788d80d1a229de3ebce8db3a8abfb6 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Sat, 21 Sep 2024 18:16:51 -0700 Subject: [PATCH 28/69] add awq to generate --- torchao/_models/llama/eval.py | 1 - torchao/_models/llama/generate.py | 26 ++++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 7475763f29..3c71012455 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -145,7 +145,6 @@ def run_evaluation( ) is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) - pass else: if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(model) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 5fb905dbf9..99938881a3 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -229,6 +229,32 @@ def main( quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType())) if "fp6" in quantization: quantize_(model, fpx_weight_only(3, 2)) + if quantization.startswith("awq"): + from torchao._models._eval import TransformerEvalWrapper + from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + if not TORCH_VERSION_AT_LEAST_2_3: + print("Awq requires torch2.3+") + exit() + from torchao.prototype.awq import insert_awq_observer_, awq_uintx, ObservedLinear + quant_dtype = quantization.split("-")[1] + group_size = int(quantization.split("-")[2]) + quant_dtype = getattr(torch, quant_dtype, torch.uint8) + model=model.to(device) + # get calibration data + insert_awq_observer_(model,calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size) + with torch.no_grad(): + TransformerEvalWrapper( + model=model, + tokenizer=tokenizer, + max_seq_length=calibration_seq_length, + input_prep_func=prepare_inputs_for_model, + device=device, + ).run_eval( + tasks=calibration_tasks, + limit=calibration_limit, + ) + is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) + quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) if "uintx" in quantization: # uintx-nbits-groupsize, e.g. "uintx-2-64" if "hqq" in quantization: From 310138e9ecff6fc77ed25cb1cf7fef75b2681f95 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Sat, 21 Sep 2024 18:32:13 -0700 Subject: [PATCH 29/69] add calibration params to cli --- torchao/_models/llama/generate.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 99938881a3..c4de79cee3 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -161,6 +161,8 @@ def main( temperature: float = 0.8, checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), quantization: Optional[str] = None, + calibration_limit: int = 10, + calibration_seq_length: int = 256, kv_cache_quantization: bool = False, cache_size: Optional[int] = None, linear_causal_mask: bool=False, @@ -246,7 +248,7 @@ def main( TransformerEvalWrapper( model=model, tokenizer=tokenizer, - max_seq_length=calibration_seq_length, + max_seq_length=args.calibration_seq_length, input_prep_func=prepare_inputs_for_model, device=device, ).run_eval( @@ -444,6 +446,8 @@ def callback(x): +'autoquant-int4, uintx--, uintx---hqq, sparse-marlin' ) ) + parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples") + parser.add_argument("--calibration_sequence_length", type=int, default=256, help="Sequence length for calibration") parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size') parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)') @@ -459,5 +463,5 @@ def callback(x): args = parser.parse_args() main( args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, - args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result + args.temperature, args.checkpoint_path, args.quantization, args.calibration_limit, args.calibration_sequence_length, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result ) From a1b2bd0079291e6303d22a9807efe2a56d781859 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Sat, 21 Sep 2024 18:34:52 -0700 Subject: [PATCH 30/69] fix name --- torchao/_models/llama/generate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index c4de79cee3..350856dc18 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -248,7 +248,7 @@ def main( TransformerEvalWrapper( model=model, tokenizer=tokenizer, - max_seq_length=args.calibration_seq_length, + max_seq_length=calibration_seq_length, input_prep_func=prepare_inputs_for_model, device=device, ).run_eval( @@ -447,7 +447,7 @@ def callback(x): ) ) parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples") - parser.add_argument("--calibration_sequence_length", type=int, default=256, help="Sequence length for calibration") + parser.add_argument("--calibration_seq_length", type=int, default=256, help="Sequence length for calibration") parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size') parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)') @@ -463,5 +463,5 @@ def callback(x): args = parser.parse_args() main( args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, - args.temperature, args.checkpoint_path, args.quantization, args.calibration_limit, args.calibration_sequence_length, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result + args.temperature, args.checkpoint_path, args.quantization, args.calibration_limit, args.calibration_seq_length, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result ) From 9379686582a2f764b909734cb21256ad05fbf5ef Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Tue, 24 Sep 2024 21:44:34 -0700 Subject: [PATCH 31/69] pass linear properly --- scripts/create_weight_map.py | 43 ++++++++++++++++++++++++++++++++++++ torchao/prototype/awq/api.py | 2 +- 2 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 scripts/create_weight_map.py diff --git a/scripts/create_weight_map.py b/scripts/create_weight_map.py new file mode 100644 index 0000000000..334465a9a2 --- /dev/null +++ b/scripts/create_weight_map.py @@ -0,0 +1,43 @@ +import json +import torch +from transformers import AutoModel +from pathlib import Path +def create_weight_map(checkpoint_dir: Path): + """ + This function, create_weight_map, generates a mapping of a model's weights to a file (pytorch_model.bin) + and saves this mapping, along with the model's total size, to a JSON file (pytorch_model.bin.index.json). + The model is loaded from a pre-trained model specified by model_name. + This weight map is used by the HF conversion script (convert_hf_checkpoint.py). + """ + # Load the model + model_name = checkpoint_dir.parent.name +"/"+ checkpoint_dir.name + print(model_name) + model = AutoModel.from_pretrained(model_name) + # Get the state dict + state_dict = model.state_dict() + # Create the weight map + weight_map = {} + for key, tensor in state_dict.items(): + # In this example, we're assuming all weights are in a single file + # You may need to adjust this if your model uses sharded weights + weight_map[key] = "pytorch_model.bin" + # Create the index dictionary + index_dict = { + "metadata": {"total_size": sum(param.numel() * param.element_size() for param in model.parameters())}, + "weight_map": weight_map + } + # Save the index dictionary to a JSON file + with open(f"{checkpoint_dir}/pytorch_model.bin.index.json", "w") as f: + json.dump(index_dict, f, indent=2) + print("Created pytorch_model.bin.index.json") + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Create weight map for hf model') + parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/Xenova/llama2.c-stories15M")) + + + args = parser.parse_args() + create_weight_map( + args.checkpoint_dir + ) \ No newline at end of file diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 20f450dc21..fb7b3fd34c 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -106,7 +106,7 @@ def weight_quant_func(observed_linear): layout_type = AwqLayoutType(equalization_scale, quant_dtype) return to_affine_quantized_intx( - observed_linear.weight, + observed_linear, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, From 8d173dfba77e1edbca0cbfcdaa97f3b673223ce4 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Tue, 24 Sep 2024 21:58:50 -0700 Subject: [PATCH 32/69] recast W*eq_scale to original dtype --- torchao/prototype/awq/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index 291f1459ac..49293c45b9 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -168,7 +168,7 @@ class AwqLayoutType(LayoutType): dtype: torch.dtype def pre_process(self, input: torch.Tensor) -> torch.Tensor: - return input * self.equalization_scale + return (input * self.equalization_scale).to(input.dtype) def post_process(self, input: torch.Tensor) -> torch.Tensor: # pack weights for sub dtype bit size From 6aab8f88855812aa36a367276977a3a07a00f01e Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Tue, 24 Sep 2024 21:59:35 -0700 Subject: [PATCH 33/69] revert bad change --- torchao/prototype/awq/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index fb7b3fd34c..20f450dc21 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -106,7 +106,7 @@ def weight_quant_func(observed_linear): layout_type = AwqLayoutType(equalization_scale, quant_dtype) return to_affine_quantized_intx( - observed_linear, + observed_linear.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, From 4e976113123c410ab446d4e85eb1fc0a9aa8e5a4 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Tue, 24 Sep 2024 22:08:51 -0700 Subject: [PATCH 34/69] make scales same type as model --- torchao/prototype/awq/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index 49293c45b9..72008b3a8a 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -113,7 +113,7 @@ def calculate_qparams(self): best_scales = None for i in range(self.scale_options): ratio = i * 1 / self.scale_options - scales = self.average.pow(ratio) + scales = self.average.pow(ratio).to(self.inputs[0].dtype) scales = scales / (scales.max() * scales.min()).sqrt() layout = AwqLayoutType(scales, self.target_dtype) # regardless of weight dtype, we have to store as packed uint8 tensors From 77db01f438b409ff24d980037bf44b9873a5febf Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 25 Sep 2024 13:14:01 -0700 Subject: [PATCH 35/69] compatible with compile --- test/prototype/test_awq.py | 10 +++- torchao/prototype/awq/api.py | 4 +- torchao/prototype/awq/core.py | 104 +++++++++++++++++++++++++++------- 3 files changed, 95 insertions(+), 23 deletions(-) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 8bb03cb514..1d70c9b2e1 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -30,8 +30,13 @@ def forward(self, x): qdtypes = (torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7, torch.uint8) else: qdtypes = () - -idtypes = (torch.bfloat16,)#, torch.half, torch.float32) + +@pytest.fixture(autouse=True) +def run_before_and_after_tests(): + yield + torch._dynamo.reset() # reset cache between tests + +idtypes = (torch.half,)#, torch.half, torch.float32) @pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("qdtype", qdtypes) @pytest.mark.parametrize("idtype", idtypes) @@ -62,6 +67,7 @@ def test(device, qdtype, idtype): # quantize is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) quantize_(m, awq_uintx(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear) + m = torch.compile(m, fullgraph=True) awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) assert awq_out is not None diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 20f450dc21..7e46d1920b 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -103,10 +103,10 @@ def weight_quant_func(observed_linear): preserve_zero = True zero_point_dtype = torch.int64 zero_point_domain = ZeroPointDomain.INT - layout_type = AwqLayoutType(equalization_scale, quant_dtype) + layout_type = AwqLayoutType(quant_dtype, equalization_scale) return to_affine_quantized_intx( - observed_linear.weight, + observed_linear.weight * equalization_scale, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index 72008b3a8a..4b2de194a8 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -3,13 +3,13 @@ import torch import torch.nn.functional as F - +from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.dtypes.uintx.uintx import to_uintx from torchao.dtypes.affine_quantized_tensor import ( to_affine_quantized_intx, LayoutType, register_layout_cls, - PlainAQTLayout, + AQTLayout, register_aqt_quantized_linear_dispatch ) @@ -113,13 +113,13 @@ def calculate_qparams(self): best_scales = None for i in range(self.scale_options): ratio = i * 1 / self.scale_options - scales = self.average.pow(ratio).to(self.inputs[0].dtype) + scales = self.average.pow(ratio).to(self.weight.dtype) scales = scales / (scales.max() * scales.min()).sqrt() - layout = AwqLayoutType(scales, self.target_dtype) + layout = AwqLayoutType(self.target_dtype, scales) # regardless of weight dtype, we have to store as packed uint8 tensors tensor_dtype = torch.uint8 w = to_affine_quantized_intx( - self.weight.data, + self.weight*scales, self.mapping_type, self.block_size, tensor_dtype, @@ -164,11 +164,8 @@ def from_float(cls, float_linear: torch.nn.Linear, act_obs: AWQObserver): @dataclass(frozen=True) class AwqLayoutType(LayoutType): - equalization_scale: torch.Tensor dtype: torch.dtype - - def pre_process(self, input: torch.Tensor) -> torch.Tensor: - return (input * self.equalization_scale).to(input.dtype) + equalization_scale: torch.Tensor def post_process(self, input: torch.Tensor) -> torch.Tensor: # pack weights for sub dtype bit size @@ -176,23 +173,80 @@ def post_process(self, input: torch.Tensor) -> torch.Tensor: return to_uintx(input, self.dtype) return input - def _quantized_linear_impl(input_tensor, weight_tensor, bias): - # divide activations by awq scales - # print(input_tensor.dtype, weight_tensor.layout_tensor.layout_type.equalization_scale.dtype, weight_tensor.dequantize().dtype, bias.dtype) - return F.linear(input_tensor / weight_tensor.layout_tensor.layout_type.equalization_scale, weight_tensor.dequantize(), bias) - - def _linear_awq_check(input_tensor, weight_tensor, bias): - return isinstance(weight_tensor.layout_tensor, AwqAQTLayout) +def _quantized_linear_impl(input_tensor, weight_tensor, bias): + # divide activations by awq scales + return F.linear(input_tensor / weight_tensor.layout_tensor.equalization_scale, weight_tensor.dequantize(), bias) -register_aqt_quantized_linear_dispatch(AwqLayoutType._linear_awq_check, AwqLayoutType._quantized_linear_impl) +def _linear_awq_check(input_tensor, weight_tensor, bias): + return isinstance(weight_tensor.layout_tensor, AwqAQTLayout) + +register_aqt_quantized_linear_dispatch(_linear_awq_check, _quantized_linear_impl) @register_layout_cls(AwqLayoutType) -class AwqAQTLayout(PlainAQTLayout): +class AwqAQTLayout(AQTLayout): + @staticmethod + def __new__( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + equalization_scale: torch.Tensor, + layout_type: LayoutType, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + equalization_scale: torch.Tensor, + layout_type: LayoutType, + ): + self.int_data = int_data + self.scale = scale + self.zero_point = zero_point + self.equalization_scale = equalization_scale + self.layout_type = layout_type + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # unpack if needed w = self.int_data if self.layout_type.dtype == torch.uint8 else self.int_data.get_plain() return w, self.scale, self.zero_point + def __tensor_flatten__(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return ["int_data", "scale", "zero_point", "equalization_scale"], [self.layout_type] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] + equalization_scale = tensor_data_dict["equalization_scale"] + layout_type, = tensor_attributes + return cls(int_data, scale, zero_point, equalization_scale, layout_type) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is torch.ops.aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"AwqAQTLayout dispatch: attempting to run {func}, this is not supported" + ) + @classmethod def from_plain( cls, @@ -202,4 +256,16 @@ def from_plain( layout_type: LayoutType, ): assert isinstance(layout_type, AwqLayoutType) - return cls(int_data, scale, zero_point, layout_type) \ No newline at end of file + return cls(int_data, scale, zero_point, layout_type.equalization_scale, layout_type) + + def get_layout_type(self) -> LayoutType: + return self.layout_type + + def _apply_fn_to_data(self, fn): + self.int_data = fn(self.int_data) + self.scale = fn(self.scale) + self.zero_point = fn(self.zero_point) + self.equalization_scale = fn(self.equalization_scale) + return self + +to_awq = AwqAQTLayout.from_plain \ No newline at end of file From 588e81ef804565fec719ad6e5393cf0c8eafe7c0 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 25 Sep 2024 14:51:57 -0700 Subject: [PATCH 36/69] cast eq scale to bf16 --- torchao/prototype/awq/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 7e46d1920b..c9b55ff695 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -92,7 +92,7 @@ def awq_uintx(quant_dtype: torch.dtype = torch.uint4, group_size: int = 128): assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" def weight_quant_func(observed_linear): # weight quantization - equalization_scale = observed_linear.act_obs.calculate_qparams() + equalization_scale = observed_linear.act_obs.calculate_qparams().to(observed_linear.weight.dtype) # AQT config target_dtype = torch.uint8 mapping_type = MappingType.ASYMMETRIC From bfa5797a16b84080e59a51df82e66f4fd269a26e Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 25 Sep 2024 15:18:50 -0700 Subject: [PATCH 37/69] switch calibration dataset --- torchao/_models/llama/eval.py | 14 ++++---------- torchao/_models/llama/generate.py | 15 +++++---------- 2 files changed, 9 insertions(+), 20 deletions(-) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 3c71012455..e5530aa587 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -133,16 +133,10 @@ def run_evaluation( # get calibration data insert_awq_observer_(model,calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size) with torch.no_grad(): - TransformerEvalWrapper( - model=model, - tokenizer=tokenizer, - max_seq_length=calibration_seq_length, - input_prep_func=prepare_inputs_for_model, - device=device, - ).run_eval( - tasks=calibration_tasks, - limit=calibration_limit, - ) + calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibration_limit, block_size=calibration_seq_length) + for batch in calibration_data: + model(batch.to(device)) + batch.to("cpu") is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) else: diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 350856dc18..064ca707ac 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -245,16 +245,11 @@ def main( # get calibration data insert_awq_observer_(model,calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size) with torch.no_grad(): - TransformerEvalWrapper( - model=model, - tokenizer=tokenizer, - max_seq_length=calibration_seq_length, - input_prep_func=prepare_inputs_for_model, - device=device, - ).run_eval( - tasks=calibration_tasks, - limit=calibration_limit, - ) + with torch.no_grad(): + calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibration_limit, block_size=calibration_seq_length) + for batch in calibration_data: + model(batch.to(device)) + batch.to("cpu") is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) if "uintx" in quantization: From 41a621b8cfc93f6c923589f47668d6e501571a34 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 25 Sep 2024 15:19:27 -0700 Subject: [PATCH 38/69] remove extra line --- torchao/_models/llama/generate.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 064ca707ac..4738fa27b7 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -245,7 +245,6 @@ def main( # get calibration data insert_awq_observer_(model,calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size) with torch.no_grad(): - with torch.no_grad(): calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibration_limit, block_size=calibration_seq_length) for batch in calibration_data: model(batch.to(device)) From 20767d5ab4db5211622987f922e9fb707c6482e5 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 25 Sep 2024 15:26:06 -0700 Subject: [PATCH 39/69] add import --- torchao/_models/llama/eval.py | 1 + torchao/_models/llama/generate.py | 1 + 2 files changed, 2 insertions(+) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index e5530aa587..11ee9e6209 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -122,6 +122,7 @@ def run_evaluation( if "awq" in quantization: from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + from torchao.prototype.awq.example import get_calib_dataset if not TORCH_VERSION_AT_LEAST_2_3: print("Awq requires torch2.3+") exit() diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 4738fa27b7..3f9527c151 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -234,6 +234,7 @@ def main( if quantization.startswith("awq"): from torchao._models._eval import TransformerEvalWrapper from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + from torchao.prototype.awq.example import get_calib_dataset if not TORCH_VERSION_AT_LEAST_2_3: print("Awq requires torch2.3+") exit() From ae32c7cdf39735c7d895d783d6871505ab30c69e Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 27 Sep 2024 18:32:46 -0700 Subject: [PATCH 40/69] added save/load scales --- test/prototype/test_awq.py | 21 ++++++++++--- torchao/prototype/awq/__init__.py | 2 +- torchao/prototype/awq/api.py | 51 ++++++++++++++++++++++++++++--- torchao/prototype/awq/core.py | 1 + 4 files changed, 66 insertions(+), 9 deletions(-) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 1d70c9b2e1..b871669c0f 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -1,11 +1,12 @@ from copy import deepcopy +import os import pytest import torch from torchao.quantization import quantize_ from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 if TORCH_VERSION_AT_LEAST_2_3: - from torchao.prototype.awq import insert_awq_observer_, awq_uintx, ObservedLinear + from torchao.prototype.awq import insert_awq_observer_, awq_uintx, ObservedLinear, save_equalization_scales, load_equalization_scales_and_quantize_ class ToyLinearModel(torch.nn.Module): def __init__(self, m=512, n=256, k=128): @@ -22,12 +23,11 @@ def forward(self, x): x = self.linear2(x) x = self.linear3(x) return x - devices = ["cuda"] # torch.uintx dtypes are introduced in 2.3 if TORCH_VERSION_AT_LEAST_2_3: - qdtypes = (torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7, torch.uint8) + qdtypes = (torch.uint4,)#torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7, torch.uint8) else: qdtypes = () @@ -53,6 +53,7 @@ def test(device, qdtype, idtype): sequence_length = 5 m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device) + m_save_load = deepcopy(m) m_bf16 = deepcopy(m) dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device) @@ -60,14 +61,26 @@ def test(device, qdtype, idtype): bf16_out = torch.cat([m_bf16(i.squeeze(0)) for i in dataset], dim=0) # calibrate - insert_awq_observer_(m, n_validation_examples,sequence_length, quant_dtype=quant_dtype, group_size=group_size) + insert_awq_observer_(m, n_validation_examples, sequence_length, quant_dtype=quant_dtype, group_size=group_size) + insert_awq_observer_(m_save_load, n_validation_examples, sequence_length, quant_dtype=quant_dtype, group_size=group_size) + for example in calibration_data: m(example.to(device)) + m_save_load(example.to(device)) + equalization_scale_path = "equalization_scales.pt" + save_equalization_scales(m_save_load, equalization_scale_path) + load_equalization_scales_and_quantize_(m_save_load, equalization_scale_path) # quantize is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) quantize_(m, awq_uintx(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear) m = torch.compile(m, fullgraph=True) + m_save_load = torch.compile(m_save_load, fullgraph=True) awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) + awq_save_load_out = torch.cat([m_save_load(i.squeeze(0)) for i in dataset]) assert awq_out is not None + assert awq_save_load_out is not None + assert torch.allclose(awq_out, awq_save_load_out, atol = 1e-2) + # remove equalization scale file + os.remove(equalization_scale_path) diff --git a/torchao/prototype/awq/__init__.py b/torchao/prototype/awq/__init__.py index e3a2a37454..6be48a9cb8 100644 --- a/torchao/prototype/awq/__init__.py +++ b/torchao/prototype/awq/__init__.py @@ -1,2 +1,2 @@ -from .api import insert_awq_observer_, awq_uintx +from .api import insert_awq_observer_, awq_uintx, save_equalization_scales, load_equalization_scales_and_quantize_ from .core import ObservedLinear \ No newline at end of file diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index c9b55ff695..6b3b41947d 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -1,6 +1,6 @@ import torch import torch.nn.functional as F - +import json from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, @@ -14,7 +14,7 @@ ObservedLinear, AwqLayoutType ) - +from typing import Dict, Optional assert len(_DTYPE_TO_BIT_WIDTH) > 0, "Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+" @@ -79,6 +79,46 @@ def insert_subclass(observed_linear): return linear return insert_subclass + +def save_equalization_scales(model: torch.nn.Module, save_path: str) -> Dict[str, torch.Tensor]: + result = {} + + def recurse(module: torch.nn.Module, name: str = ''): + for child_name, child in module.named_children(): + full_name = f"{name}.{child_name}" if name else child_name + + # Apply the analysis function to this layer + if isinstance(child, ObservedLinear): + result[full_name] = child.act_obs.calculate_qparams() + + # Recurse into child modules + recurse(child, full_name) + + recurse(model) + + torch.save(result, save_path) + +def load_equalization_scales_and_quantize_(model: torch.nn.Module, equalization_scale_path: str, quant_dtype: torch.dtype = torch.uint4, group_size: int = 128, device=None) -> torch.nn.Module: + equalization_scales = torch.load(equalization_scale_path) + + def recurse(module: torch.nn.Module, name: str = ''): + if isinstance(module, ObservedLinear): + module.equalization_scale = equalization_scales[name] + if device is not None: + module.to(device=device) # move to device before quantization + module = awq_uintx(quant_dtype, group_size)(module) + else: + for child_name, child in module.named_children(): + full_name = f"{name}.{child_name}" if name else child_name + new_child = recurse(child, full_name) + if new_child is not child: + setattr(model, full_name, new_child) + if device is not None: + module.to(device=device) + return module + + recurse(model) + def awq_uintx(quant_dtype: torch.dtype = torch.uint4, group_size: int = 128): """ @@ -92,13 +132,16 @@ def awq_uintx(quant_dtype: torch.dtype = torch.uint4, group_size: int = 128): assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" def weight_quant_func(observed_linear): # weight quantization - equalization_scale = observed_linear.act_obs.calculate_qparams().to(observed_linear.weight.dtype) # AQT config + if observed_linear.equalization_scale is None: + equalization_scale = observed_linear.act_obs.calculate_qparams() + else: + equalization_scale = observed_linear.equalization_scale target_dtype = torch.uint8 mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0] - quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1] + quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1] eps = torch.finfo(torch.float32).eps preserve_zero = True zero_point_dtype = torch.int64 diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index 4b2de194a8..6e65157e8d 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -148,6 +148,7 @@ class ObservedLinear(torch.nn.Linear): def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None): super().__init__(in_features, out_features, bias, device, dtype) self.act_obs = act_obs + self.equalization_scale = None def forward(self, input: torch.Tensor): output = F.linear(input, self.weight, self.bias) From e2160aefb8d799bc6ee60a62b3163f96391e74e5 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 27 Sep 2024 19:01:40 -0700 Subject: [PATCH 41/69] add save/store workflow to example --- torchao/prototype/awq/example.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index aa18462260..3f5cfaf29b 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -4,7 +4,7 @@ from datasets import load_dataset from tqdm import tqdm import time -from torchao.prototype.awq.api import insert_awq_observer_, ObservedLinear, awq_uintx +from torchao.prototype.awq import insert_awq_observer_, ObservedLinear, awq_uintx, save_equalization_scales, load_equalization_scales_and_quantize_ from torchao.quantization import quantize_, int4_weight_only, uintx_weight_only @@ -92,10 +92,17 @@ def wikitext2_ppl( batch.to("cpu") print(f"time for calibration: {time.time() - t0:.02f} seconds") - # use awq_uintx() to apply awq quantization - is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) - t0 = time.time() - quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) + print(f"running {quant_dtype} quantization") + if scale_store_path is not None: + print(f"Saving equalization scales to {scale_store_path}") + save_equalization_scales(scale_store_path) + load_equalization_scales_and_quantize_(model, quant_dtype=quant_dtype, group_size=group_size) + else: + # use awq_uintx() to apply awq quantization + is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) + t0 = time.time() + quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) + print(f"time for quantization: {time.time() - t0:.02f} seconds") elif quant=="int4": @@ -122,6 +129,7 @@ def wikitext2_ppl( parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") parser.add_argument("--seq_len", type=int, default=512, help="Length of examples to calibrate/evaluate model on. Default 512") parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") + parser.add_argument("scale_store_path", type=str, default= None, help="Path to store the scale values.") args = parser.parse_args() From 9704c386b6f60f98bba4426e8f5f06becc8773ef Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 27 Sep 2024 19:03:28 -0700 Subject: [PATCH 42/69] add arg to fn --- torchao/prototype/awq/example.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 3f5cfaf29b..6af3af5b11 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -144,7 +144,8 @@ def wikitext2_ppl( device=args.device, precision=precision_dtype, sequence_length=args.seq_len, - compile=args.compile + compile=args.compile, + scale_store_path=args.scale_store_path ) print(f"{args.quant} Perplexity: {ppl.item():.5f}") \ No newline at end of file From 04ef51db7ff2217eccf60a86b7bae82f2421c166 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 27 Sep 2024 19:04:48 -0700 Subject: [PATCH 43/69] fix cli arg --- torchao/prototype/awq/example.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 6af3af5b11..1b854041d7 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -69,7 +69,8 @@ def wikitext2_ppl( device="cuda", precision=torch.bfloat16, sequence_length=2048, - compile=False): + compile=False + scale_store_path=None): print(f"Loading model on {device}...") torch.manual_seed(34) t0 = time.time() @@ -129,7 +130,7 @@ def wikitext2_ppl( parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") parser.add_argument("--seq_len", type=int, default=512, help="Length of examples to calibrate/evaluate model on. Default 512") parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") - parser.add_argument("scale_store_path", type=str, default= None, help="Path to store the scale values.") + parser.add_argument("scale_store_path", type=str, default=None, help="Path to store the scale values.") args = parser.parse_args() From 17e2fbca3a04634911b1a5d636ea56ff691611e2 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 27 Sep 2024 19:04:56 -0700 Subject: [PATCH 44/69] fix cli arg --- torchao/prototype/awq/example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 1b854041d7..dc1e2ba4ac 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -130,7 +130,7 @@ def wikitext2_ppl( parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") parser.add_argument("--seq_len", type=int, default=512, help="Length of examples to calibrate/evaluate model on. Default 512") parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") - parser.add_argument("scale_store_path", type=str, default=None, help="Path to store the scale values.") + parser.add_argument("--scale_store_path", type=str, default=None, help="Path to store the scale values.") args = parser.parse_args() From 71f9e27dc41226184a84d3cb1a3721507289886b Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 27 Sep 2024 19:05:15 -0700 Subject: [PATCH 45/69] add comma --- torchao/prototype/awq/example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index dc1e2ba4ac..4b2208ba58 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -69,7 +69,7 @@ def wikitext2_ppl( device="cuda", precision=torch.bfloat16, sequence_length=2048, - compile=False + compile=False, scale_store_path=None): print(f"Loading model on {device}...") torch.manual_seed(34) From 1716b0c6c8cc298a879f7470619208fbb6a942d6 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 27 Sep 2024 19:05:51 -0700 Subject: [PATCH 46/69] add model to fn call --- torchao/prototype/awq/example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 4b2208ba58..132fe3a713 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -96,7 +96,7 @@ def wikitext2_ppl( print(f"running {quant_dtype} quantization") if scale_store_path is not None: print(f"Saving equalization scales to {scale_store_path}") - save_equalization_scales(scale_store_path) + save_equalization_scales(model, scale_store_path) load_equalization_scales_and_quantize_(model, quant_dtype=quant_dtype, group_size=group_size) else: # use awq_uintx() to apply awq quantization From 436fb9f43692618c3ccf26b621e57b308ff3bb29 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 27 Sep 2024 19:06:47 -0700 Subject: [PATCH 47/69] fix example --- torchao/prototype/awq/example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 132fe3a713..a2f842d012 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -97,7 +97,7 @@ def wikitext2_ppl( if scale_store_path is not None: print(f"Saving equalization scales to {scale_store_path}") save_equalization_scales(model, scale_store_path) - load_equalization_scales_and_quantize_(model, quant_dtype=quant_dtype, group_size=group_size) + load_equalization_scales_and_quantize_(model, scale_store_path, quant_dtype=quant_dtype, group_size=group_size) else: # use awq_uintx() to apply awq quantization is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) From 84b407ca25df9f3b655f5519c46e5f1a5b96050b Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Tue, 1 Oct 2024 15:04:31 -0700 Subject: [PATCH 48/69] refactored awq impl --- test/dtypes/test_uintx.py | 2 +- test/prototype/test_awq.py | 21 +++-- torchao/dtypes/uintx/__init__.py | 1 + torchao/prototype/awq/__init__.py | 5 +- torchao/prototype/awq/api.py | 60 ++---------- torchao/prototype/awq/core.py | 126 +------------------------ torchao/prototype/awq/example.py | 26 +++-- torchao/prototype/awq/layout.py | 151 ++++++++++++++++++++++++++++++ 8 files changed, 193 insertions(+), 199 deletions(-) create mode 100644 torchao/prototype/awq/layout.py diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index 2d689a0c09..bb754135db 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -4,7 +4,7 @@ import torch -from torchao.dtypes.uintx.uintx import to_uintx +from torchao.dtypes.uintx import to_uintx from torchao.quantization.quant_api import quantize_, uintx_weight_only from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index b871669c0f..22d7a811aa 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -6,7 +6,7 @@ from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 if TORCH_VERSION_AT_LEAST_2_3: - from torchao.prototype.awq import insert_awq_observer_, awq_uintx, ObservedLinear, save_equalization_scales, load_equalization_scales_and_quantize_ + from torchao.prototype.awq import insert_awq_observer_, awq_uintx, ObservedLinear class ToyLinearModel(torch.nn.Module): def __init__(self, m=512, n=256, k=128): @@ -53,7 +53,6 @@ def test(device, qdtype, idtype): sequence_length = 5 m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device) - m_save_load = deepcopy(m) m_bf16 = deepcopy(m) dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device) @@ -62,25 +61,27 @@ def test(device, qdtype, idtype): # calibrate insert_awq_observer_(m, n_validation_examples, sequence_length, quant_dtype=quant_dtype, group_size=group_size) - insert_awq_observer_(m_save_load, n_validation_examples, sequence_length, quant_dtype=quant_dtype, group_size=group_size) for example in calibration_data: m(example.to(device)) - m_save_load(example.to(device)) - equalization_scale_path = "equalization_scales.pt" - save_equalization_scales(m_save_load, equalization_scale_path) - load_equalization_scales_and_quantize_(m_save_load, equalization_scale_path) + # quantize is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) quantize_(m, awq_uintx(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear) + + model_save_path = "awq_model.pt" + torch.save(m, model_save_path) + loaded_model = torch.load(model_save_path) + os.remove(model_save_path) + m = torch.compile(m, fullgraph=True) - m_save_load = torch.compile(m_save_load, fullgraph=True) + loaded_model = torch.compile(loaded_model, fullgraph=True) awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) - awq_save_load_out = torch.cat([m_save_load(i.squeeze(0)) for i in dataset]) + awq_save_load_out = torch.cat([loaded_model(i.squeeze(0)) for i in dataset]) assert awq_out is not None assert awq_save_load_out is not None assert torch.allclose(awq_out, awq_save_load_out, atol = 1e-2) # remove equalization scale file - os.remove(equalization_scale_path) + diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index e69de29bb2..ad9166079c 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -0,0 +1 @@ +from .uintx import UintxTensor, UintxLayoutType, UintxAQTLayout, to_uintx, _DTYPE_TO_BIT_WIDTH \ No newline at end of file diff --git a/torchao/prototype/awq/__init__.py b/torchao/prototype/awq/__init__.py index 6be48a9cb8..1d26e0ef21 100644 --- a/torchao/prototype/awq/__init__.py +++ b/torchao/prototype/awq/__init__.py @@ -1,2 +1,3 @@ -from .api import insert_awq_observer_, awq_uintx, save_equalization_scales, load_equalization_scales_and_quantize_ -from .core import ObservedLinear \ No newline at end of file +from .api import insert_awq_observer_, awq_uintx +from .core import ObservedLinear +from .layout import to_weight_tensor_with_equalization_scales \ No newline at end of file diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 6b3b41947d..0645e8040d 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -1,20 +1,21 @@ +from typing import Dict, Optional + import torch import torch.nn.functional as F -import json + from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, _DTYPE_TO_QVALUE_BOUNDS, ) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter +from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType from torchao.dtypes import to_affine_quantized_intx -from torchao.dtypes.uintx.uintx import _DTYPE_TO_BIT_WIDTH -from torchao.prototype.awq.core import( +from .core import( AWQObserver, ObservedLinear, - AwqLayoutType ) -from typing import Dict, Optional +from .layout import to_weight_tensor_with_equalization_scales assert len(_DTYPE_TO_BIT_WIDTH) > 0, "Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+" @@ -80,45 +81,6 @@ def insert_subclass(observed_linear): return insert_subclass -def save_equalization_scales(model: torch.nn.Module, save_path: str) -> Dict[str, torch.Tensor]: - result = {} - - def recurse(module: torch.nn.Module, name: str = ''): - for child_name, child in module.named_children(): - full_name = f"{name}.{child_name}" if name else child_name - - # Apply the analysis function to this layer - if isinstance(child, ObservedLinear): - result[full_name] = child.act_obs.calculate_qparams() - - # Recurse into child modules - recurse(child, full_name) - - recurse(model) - - torch.save(result, save_path) - -def load_equalization_scales_and_quantize_(model: torch.nn.Module, equalization_scale_path: str, quant_dtype: torch.dtype = torch.uint4, group_size: int = 128, device=None) -> torch.nn.Module: - equalization_scales = torch.load(equalization_scale_path) - - def recurse(module: torch.nn.Module, name: str = ''): - if isinstance(module, ObservedLinear): - module.equalization_scale = equalization_scales[name] - if device is not None: - module.to(device=device) # move to device before quantization - module = awq_uintx(quant_dtype, group_size)(module) - else: - for child_name, child in module.named_children(): - full_name = f"{name}.{child_name}" if name else child_name - new_child = recurse(child, full_name) - if new_child is not child: - setattr(model, full_name, new_child) - if device is not None: - module.to(device=device) - return module - - recurse(model) - def awq_uintx(quant_dtype: torch.dtype = torch.uint4, group_size: int = 128): """ @@ -133,10 +95,7 @@ def awq_uintx(quant_dtype: torch.dtype = torch.uint4, group_size: int = 128): def weight_quant_func(observed_linear): # weight quantization # AQT config - if observed_linear.equalization_scale is None: - equalization_scale = observed_linear.act_obs.calculate_qparams() - else: - equalization_scale = observed_linear.equalization_scale + equalization_scale = observed_linear.act_obs.calculate_qparams() target_dtype = torch.uint8 mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) @@ -146,9 +105,9 @@ def weight_quant_func(observed_linear): preserve_zero = True zero_point_dtype = torch.int64 zero_point_domain = ZeroPointDomain.INT - layout_type = AwqLayoutType(quant_dtype, equalization_scale) + layout_type = UintxLayoutType(quant_dtype) - return to_affine_quantized_intx( + qw = to_affine_quantized_intx( observed_linear.weight * equalization_scale, mapping_type, block_size, target_dtype, quant_min, @@ -157,6 +116,7 @@ def weight_quant_func(observed_linear): preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type) + return to_weight_tensor_with_equalization_scales(qw, equalization_scale) return _observed_linear_subclass_inserter(weight_quant_func) diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index 6e65157e8d..a03747c5ec 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -3,16 +3,10 @@ import torch import torch.nn.functional as F -from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.dtypes.uintx.uintx import to_uintx -from torchao.dtypes.affine_quantized_tensor import ( - to_affine_quantized_intx, - LayoutType, - register_layout_cls, - AQTLayout, - register_aqt_quantized_linear_dispatch -) +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType +from torchao.dtypes import to_affine_quantized_intx from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, @@ -115,7 +109,7 @@ def calculate_qparams(self): ratio = i * 1 / self.scale_options scales = self.average.pow(ratio).to(self.weight.dtype) scales = scales / (scales.max() * scales.min()).sqrt() - layout = AwqLayoutType(self.target_dtype, scales) + layout = UintxLayoutType(self.target_dtype) # regardless of weight dtype, we have to store as packed uint8 tensors tensor_dtype = torch.uint8 w = to_affine_quantized_intx( @@ -148,7 +142,6 @@ class ObservedLinear(torch.nn.Linear): def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None): super().__init__(in_features, out_features, bias, device, dtype) self.act_obs = act_obs - self.equalization_scale = None def forward(self, input: torch.Tensor): output = F.linear(input, self.weight, self.bias) @@ -160,113 +153,4 @@ def from_float(cls, float_linear: torch.nn.Linear, act_obs: AWQObserver): observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype) observed_linear.weight = float_linear.weight observed_linear.bias = float_linear.bias - return observed_linear - - -@dataclass(frozen=True) -class AwqLayoutType(LayoutType): - dtype: torch.dtype - equalization_scale: torch.Tensor - - def post_process(self, input: torch.Tensor) -> torch.Tensor: - # pack weights for sub dtype bit size - if self.dtype != torch.uint8: - return to_uintx(input, self.dtype) - return input - -def _quantized_linear_impl(input_tensor, weight_tensor, bias): - # divide activations by awq scales - return F.linear(input_tensor / weight_tensor.layout_tensor.equalization_scale, weight_tensor.dequantize(), bias) - -def _linear_awq_check(input_tensor, weight_tensor, bias): - return isinstance(weight_tensor.layout_tensor, AwqAQTLayout) - -register_aqt_quantized_linear_dispatch(_linear_awq_check, _quantized_linear_impl) - -@register_layout_cls(AwqLayoutType) -class AwqAQTLayout(AQTLayout): - @staticmethod - def __new__( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - equalization_scale: torch.Tensor, - layout_type: LayoutType, - ): - kwargs = {} - kwargs["device"] = int_data.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout - ) - kwargs["dtype"] = int_data.dtype - kwargs["requires_grad"] = False - shape = int_data.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - equalization_scale: torch.Tensor, - layout_type: LayoutType, - ): - self.int_data = int_data - self.scale = scale - self.zero_point = zero_point - self.equalization_scale = equalization_scale - self.layout_type = layout_type - - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # unpack if needed - w = self.int_data if self.layout_type.dtype == torch.uint8 else self.int_data.get_plain() - return w, self.scale, self.zero_point - - def __tensor_flatten__(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return ["int_data", "scale", "zero_point", "equalization_scale"], [self.layout_type] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] - equalization_scale = tensor_data_dict["equalization_scale"] - layout_type, = tensor_attributes - return cls(int_data, scale, zero_point, equalization_scale, layout_type) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is torch.ops.aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - raise NotImplementedError( - f"AwqAQTLayout dispatch: attempting to run {func}, this is not supported" - ) - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - layout_type: LayoutType, - ): - assert isinstance(layout_type, AwqLayoutType) - return cls(int_data, scale, zero_point, layout_type.equalization_scale, layout_type) - - def get_layout_type(self) -> LayoutType: - return self.layout_type - - def _apply_fn_to_data(self, fn): - self.int_data = fn(self.int_data) - self.scale = fn(self.scale) - self.zero_point = fn(self.zero_point) - self.equalization_scale = fn(self.equalization_scale) - return self - -to_awq = AwqAQTLayout.from_plain \ No newline at end of file + return observed_linear \ No newline at end of file diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index a2f842d012..e46846bbce 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -4,7 +4,7 @@ from datasets import load_dataset from tqdm import tqdm import time -from torchao.prototype.awq import insert_awq_observer_, ObservedLinear, awq_uintx, save_equalization_scales, load_equalization_scales_and_quantize_ +from torchao.prototype.awq import insert_awq_observer_, ObservedLinear, awq_uintx, from torchao.quantization import quantize_, int4_weight_only, uintx_weight_only @@ -70,7 +70,7 @@ def wikitext2_ppl( precision=torch.bfloat16, sequence_length=2048, compile=False, - scale_store_path=None): + model_save_path=None): print(f"Loading model on {device}...") torch.manual_seed(34) t0 = time.time() @@ -94,21 +94,17 @@ def wikitext2_ppl( print(f"time for calibration: {time.time() - t0:.02f} seconds") print(f"running {quant_dtype} quantization") - if scale_store_path is not None: - print(f"Saving equalization scales to {scale_store_path}") - save_equalization_scales(model, scale_store_path) - load_equalization_scales_and_quantize_(model, scale_store_path, quant_dtype=quant_dtype, group_size=group_size) - else: - # use awq_uintx() to apply awq quantization - is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) - t0 = time.time() - quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) + # use awq_uintx() to apply awq quantization + is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) + t0 = time.time() + quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) print(f"time for quantization: {time.time() - t0:.02f} seconds") - + if model_save_path is not None: + print(f"Saving model to {model_save_path}") + torch.save(model, model_save_path) elif quant=="int4": print("running int4 quantization") - # quantize_(model, uintx_weight_only(torch.uint4, group_size=64)) quantize_(model, int4_weight_only(group_size=group_size)) if compile: @@ -130,7 +126,7 @@ def wikitext2_ppl( parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") parser.add_argument("--seq_len", type=int, default=512, help="Length of examples to calibrate/evaluate model on. Default 512") parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") - parser.add_argument("--scale_store_path", type=str, default=None, help="Path to store the scale values.") + parser.add_argument("--model_save_path", type=str, default=None, help="Path to store the scale values.") args = parser.parse_args() @@ -146,7 +142,7 @@ def wikitext2_ppl( precision=precision_dtype, sequence_length=args.seq_len, compile=args.compile, - scale_store_path=args.scale_store_path + scale_store_path=args.model_save_path ) print(f"{args.quant} Perplexity: {ppl.item():.5f}") \ No newline at end of file diff --git a/torchao/prototype/awq/layout.py b/torchao/prototype/awq/layout.py new file mode 100644 index 0000000000..d4a1f20cf1 --- /dev/null +++ b/torchao/prototype/awq/layout.py @@ -0,0 +1,151 @@ +import torch +from typing import Callable, Optional, Dict, Any +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.utils import ( + TorchAOBaseTensor, + TORCH_VERSION_AT_LEAST_2_5, +) +from torchao.dtypes import AffineQuantizedTensor + + +aten = torch.ops.aten + + +class WeightTensorWithEqualizationScales(TorchAOBaseTensor): + """ + Tensor subclass that wraps a quantized weight tensor and provides the equalization scales which are applied to activations. + + Args: + quantized_weight_tensor (torch.Tensor): The weight tensor to be wrapped. + scale (torch.Tensor): The scale tensor for activation quantization. + zero_point (Optional[torch.Tensor]): The zero point tensor for activation quantization. Default is None. + equalization_scale (torch.Tensor): The equalization scale tensor. + """ + + quantized_weight_tensor: AffineQuantizedTensor + equalization_scale: torch.Tensor + + def __new__( + cls, + quantized_weight_tensor: torch.Tensor, + equalization_scale: torch.Tensor + ): + kwargs = {} + dtype = quantized_weight_tensor.dtype + kwargs["dtype"] = dtype + kwargs["requires_grad"] = False + kwargs["device"] = quantized_weight_tensor.device + shape = quantized_weight_tensor.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + quantized_weight_tensor: torch.Tensor, + equalization_scale: torch.Tensor + ): + self.quantized_weight_tensor = quantized_weight_tensor + self.equalization_scale = equalization_scale + + def __repr__(self): + return f"LinearActivationQuantizedTensor({self.quantized_weight_tensor}, eq_scale={self.equalization_scale})" + + def __tensor_flatten__(self): + tensor_data = ["quantized_weight_tensor", "equalization_scale"] + return tensor_data, [] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + quantized_weight_tensor = tensor_data_dict["quantized_weight_tensor"] + equalization_scale = tensor_data_dict["equalization_scale"] + return cls( + quantized_weight_tensor, + equalization_scale, + ) + + @staticmethod + def _quantized_linear_op( + input_tensor: torch.Tensor, weight_tensor: torch.Tensor, bias: torch.Tensor + ): + return torch.nn.functional.linear( + input_tensor / weight_tensor.equalization_scale, weight_tensor.quantized_weight_tensor.dequantize(), bias + ) + + @classmethod + def from_quantized( + cls, + quantized_weight_tensor: AffineQuantizedTensor, + equalization_scale: torch.Tensor + ): + return cls(quantized_weight_tensor, equalization_scale) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.quantized_weight_tensor), + fn(self.equalization_scale), + ) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs.pop("device") + return self.__class__( + self.quantized_weight_tensor.to(device), + self.equalization_scale.to(device), + ) + + +implements = WeightTensorWithEqualizationScales.implements + + +@implements(torch.nn.functional.linear) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + if isinstance(weight_tensor, WeightTensorWithEqualizationScales): + return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) + + raise NotImplementedError( + "LinearActivationQuantizedTensor: No specialized dispatch found for linear op" + ) + + +@implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +@implements(aten._to_copy.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + + +@implements(aten.t.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.t) + ) + +to_weight_tensor_with_equalization_scales = WeightTensorWithEqualizationScales.from_quantized +if TORCH_VERSION_AT_LEAST_2_5: + # Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` + torch.serialization.add_safe_globals( + [WeightTensorWithEqualizationScales] + ) \ No newline at end of file From 1216f97f1a0da31476bc770e3c7c1973bc57c61a Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Tue, 1 Oct 2024 17:48:53 -0700 Subject: [PATCH 49/69] edits +update usage --- scripts/hf_eval.py | 38 +++++++++--------- test/prototype/test_awq.py | 4 +- torchao/_models/_eval.py | 9 ----- torchao/_models/llama/eval.py | 22 ----------- torchao/_models/llama/generate.py | 19 +++++---- torchao/dtypes/affine_quantized_tensor.py | 3 +- torchao/prototype/awq/__init__.py | 2 +- torchao/prototype/awq/api.py | 47 +++++++++++++---------- torchao/prototype/awq/core.py | 14 +++---- torchao/prototype/awq/example.py | 4 +- torchao/prototype/awq/layout.py | 4 +- 11 files changed, 73 insertions(+), 93 deletions(-) diff --git a/scripts/hf_eval.py b/scripts/hf_eval.py index 43bdb10fca..2a971d23b3 100644 --- a/scripts/hf_eval.py +++ b/scripts/hf_eval.py @@ -65,25 +65,25 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, spars elif quantization == "autoquant": model = autoquant(model.to(device=device)) elif quantization == "awq": - from datasets import load_dataset - from tqdm import tqdm - from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer, awq_quant - - insert_awq_observer(model, precision, device) - wikitext103 = load_dataset("wikitext", "wikitext-103-v1") - wikitext103_train = wikitext103["train"] - wikitext103_calibration = wikitext103_train.select(range(1)) - calibration_input_ids = [tokenizer.encode(text, return_tensors="pt") for text in wikitext103_calibration["text"]] - model.to(device) - print("running awq calibration") - for i, ids in tqdm(enumerate(calibration_input_ids)): - if ids.shape[-1] == 0: - continue - model(ids.to(device)) - - - is_observed_linear = lambda m, fqn: isinstance(model, ObservedLinear) - quantize_(model, awq_quant, is_observed_linear) + from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + from torchao.prototype.awq.example import get_calib_dataset + if not TORCH_VERSION_AT_LEAST_2_3: + print("AWQ quantization requires torch2.3+") + exit() + from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear + quant_dtype = torch.uint4 + group_size = 64 + calibration_limit = 10 + calibration_seq_length = 1024 + model=model.to(device) + insert_awq_observer_(model,calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size) + with torch.no_grad(): + calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibration_limit, block_size=calibration_seq_length) + for batch in calibration_data: + model(batch.to(device)) + batch.to("cpu") + is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) + quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) if quantization != "autoquant" and compile: model = torch.compile(model, mode= "max-autotune", fullgraph=True) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 22d7a811aa..475d3e2f50 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -6,7 +6,7 @@ from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 if TORCH_VERSION_AT_LEAST_2_3: - from torchao.prototype.awq import insert_awq_observer_, awq_uintx, ObservedLinear + from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear class ToyLinearModel(torch.nn.Module): def __init__(self, m=512, n=256, k=128): @@ -67,7 +67,7 @@ def test(device, qdtype, idtype): # quantize - is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) + is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) quantize_(m, awq_uintx(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear) model_save_path = "awq_model.pt" diff --git a/torchao/_models/_eval.py b/torchao/_models/_eval.py index cd2452635b..d2f8088d95 100644 --- a/torchao/_models/_eval.py +++ b/torchao/_models/_eval.py @@ -287,15 +287,6 @@ def _model_call(self, inps): return torch.randn( (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device ) - - def _model_call(self, inps): - input = self.input_prep_func(inps.to(self._device)) - - max_seq_length = min(max(inps.size()), self.max_length) - with torch.device(self._device): - self.model_.setup_caches(self.batch_size, max_seq_length) - logits = self.model_(*input) - return logits # pad or truncate to the right size if T >= self.calibration_seq_length: diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 11ee9e6209..a392df0f40 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -119,27 +119,6 @@ def run_evaluation( quantizer = Int4WeightOnlyGPTQQuantizer(groupsize=groupsize, device=device) model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) model = quantizer.quantize(model, inputs).to(device) - - if "awq" in quantization: - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - from torchao.prototype.awq.example import get_calib_dataset - if not TORCH_VERSION_AT_LEAST_2_3: - print("Awq requires torch2.3+") - exit() - from torchao.prototype.awq import insert_awq_observer_, awq_uintx, ObservedLinear - quant_dtype = quantization.split("-")[1] - group_size = int(quantization.split("-")[2]) - quant_dtype = getattr(torch, quant_dtype, torch.uint8) - model=model.to(device) - # get calibration data - insert_awq_observer_(model,calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size) - with torch.no_grad(): - calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibration_limit, block_size=calibration_seq_length) - for batch in calibration_data: - model(batch.to(device)) - batch.to("cpu") - is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) - quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) else: if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(model) @@ -231,7 +210,6 @@ def run_evaluation( "int4wo--gptq, autoquant, autoquant-int4, int4wo--hqq, " "uintx--, uintx---hqq, sparse-marlin, " "autoround---------" - "awq-uint-" ), ) parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 3f9527c151..1a4d046196 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -238,18 +238,23 @@ def main( if not TORCH_VERSION_AT_LEAST_2_3: print("Awq requires torch2.3+") exit() - from torchao.prototype.awq import insert_awq_observer_, awq_uintx, ObservedLinear + from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear quant_dtype = quantization.split("-")[1] group_size = int(quantization.split("-")[2]) quant_dtype = getattr(torch, quant_dtype, torch.uint8) model=model.to(device) # get calibration data - insert_awq_observer_(model,calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size) - with torch.no_grad(): - calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibration_limit, block_size=calibration_seq_length) - for batch in calibration_data: - model(batch.to(device)) - batch.to("cpu") + insert_awq_observer_(model, calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size) + TransformerEvalWrapper( + model=model.to(device), + tokenizer=tokenizer, + max_seq_length=calibration_seq_length, + input_prep_func=prepare_inputs_for_model, + device=device, + ).run_eval( + tasks=['wikitext'], + limit=calibration_limit, + ) is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) if "uintx" in quantization: diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 60ec35d6e9..3dc632cd0e 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -28,14 +28,12 @@ is_device, get_out_shape, ) - from torchao.float8.inference import ( preprocess_data, Float8MMConfig, addmm_float8_unwrapped_inference, _is_rowwise_scaled ) - from torch.utils._python_dispatch import is_traceable_wrapper_subclass from dataclasses import dataclass from torchao.utils import ( @@ -1463,6 +1461,7 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): f"need input_tensor shape: {input_tensor.shape} final" f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " ) + # TODO: check groupsize quantization # avoid circular dep, TODO: move this to a common util.py act_mat = input_tensor diff --git a/torchao/prototype/awq/__init__.py b/torchao/prototype/awq/__init__.py index 1d26e0ef21..6ba1ffe694 100644 --- a/torchao/prototype/awq/__init__.py +++ b/torchao/prototype/awq/__init__.py @@ -1,3 +1,3 @@ from .api import insert_awq_observer_, awq_uintx -from .core import ObservedLinear +from .core import AWQObservedLinear from .layout import to_weight_tensor_with_equalization_scales \ No newline at end of file diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 0645e8040d..bc7cb74dbe 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Dict, Optional, Callable import torch import torch.nn.functional as F @@ -8,12 +8,13 @@ ZeroPointDomain, _DTYPE_TO_QVALUE_BOUNDS, ) +from torchao.quantization.observer import PerGroup from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType from torchao.dtypes import to_affine_quantized_intx from .core import( AWQObserver, - ObservedLinear, + AWQObservedLinear, ) from .layout import to_weight_tensor_with_equalization_scales @@ -36,7 +37,7 @@ def insert_awq_observer_(model: torch.nn.Module, n_validation_examples: int, val # AQT config mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) + quantization_granularity = PerGroup(group_size) quant_min = 0 quant_max = 255 if quant_dtype == torch.uint8 else 2 ** _DTYPE_TO_BIT_WIDTH[quant_dtype] - 1 eps = torch.finfo(torch.float32).eps @@ -46,11 +47,11 @@ def insert_awq_observer_(model: torch.nn.Module, n_validation_examples: int, val assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" def replace_with_observer(layer): - # creates observer and replaces linear layers with observed linear layers + # creates observer and replaces linear layers with AWQObservedLinear layers observer = AWQObserver( layer.weight, layer.bias, - block_size, + quantization_granularity, mapping_type, quant_dtype, n_validation_examples, @@ -62,15 +63,15 @@ def replace_with_observer(layer): quant_min=quant_min, quant_max = quant_max, eps = eps) - return ObservedLinear.from_float(layer, observer) + return AWQObservedLinear.from_float(layer, observer) _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) def _observed_linear_subclass_inserter(constructor): """ - Replaces unquantized observed linear instances with quantized linear instances. + Replaces unquantized AWQObservedLinear instances with quantized linear instances. Args: - constructor: the function which applies quantization to the observed linear layer + constructor: the function which applies quantization to the AWQObservedLinear layer """ def insert_subclass(observed_linear): # creates the new linear layer using constructor @@ -82,13 +83,16 @@ def insert_subclass(observed_linear): return insert_subclass -def awq_uintx(quant_dtype: torch.dtype = torch.uint4, group_size: int = 128): +def awq_uintx(quant_dtype: torch.dtype = torch.uint4, + group_size: int = 128, + weight_quant_fn: Optional[Callable[[torch.Tensor], torch.Tensor]]= None): """ Quantizes linear layers when passed into quantize_() Args: quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8 group_size: Quantization granularity. Use -1 for channel wise quantization + weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used """ assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" @@ -98,7 +102,7 @@ def weight_quant_func(observed_linear): equalization_scale = observed_linear.act_obs.calculate_qparams() target_dtype = torch.uint8 mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) + quantization_granularity = PerGroup(group_size) quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0] quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1] eps = torch.finfo(torch.float32).eps @@ -106,16 +110,19 @@ def weight_quant_func(observed_linear): zero_point_dtype = torch.int64 zero_point_domain = ZeroPointDomain.INT layout_type = UintxLayoutType(quant_dtype) - - qw = to_affine_quantized_intx( - observed_linear.weight * equalization_scale, - mapping_type, block_size, - target_dtype, quant_min, - quant_max, eps, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - layout_type=layout_type) + if weight_quant_fn is not None: + qw = weight_quant_fn(observed_linear.weight * equalization_scale) + else: + # usage according to original paper + qw = to_affine_quantized_intx( + observed_linear.weight * equalization_scale, + mapping_type, (1, quantization_granularity.group_size), + target_dtype, quant_min, + quant_max, eps, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + layout_type=layout_type) return to_weight_tensor_with_equalization_scales(qw, equalization_scale) return _observed_linear_subclass_inserter(weight_quant_func) diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index a03747c5ec..77810a2e4a 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -12,7 +12,7 @@ ZeroPointDomain, ) from torchao.quantization.observer import ( - AffineQuantizedObserverBase, PerGroup + AffineQuantizedObserverBase, GranularityType ) @@ -20,7 +20,7 @@ class AWQObserver(AffineQuantizedObserverBase): def __init__(self, weight: torch.Tensor, bias: torch.Tensor, - block_size: Tuple, + quantization_granularity: GranularityType, mapping_type: MappingType, target_dtype: torch.dtype, n_validation_examples: int, @@ -40,7 +40,7 @@ def __init__(self, Args: weight: The weight tensor to be observed. bias: The bias tensor to be observed. - block_size: The weight tensor shape after being reshaped to support per group quantization + quantization_granularity: Granularity type which specifies how many weights share the same scale/zero point input_dtype: The data type of the input tensor. mapping_type: Always set to asymmetric target_dtype: The target data type of the quantized tensor @@ -59,7 +59,7 @@ def __init__(self, super().__init__( mapping_type, target_dtype, - PerGroup(block_size[-1]), + quantization_granularity, quant_min = quant_min, quant_max = quant_max, eps = eps, @@ -68,7 +68,7 @@ def __init__(self, preserve_zero = preserve_zero, zero_point_domain = zero_point_domain, ) - self.block_size = block_size + self.quantization_granularity = quantization_granularity self.weight = weight self.bias = bias self.n_validation_examples = n_validation_examples @@ -115,7 +115,7 @@ def calculate_qparams(self): w = to_affine_quantized_intx( self.weight*scales, self.mapping_type, - self.block_size, + (1, self.quantization_granularity.group_size), tensor_dtype, quant_min = self.quant_min, quant_max = self.quant_max, @@ -138,7 +138,7 @@ def calculate_qparams(self): self.outputs[i].to("cpu") return best_scales.detach() -class ObservedLinear(torch.nn.Linear): +class AWQObservedLinear(torch.nn.Linear): def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None): super().__init__(in_features, out_features, bias, device, dtype) self.act_obs = act_obs diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index e46846bbce..f68bd26194 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -4,7 +4,7 @@ from datasets import load_dataset from tqdm import tqdm import time -from torchao.prototype.awq import insert_awq_observer_, ObservedLinear, awq_uintx, +from torchao.prototype.awq import insert_awq_observer_, AWQObservedLinear, awq_uintx, from torchao.quantization import quantize_, int4_weight_only, uintx_weight_only @@ -95,7 +95,7 @@ def wikitext2_ppl( print(f"running {quant_dtype} quantization") # use awq_uintx() to apply awq quantization - is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) + is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) t0 = time.time() quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) diff --git a/torchao/prototype/awq/layout.py b/torchao/prototype/awq/layout.py index d4a1f20cf1..90bb8c8c30 100644 --- a/torchao/prototype/awq/layout.py +++ b/torchao/prototype/awq/layout.py @@ -22,7 +22,7 @@ class WeightTensorWithEqualizationScales(TorchAOBaseTensor): equalization_scale (torch.Tensor): The equalization scale tensor. """ - quantized_weight_tensor: AffineQuantizedTensor + quantized_weight_tensor: TorchAOBaseTensor equalization_scale: torch.Tensor def __new__( @@ -69,7 +69,7 @@ def _quantized_linear_op( input_tensor: torch.Tensor, weight_tensor: torch.Tensor, bias: torch.Tensor ): return torch.nn.functional.linear( - input_tensor / weight_tensor.equalization_scale, weight_tensor.quantized_weight_tensor.dequantize(), bias + input_tensor / weight_tensor.equalization_scale, weight_tensor.quantized_weight_tensor, bias ) @classmethod From 39306601c9ee1e59dcef02c12bb57fe0d519c984 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Tue, 1 Oct 2024 19:56:15 -0700 Subject: [PATCH 50/69] perplexity evals added --- torchao/_models/llama/eval.py | 2 +- torchao/prototype/awq/api.py | 27 +++++++++++++++------------ torchao/prototype/awq/example.py | 30 ++++++++++++++++++++++++++---- torchao/prototype/awq/readme.md | 28 +++++++++++++++++++--------- 4 files changed, 61 insertions(+), 26 deletions(-) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index a392df0f40..cc80d40db2 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -112,7 +112,7 @@ def run_evaluation( model.config.vocab_size, device="cpu" ).record_inputs( - ["pile_hackernews"], + calibration_tasks, calibration_limit, ).get_inputs() diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index bc7cb74dbe..1e29c25a23 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -84,7 +84,7 @@ def insert_subclass(observed_linear): def awq_uintx(quant_dtype: torch.dtype = torch.uint4, - group_size: int = 128, + group_size: int = 64, weight_quant_fn: Optional[Callable[[torch.Tensor], torch.Tensor]]= None): """ Quantizes linear layers when passed into quantize_() @@ -95,25 +95,28 @@ def awq_uintx(quant_dtype: torch.dtype = torch.uint4, weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used """ - assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" + def weight_quant_func(observed_linear): # weight quantization # AQT config equalization_scale = observed_linear.act_obs.calculate_qparams() - target_dtype = torch.uint8 - mapping_type = MappingType.ASYMMETRIC - quantization_granularity = PerGroup(group_size) - quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0] - quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1] - eps = torch.finfo(torch.float32).eps - preserve_zero = True - zero_point_dtype = torch.int64 - zero_point_domain = ZeroPointDomain.INT - layout_type = UintxLayoutType(quant_dtype) if weight_quant_fn is not None: qw = weight_quant_fn(observed_linear.weight * equalization_scale) else: # usage according to original paper + assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" + + target_dtype = torch.uint8 + mapping_type = MappingType.ASYMMETRIC + quantization_granularity = PerGroup(group_size) + quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0] + quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1] + eps = torch.finfo(torch.float32).eps + preserve_zero = True + zero_point_dtype = torch.int64 + zero_point_domain = ZeroPointDomain.INT + layout_type = UintxLayoutType(quant_dtype) + qw = to_affine_quantized_intx( observed_linear.weight * equalization_scale, mapping_type, (1, quantization_granularity.group_size), diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index f68bd26194..c70b1f5d17 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -4,7 +4,7 @@ from datasets import load_dataset from tqdm import tqdm import time -from torchao.prototype.awq import insert_awq_observer_, AWQObservedLinear, awq_uintx, +from torchao.prototype.awq import insert_awq_observer_, AWQObservedLinear, awq_uintx from torchao.quantization import quantize_, int4_weight_only, uintx_weight_only @@ -86,6 +86,7 @@ def wikitext2_ppl( t0 = time.time() # insert observers to find average magnitude and calculate scales + insert_awq_observer_(model,validation_size, sequence_length, quant_dtype=quant_dtype, group_size=group_size) calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibration_size, block_size=sequence_length) for batch in calibration_data: @@ -96,8 +97,27 @@ def wikitext2_ppl( print(f"running {quant_dtype} quantization") # use awq_uintx() to apply awq quantization is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) + + from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, + _DTYPE_TO_QVALUE_BOUNDS, + ) + from torchao.dtypes import to_affine_quantized_intx, TensorCoreTiledLayoutType t0 = time.time() - quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) + def hqqint4(weight): + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + preserve_zero = False + zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.FLOAT + + return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8), use_hqq=True) + quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, weight_quant_fn=hqqint4), is_observed_linear) print(f"time for quantization: {time.time() - t0:.02f} seconds") if model_save_path is not None: @@ -106,7 +126,9 @@ def wikitext2_ppl( elif quant=="int4": print("running int4 quantization") quantize_(model, int4_weight_only(group_size=group_size)) - + elif quant=="hqq": + print("running int4-hqq quantization") + quantize_(model,int4_weight_only(group_size=group_size, use_hqq=True)) if compile: model = torch.compile(model) @@ -142,7 +164,7 @@ def wikitext2_ppl( precision=precision_dtype, sequence_length=args.seq_len, compile=args.compile, - scale_store_path=args.model_save_path + model_save_path=args.model_save_path ) print(f"{args.quant} Perplexity: {ppl.item():.5f}") \ No newline at end of file diff --git a/torchao/prototype/awq/readme.md b/torchao/prototype/awq/readme.md index 502da11fe7..9bf5e1bfe3 100644 --- a/torchao/prototype/awq/readme.md +++ b/torchao/prototype/awq/readme.md @@ -4,13 +4,23 @@ Adapted from https://github.com/mit-han-lab/llm-awq ## Benchmarks Benchmarks are run on a machine with a single RTX 3090 GPU using the script in awq/example.py which calculates perplexity by concatenating wikitex2 test examples with newlines and dividing by context length. Group size of 64 was used for both int4wo and awq-uint4. For awq quantization, c refers to number of calibration sequences and v refers to number of validation sequences. Calibration data is used to find the average magnitude of activations while validation data is used to find optimal equilization scales. Note c is always larger than v. Calibration data came from Pile dataset validation split. -| Model | Quantization | wikitext2-perplexity | -| ------------------ | ------------------------ | ------------------- | -| GPT-2 | Base (bfloat16) | 30.1904 | -| | int4wo (tinygemm kernel) | 519.73108 | -| | awq-uint4 | 485.54907 | -| Llama-2-7b-hf | Base (bfloat16) | 5.47367 | -| | int4wo (tinygemm kernel) | 5.73546 | -| | awq-uint4-c1-v1 | 5.72359 | -| | awq-uint4-c10-v1 | 5.72350 | +| Model | Quantization | Perplexity | +|--------------------|--------------|------------| +| Llama-2-7b-chat-hf | bfloat16 | 5.0309 | +| | awq-uint4 | 5.2388 | +| | int4 | 5.28 | +| | awq-hqq | 5.204 | +| | hqq | 5.3419 | +| Llama-3-8b | bfloat16 | 4.6269 | +| | awq-uint4 | 4.968 | +| | int4 | 5.04325 | +| | awq-hqq | 4.8525 | +| | hqq | 5.1277 | +| Llama-3.1-8b | bfloat16 | 4.69732 | +| | awq-uint4 | 4.98163 | +| | int4 | 5.04091 | +| | awq-hqq | 4.90632 | +| | hqq | 5.14375 | + + From 3e5710c8787687eec5867bd9c712538b4c52cf2d Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Tue, 1 Oct 2024 23:56:56 -0700 Subject: [PATCH 51/69] updated readme with benchmarks --- test/prototype/test_awq.py | 4 ++-- torchao/prototype/awq/example.py | 30 ++++++++++++++++-------------- torchao/prototype/awq/readme.md | 16 ++++++++-------- 3 files changed, 26 insertions(+), 24 deletions(-) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 475d3e2f50..b76b788ec9 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -27,7 +27,7 @@ def forward(self, x): devices = ["cuda"] # torch.uintx dtypes are introduced in 2.3 if TORCH_VERSION_AT_LEAST_2_3: - qdtypes = (torch.uint4,)#torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7, torch.uint8) + qdtypes = (torch.uint3, torch.uint4, torch.uint8) else: qdtypes = () @@ -36,7 +36,7 @@ def run_before_and_after_tests(): yield torch._dynamo.reset() # reset cache between tests -idtypes = (torch.half,)#, torch.half, torch.float32) +idtypes = (torch.half, torch.bfloat16)#, torch.half, torch.float32) @pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("qdtype", qdtypes) @pytest.mark.parametrize("idtype", idtypes) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index c70b1f5d17..b55c4e9c8e 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -105,20 +105,22 @@ def wikitext2_ppl( ) from torchao.dtypes import to_affine_quantized_intx, TensorCoreTiledLayoutType t0 = time.time() - def hqqint4(weight): - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - target_dtype = torch.int32 - quant_min = 0 - quant_max = 15 - eps = 1e-6 - preserve_zero = False - zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT - - return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8), use_hqq=True) - quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, weight_quant_fn=hqqint4), is_observed_linear) - + if "hqq" in quant: + def hqqint4(weight): + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + preserve_zero = False + zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.FLOAT + + return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8), use_hqq=True) + quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, weight_quant_fn=hqqint4), is_observed_linear) + else: + quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) print(f"time for quantization: {time.time() - t0:.02f} seconds") if model_save_path is not None: print(f"Saving model to {model_save_path}") diff --git a/torchao/prototype/awq/readme.md b/torchao/prototype/awq/readme.md index 9bf5e1bfe3..cba3023577 100644 --- a/torchao/prototype/awq/readme.md +++ b/torchao/prototype/awq/readme.md @@ -2,15 +2,15 @@ Adapted from https://github.com/mit-han-lab/llm-awq ## Benchmarks -Benchmarks are run on a machine with a single RTX 3090 GPU using the script in awq/example.py which calculates perplexity by concatenating wikitex2 test examples with newlines and dividing by context length. Group size of 64 was used for both int4wo and awq-uint4. For awq quantization, c refers to number of calibration sequences and v refers to number of validation sequences. Calibration data is used to find the average magnitude of activations while validation data is used to find optimal equilization scales. Note c is always larger than v. Calibration data came from Pile dataset validation split. +Evaluation perplexity numbers were calculated using the script in awq/example.py which calculates perplexity by concatenating wikitex2 test examples with newlines and dividing by context length. Group size of 64 was used for all quantization methods. For Llama-2-7b-chat-hf, performance benchmarks were calculated using the torchao/_models/llama/generate.py script and run on a 1xA100 80GB SXM4 instance. The awq-uint4 quantization method does not use an efficient fused kernel which is why performance is not great. awq-hqq uses tinygemm int4->bf16 kernel + hqq to provide better performance. -| Model | Quantization | Perplexity | -|--------------------|--------------|------------| -| Llama-2-7b-chat-hf | bfloat16 | 5.0309 | -| | awq-uint4 | 5.2388 | -| | int4 | 5.28 | -| | awq-hqq | 5.204 | -| | hqq | 5.3419 | +| Model | Quantization | Perplexity | Tokens/sec | Throughput (GB/sec) | Peak Mem (GB) | Model Size (GB) | +|--------------------|--------------|------------|------------|---------------------|---------------|-----------------| +| Llama-2-7b-chat-hf | bfloat16 | 5.0309 | 107.38 | 1418.93 | 13.88 | 13.21 | +| | awq-uint4 | 5.2388 | 43.59 | 194.93 | 7.31 | 4.47 | +| | int4 | 5.28 | 201.14 | 751.42 | 4.87 | 3.74 | +| | awq-hqq | 5.204 | 196.6 | 761.2 | 5.05 | 3.87 | +| | hqq | 5.3419 | 209.19 | 804.32 | 4.89 | 3.84 | | Llama-3-8b | bfloat16 | 4.6269 | | | awq-uint4 | 4.968 | | | int4 | 5.04325 | From da6a70d1ef34a7701d91e25a08a300b9324f975f Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 2 Oct 2024 00:01:02 -0700 Subject: [PATCH 52/69] add awq-hqq to generate --- torchao/_models/llama/generate.py | 25 +++++++++++++++++++++++-- torchao/prototype/awq/example.py | 13 +++++++------ 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 1a4d046196..cc2f6fe212 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -255,8 +255,29 @@ def main( tasks=['wikitext'], limit=calibration_limit, ) - is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) - quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) + is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) + if "hqq" in quant: + from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, + _DTYPE_TO_QVALUE_BOUNDS, + ) + from torchao.dtypes import to_affine_quantized_intx, TensorCoreTiledLayoutType + def hqqint4(weight): + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + preserve_zero = False + zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.FLOAT + + return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8), use_hqq=True) + quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, weight_quant_fn=hqqint4), is_observed_linear) + else: + quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) if "uintx" in quantization: # uintx-nbits-groupsize, e.g. "uintx-2-64" if "hqq" in quantization: diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index b55c4e9c8e..5122296035 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -98,14 +98,15 @@ def wikitext2_ppl( # use awq_uintx() to apply awq quantization is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) - from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, - _DTYPE_TO_QVALUE_BOUNDS, - ) - from torchao.dtypes import to_affine_quantized_intx, TensorCoreTiledLayoutType + t0 = time.time() if "hqq" in quant: + from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, + _DTYPE_TO_QVALUE_BOUNDS, + ) + from torchao.dtypes import to_affine_quantized_intx, TensorCoreTiledLayoutType def hqqint4(weight): mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) From 7314d993bc861c829b893e091cdb11917ce4f3d1 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 2 Oct 2024 00:11:32 -0700 Subject: [PATCH 53/69] better citation --- torchao/prototype/awq/example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 5122296035..de27bee836 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -8,7 +8,7 @@ from torchao.quantization import quantize_, int4_weight_only, uintx_weight_only -# adapted from: https://github.com/mit-han-lab/llm-awq +# adapted from: https://github.com/mit-han-lab/llm-awq/blob/main/awq/entry.py#L255 def get_calib_dataset(tokenizer=None, n_samples=100, block_size=512): dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") samples = [] From 68d95923c631e8a8e94a7f2d22381be2c2045634 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 2 Oct 2024 22:45:34 -0700 Subject: [PATCH 54/69] nits --- scripts/hf_eval.py | 2 +- test/prototype/test_awq.py | 64 ++++-- torchao/_models/llama/generate.py | 2 +- torchao/prototype/awq/README.md | 29 +++ torchao/prototype/awq/api.py | 8 +- torchao/prototype/awq/example.py | 183 +++++++++++++----- torchao/prototype/awq/readme.md | 26 --- .../quantization/linear_activation_scale.py | 3 +- 8 files changed, 221 insertions(+), 96 deletions(-) create mode 100644 torchao/prototype/awq/README.md delete mode 100644 torchao/prototype/awq/readme.md diff --git a/scripts/hf_eval.py b/scripts/hf_eval.py index 2a971d23b3..1d4935f1b9 100644 --- a/scripts/hf_eval.py +++ b/scripts/hf_eval.py @@ -81,7 +81,7 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, spars calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibration_limit, block_size=calibration_seq_length) for batch in calibration_data: model(batch.to(device)) - batch.to("cpu") + del batch is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index b76b788ec9..d7aa27c011 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -24,10 +24,10 @@ def forward(self, x): x = self.linear3(x) return x -devices = ["cuda"] +devices = ["cpu", "cuda"] # torch.uintx dtypes are introduced in 2.3 if TORCH_VERSION_AT_LEAST_2_3: - qdtypes = (torch.uint3, torch.uint4, torch.uint8) + qdtypes = (torch.uint4, torch.uint7) else: qdtypes = () @@ -36,13 +36,12 @@ def run_before_and_after_tests(): yield torch._dynamo.reset() # reset cache between tests -idtypes = (torch.half, torch.bfloat16)#, torch.half, torch.float32) +idtypes = (torch.half, torch.bfloat16) @pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("qdtype", qdtypes) @pytest.mark.parametrize("idtype", idtypes) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3,reason="torch.uint(2-7) requires torch2.3+") -def test(device, qdtype, idtype): +def test_awq_loading(device, qdtype, idtype): dataset_size = 100 l1,l2,l3 = 512,256,128 original_dtype = idtype @@ -53,11 +52,8 @@ def test(device, qdtype, idtype): sequence_length = 5 m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device) - m_bf16 = deepcopy(m) - dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device) calibration_data = dataset[:n_calibration_examples] - bf16_out = torch.cat([m_bf16(i.squeeze(0)) for i in dataset], dim=0) # calibrate insert_awq_observer_(m, n_validation_examples, sequence_length, quant_dtype=quant_dtype, group_size=group_size) @@ -70,18 +66,62 @@ def test(device, qdtype, idtype): is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) quantize_(m, awq_uintx(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear) - model_save_path = "awq_model.pt" + model_save_path = "awq_model.pth" torch.save(m, model_save_path) loaded_model = torch.load(model_save_path) os.remove(model_save_path) - m = torch.compile(m, fullgraph=True) - loaded_model = torch.compile(loaded_model, fullgraph=True) + if device == "cuda": + m = torch.compile(m, fullgraph=True) + loaded_model = torch.compile(loaded_model, fullgraph=True) + awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) awq_save_load_out = torch.cat([loaded_model(i.squeeze(0)) for i in dataset]) assert awq_out is not None assert awq_save_load_out is not None assert torch.allclose(awq_out, awq_save_load_out, atol = 1e-2) - # remove equalization scale file + +@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3,reason="torch.uint(2-7) requires torch2.3+") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_save_weights_only(): + dataset_size = 100 + l1,l2,l3 = 512,256,128 + original_dtype = torch.half + quant_dtype = torch.uint4 + device = "cuda" + group_size = 128 + n_calibration_examples = 10 + n_validation_examples = 10 + sequence_length = 5 + + m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device) + m2 = deepcopy(m) + dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device) + calibration_data = dataset[:n_calibration_examples] + + # calibrate + insert_awq_observer_(m, n_validation_examples, sequence_length, quant_dtype=quant_dtype, group_size=group_size) + + for example in calibration_data: + m(example.to(device)) + + + # quantize + is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) + quantize_(m, awq_uintx(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear) + + model_save_path = "awq_model.pth" + torch.save(m.state_dict(), model_save_path) + m2.load_state_dict(torch.load(model_save_path), assign=True) # load weights only.torch.load(model_save_path) + os.remove(model_save_path) + m = torch.compile(m, fullgraph=True) + m2 = torch.compile(m2, fullgraph=True) + + awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) + awq_save_load_out = torch.cat([m2(i.squeeze(0)) for i in dataset]) + + assert awq_out is not None + assert awq_save_load_out is not None + assert torch.allclose(awq_out, awq_save_load_out, atol = 1e-2) \ No newline at end of file diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index a78baa80f5..3e3fde7040 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -256,7 +256,7 @@ def main( limit=calibration_limit, ) is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) - if "hqq" in quant: + if "hqq" in quant and quant_dtype == torch.uint4: from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, diff --git a/torchao/prototype/awq/README.md b/torchao/prototype/awq/README.md new file mode 100644 index 0000000000..e7b7f782f7 --- /dev/null +++ b/torchao/prototype/awq/README.md @@ -0,0 +1,29 @@ +# AWQ Quantization +Adapted from https://github.com/mit-han-lab/llm-awq + +## Benchmarks +Evaluation perplexity numbers were calculated using the script in awq/example.py Group size of 64 was used for all quantization methods. For Llama-2-7b-chat-hf, performance benchmarks were calculated using the torchao/_models/llama/generate.py script and run on a 1xA100 80GB SXM4 instance. The awq-uint4 quantization method does not use an efficient fused kernel which is why performance is not great. awq-hqq uses tinygemm int4->bf16 kernel + hqq to provide better performance. + +| Model | Quantization | Tokens/sec | Throughput (GB/sec) | Peak Mem (GB) | Model Size (GB) | +|--------------------|--------------|------------|---------------------|---------------|-----------------| +| Llama-2-7b-chat-hf | bfloat16 | 107.38 | 1418.93 | 13.88 | 13.21 | +| | awq-hqq-int4 | 196.6 | 761.2 | 5.05 | 3.87 | +| | awq-uint4 | 43.59 | 194.93 | 7.31 | 4.47 | +| | int4wo-hqq | 209.19 | 804.32 | 4.89 | 3.84 | +| | int4wo-64 | 201.14 | 751.42 | 4.87 | 3.74 | + + + +The following tests were performed using LM eval and groupsize = 128 +| Model | Quantization | Perplexity | Truthful QA MC2 | WinoGrande | ARC challenge | +| Llama-3-8B-Instruct| bfloat16 | 10.936 | 0.540 | 0.783 | 0.567 | +| | awq-hqq-int4 | 11.383 | 0.522 | 0.772 | 0.543 | +| | awq-uint4 | 11.409 | 0.519 | 0.756 | 0.577 | +| | int4wo-hqq | 11.905 | 0.528 | 0.757 | 0.563 | +| | int4wo-128 | 12.380 | 0.502 | 0.753 | 0.548 | + + + + + + diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 1e29c25a23..b8b393a290 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -8,6 +8,7 @@ ZeroPointDomain, _DTYPE_TO_QVALUE_BOUNDS, ) +from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata from torchao.quantization.observer import PerGroup from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType @@ -16,7 +17,6 @@ AWQObserver, AWQObservedLinear, ) -from .layout import to_weight_tensor_with_equalization_scales assert len(_DTYPE_TO_BIT_WIDTH) > 0, "Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+" @@ -34,7 +34,7 @@ def insert_awq_observer_(model: torch.nn.Module, n_validation_examples: int, val group_size: Quantization granularity. Use -1 for channel wise quantization """ _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) - + assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" # AQT config mapping_type = MappingType.ASYMMETRIC quantization_granularity = PerGroup(group_size) @@ -44,7 +44,7 @@ def insert_awq_observer_(model: torch.nn.Module, n_validation_examples: int, val preserve_zero = True zero_point_dtype = torch.int64 zero_point_domain = ZeroPointDomain.INT - assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" + def replace_with_observer(layer): # creates observer and replaces linear layers with AWQObservedLinear layers @@ -126,7 +126,7 @@ def weight_quant_func(observed_linear): preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type) - return to_weight_tensor_with_equalization_scales(qw, equalization_scale) + return to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale) return _observed_linear_subclass_inserter(weight_quant_func) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index de27bee836..4be6a75860 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -31,46 +31,125 @@ def get_calib_dataset(tokenizer=None, n_samples=100, block_size=512): cat_samples = torch.cat(samples, dim=1) return [cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_samples)] -def wiki2_eval(model, tokenizer, sequence_length): - testenc = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") - testenc = tokenizer("\n\n".join(testenc["text"]), return_tensors="pt") - testenc = testenc.input_ids - nsamples = 100 - model = model.eval() - # calculate perplexity - nlls = [] - for i in tqdm(range(nsamples), desc="evaluating..."): - batch = testenc[:, i : i + sequence_length].to( - model.device - ) - with torch.no_grad(): - lm_logits = model(batch).logits - batch = batch.to("cpu") - lm_logits = lm_logits.to("cpu") - shift_logits = lm_logits[:, :-1, :].contiguous().float() - shift_labels = testenc[:, i : i + sequence_length][:, 1:].to("cpu") - loss_fct = torch.nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) - ) - neg_log_likelihood = loss.float() * sequence_length - nlls.append(neg_log_likelihood) - - ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * sequence_length)) +# from https://github.com/mobiusml/hqq/blob/master/examples/llama2_benchmark/eval_model.py +def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True): + model.eval() + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "right" + tokenizer.add_eos_token = False + + dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + encodings = tokenizer('\n\n'.join(dataset['text']), return_tensors='pt') + + encodings['input_ids'] = encodings['input_ids'].to('cuda') + + lls, t = [], [] + for i in tqdm(range(0, encodings['input_ids'].size(1), stride), disable=not verbose): + begin_loc = max(i + stride - sequence_length, 0) + end_loc = min(i + stride, encodings['input_ids'].size(1)) + trg_len = end_loc - i + input_ids = encodings['input_ids'][:,begin_loc:end_loc] + target_ids = input_ids.clone() + target_ids[:,:-trg_len] = -100 #ignore context + + t1 = time.time() + with torch.no_grad(): + log_likelihood = model(input_ids, labels=target_ids).loss * trg_len + torch.cuda.synchronize() + t2 = time.time() + t.append((t2-t1)) + lls.append(log_likelihood) + + del input_ids, target_ids + + ppl = float(torch.exp(torch.stack(lls).sum() / end_loc)) + pred_time = sum(t)/len(t) + if(verbose): + print('perplexity', ppl) + print('time', str(pred_time) + ' sec') + + return {'perplexity':ppl, 'prediction_time':pred_time} + +# from Hicham Badri (@mobicham) +def QA(model, tokenizer): + import numpy as np + import copy + import lm_eval + model.eval(); + model.config.use_cache = False + try: + lm_eval.tasks.initialize_tasks() + except: + pass + model_eval = lm_eval.models.huggingface.HFLM(pretrained=model, tokenizer=tokenizer) + eval_batch_size = 1 #8 + + results = {} + ############################################ + for task in [("truthfulqa_mc2", 0)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] + print(tag, results[tag]) + + for task in [("winogrande", 5)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] + print(tag, results[tag]) + + for task in [("arc_challenge", 25)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] + print(tag, results[tag]) + + # ############################################ + for task in [("hellaswag", 10)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] + print(tag, results[tag]) + + for task in [("gsm8k", 5)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] + print(tag, results[tag]) + # ############################################ + + results_1 = copy.deepcopy(results) - return ppl + #MMLU + results_mmlu = {} + for task in [("mmlu", 5)]: + tag, fewshot = task + results_mmlu[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] + print(tag, results_mmlu[tag]) + + mmlu_list = "hendrycksTest-abstract_algebra,hendrycksTest-anatomy,hendrycksTest-astronomy,hendrycksTest-business_ethics,hendrycksTest-clinical_knowledge,hendrycksTest-college_biology,hendrycksTest-college_chemistry,hendrycksTest-college_computer_science,hendrycksTest-college_mathematics,hendrycksTest-college_medicine,hendrycksTest-college_physics,hendrycksTest-computer_security,hendrycksTest-conceptual_physics,hendrycksTest-econometrics,hendrycksTest-electrical_engineering,hendrycksTest-elementary_mathematics,hendrycksTest-formal_logic,hendrycksTest-global_facts,hendrycksTest-high_school_biology,hendrycksTest-high_school_chemistry,hendrycksTest-high_school_computer_science,hendrycksTest-high_school_european_history,hendrycksTest-high_school_geography,hendrycksTest-high_school_government_and_politics,hendrycksTest-high_school_macroeconomics,hendrycksTest-high_school_mathematics,hendrycksTest-high_school_microeconomics,hendrycksTest-high_school_physics,hendrycksTest-high_school_psychology,hendrycksTest-high_school_statistics,hendrycksTest-high_school_us_history,hendrycksTest-high_school_world_history,hendrycksTest-human_aging,hendrycksTest-human_sexuality,hendrycksTest-international_law,hendrycksTest-jurisprudence,hendrycksTest-logical_fallacies,hendrycksTest-machine_learning,hendrycksTest-management,hendrycksTest-marketing,hendrycksTest-medical_genetics,hendrycksTest-miscellaneous,hendrycksTest-moral_disputes,hendrycksTest-moral_scenarios,hendrycksTest-nutrition,hendrycksTest-philosophy,hendrycksTest-prehistory,hendrycksTest-professional_accounting,hendrycksTest-professional_law,hendrycksTest-professional_medicine,hendrycksTest-professional_psychology,hendrycksTest-public_relations,hendrycksTest-security_studies,hendrycksTest-sociology,hendrycksTest-us_foreign_policy,hendrycksTest-virology,hendrycksTest-world_religions" + mmlu_list = [l.replace('hendrycksTest-','') for l in mmlu_list.split(',')] + results_mmlu = results_mmlu['mmlu'] + + k = [] + for r in results_mmlu: + if np.any([(l in r) for l in mmlu_list]): + k.append(results_mmlu[r]['acc,none']) + + assert len(k)==57 + print('MMLU avg acc', np.mean(k)) + + results['mmlu'] = np.mean(k) + return results + def wikitext2_ppl( repo_id: str, - quant: str, - calibration_size: int =100, - validation_size:int=100, - group_size: int = 128, - device="cuda", - precision=torch.bfloat16, - sequence_length=2048, - compile=False, - model_save_path=None): + quant: str, + benchmark: str, + calibration_size: int, + validation_size:int, + group_size: int, + device: str, + precision:torch.dtype, + sequence_length: int, + compile: bool, + model_save_path: str): print(f"Loading model on {device}...") torch.manual_seed(34) t0 = time.time() @@ -78,35 +157,29 @@ def wikitext2_ppl( tokenizer = AutoTokenizer.from_pretrained(repo_id) model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision).eval().to(device) print(f"Time to load model: {time.time() - t0:.02f} seconds") - if quant.startswith("awq"): quant_dtype = quant.split("-")[1] quant_dtype = getattr(torch, quant_dtype, torch.bfloat16) print(f"running {quant_dtype} calibration") t0 = time.time() - # insert observers to find average magnitude and calculate scales - insert_awq_observer_(model,validation_size, sequence_length, quant_dtype=quant_dtype, group_size=group_size) calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibration_size, block_size=sequence_length) for batch in calibration_data: model(batch.to(device)) batch.to("cpu") print(f"time for calibration: {time.time() - t0:.02f} seconds") - - print(f"running {quant_dtype} quantization") - # use awq_uintx() to apply awq quantization - is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) - - t0 = time.time() + is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) if "hqq" in quant: + print(f"running awq-hqq quantization") from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, _DTYPE_TO_QVALUE_BOUNDS, ) from torchao.dtypes import to_affine_quantized_intx, TensorCoreTiledLayoutType + # example of using a different quantization function def hqqint4(weight): mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) @@ -119,8 +192,12 @@ def hqqint4(weight): zero_point_domain = ZeroPointDomain.FLOAT return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8), use_hqq=True) + t0 = time.time() quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, weight_quant_fn=hqqint4), is_observed_linear) else: + print(f"running {quant_dtype} quantization") + t0 = time.time() + # use awq_uintx() to apply awq quantization quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) print(f"time for quantization: {time.time() - t0:.02f} seconds") if model_save_path is not None: @@ -134,8 +211,12 @@ def hqqint4(weight): quantize_(model,int4_weight_only(group_size=group_size, use_hqq=True)) if compile: model = torch.compile(model) - - return wiki2_eval(model, tokenizer, 1024) + if benchmark == "QA": + return QA(model, tokenizer) + elif benchmark == "PPL": + return wiki2_eval(model, tokenizer, sequence_length) + else: + print("Invalid benchmark specified. Choose either PPL or QA") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Evaluate a model with the specified parameters.") @@ -144,12 +225,13 @@ def hqqint4(weight): # Optional arguments with default values parser.add_argument("repo", type=str, help="Repository ID of the model.") parser.add_argument("quant", type=str, help="Quantization method. Options are either int4 or awq-uintx where x is [1..8]") + parser.add_argument("--benchmark", type=str, help="Task to benchmark model on. Either PPL or QA", default="QA") parser.add_argument("--calibration_samples", type=int, default=10, help="Number of samples to use for calibration. Default is 10.") parser.add_argument("--validation_size", type=int, default=1, help="Validation size. Default is 1.") - parser.add_argument("--group_size", type=int, default=128, help="Group size to use for weights. Default is '128'") + parser.add_argument("--group_size", type=int, default=64, help="Group size to use for weights. Default is 64") parser.add_argument("--device", type=str, default="cuda", help="Device to run the evaluation on. Default is 'cuda'.") parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") - parser.add_argument("--seq_len", type=int, default=512, help="Length of examples to calibrate/evaluate model on. Default 512") + parser.add_argument("--seq_len", type=int, default=512, help="Length of examples to calibrate and evaluate model on. Default is 512") parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") parser.add_argument("--model_save_path", type=str, default=None, help="Path to store the scale values.") @@ -160,6 +242,7 @@ def hqqint4(weight): ppl = wikitext2_ppl( repo_id=args.repo, quant=args.quant, + benchmark = args.benchmark, calibration_size=args.calibration_samples, validation_size=args.validation_size, group_size= args.group_size, @@ -170,4 +253,4 @@ def hqqint4(weight): model_save_path=args.model_save_path ) - print(f"{args.quant} Perplexity: {ppl.item():.5f}") \ No newline at end of file + print(f"{args.quant} Perplexity: {ppl.items():.5f}") \ No newline at end of file diff --git a/torchao/prototype/awq/readme.md b/torchao/prototype/awq/readme.md deleted file mode 100644 index cba3023577..0000000000 --- a/torchao/prototype/awq/readme.md +++ /dev/null @@ -1,26 +0,0 @@ -# AWQ Quantization -Adapted from https://github.com/mit-han-lab/llm-awq - -## Benchmarks -Evaluation perplexity numbers were calculated using the script in awq/example.py which calculates perplexity by concatenating wikitex2 test examples with newlines and dividing by context length. Group size of 64 was used for all quantization methods. For Llama-2-7b-chat-hf, performance benchmarks were calculated using the torchao/_models/llama/generate.py script and run on a 1xA100 80GB SXM4 instance. The awq-uint4 quantization method does not use an efficient fused kernel which is why performance is not great. awq-hqq uses tinygemm int4->bf16 kernel + hqq to provide better performance. - -| Model | Quantization | Perplexity | Tokens/sec | Throughput (GB/sec) | Peak Mem (GB) | Model Size (GB) | -|--------------------|--------------|------------|------------|---------------------|---------------|-----------------| -| Llama-2-7b-chat-hf | bfloat16 | 5.0309 | 107.38 | 1418.93 | 13.88 | 13.21 | -| | awq-uint4 | 5.2388 | 43.59 | 194.93 | 7.31 | 4.47 | -| | int4 | 5.28 | 201.14 | 751.42 | 4.87 | 3.74 | -| | awq-hqq | 5.204 | 196.6 | 761.2 | 5.05 | 3.87 | -| | hqq | 5.3419 | 209.19 | 804.32 | 4.89 | 3.84 | -| Llama-3-8b | bfloat16 | 4.6269 | -| | awq-uint4 | 4.968 | -| | int4 | 5.04325 | -| | awq-hqq | 4.8525 | -| | hqq | 5.1277 | -| Llama-3.1-8b | bfloat16 | 4.69732 | -| | awq-uint4 | 4.98163 | -| | int4 | 5.04091 | -| | awq-hqq | 4.90632 | -| | hqq | 5.14375 | - - - diff --git a/torchao/quantization/linear_activation_scale.py b/torchao/quantization/linear_activation_scale.py index f1230ef75b..eee918147c 100644 --- a/torchao/quantization/linear_activation_scale.py +++ b/torchao/quantization/linear_activation_scale.py @@ -145,8 +145,7 @@ def _(func, types, args, kwargs): kwargs, args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), ) - - + @implements(aten.t.default) def _(func, types, args, kwargs): return return_and_correct_aliasing( From bc7526e41dbf610ab21614b3343de26bdda5e710 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 2 Oct 2024 22:48:21 -0700 Subject: [PATCH 55/69] remove layout.py --- torchao/prototype/awq/__init__.py | 3 +- torchao/prototype/awq/layout.py | 151 ------------------------------ 2 files changed, 1 insertion(+), 153 deletions(-) delete mode 100644 torchao/prototype/awq/layout.py diff --git a/torchao/prototype/awq/__init__.py b/torchao/prototype/awq/__init__.py index 6ba1ffe694..ca9381d575 100644 --- a/torchao/prototype/awq/__init__.py +++ b/torchao/prototype/awq/__init__.py @@ -1,3 +1,2 @@ from .api import insert_awq_observer_, awq_uintx -from .core import AWQObservedLinear -from .layout import to_weight_tensor_with_equalization_scales \ No newline at end of file +from .core import AWQObservedLinear \ No newline at end of file diff --git a/torchao/prototype/awq/layout.py b/torchao/prototype/awq/layout.py deleted file mode 100644 index 90bb8c8c30..0000000000 --- a/torchao/prototype/awq/layout.py +++ /dev/null @@ -1,151 +0,0 @@ -import torch -from typing import Callable, Optional, Dict, Any -from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import ( - TorchAOBaseTensor, - TORCH_VERSION_AT_LEAST_2_5, -) -from torchao.dtypes import AffineQuantizedTensor - - -aten = torch.ops.aten - - -class WeightTensorWithEqualizationScales(TorchAOBaseTensor): - """ - Tensor subclass that wraps a quantized weight tensor and provides the equalization scales which are applied to activations. - - Args: - quantized_weight_tensor (torch.Tensor): The weight tensor to be wrapped. - scale (torch.Tensor): The scale tensor for activation quantization. - zero_point (Optional[torch.Tensor]): The zero point tensor for activation quantization. Default is None. - equalization_scale (torch.Tensor): The equalization scale tensor. - """ - - quantized_weight_tensor: TorchAOBaseTensor - equalization_scale: torch.Tensor - - def __new__( - cls, - quantized_weight_tensor: torch.Tensor, - equalization_scale: torch.Tensor - ): - kwargs = {} - dtype = quantized_weight_tensor.dtype - kwargs["dtype"] = dtype - kwargs["requires_grad"] = False - kwargs["device"] = quantized_weight_tensor.device - shape = quantized_weight_tensor.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - quantized_weight_tensor: torch.Tensor, - equalization_scale: torch.Tensor - ): - self.quantized_weight_tensor = quantized_weight_tensor - self.equalization_scale = equalization_scale - - def __repr__(self): - return f"LinearActivationQuantizedTensor({self.quantized_weight_tensor}, eq_scale={self.equalization_scale})" - - def __tensor_flatten__(self): - tensor_data = ["quantized_weight_tensor", "equalization_scale"] - return tensor_data, [] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - quantized_weight_tensor = tensor_data_dict["quantized_weight_tensor"] - equalization_scale = tensor_data_dict["equalization_scale"] - return cls( - quantized_weight_tensor, - equalization_scale, - ) - - @staticmethod - def _quantized_linear_op( - input_tensor: torch.Tensor, weight_tensor: torch.Tensor, bias: torch.Tensor - ): - return torch.nn.functional.linear( - input_tensor / weight_tensor.equalization_scale, weight_tensor.quantized_weight_tensor, bias - ) - - @classmethod - def from_quantized( - cls, - quantized_weight_tensor: AffineQuantizedTensor, - equalization_scale: torch.Tensor - ): - return cls(quantized_weight_tensor, equalization_scale) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.quantized_weight_tensor), - fn(self.equalization_scale), - ) - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs.pop("device") - return self.__class__( - self.quantized_weight_tensor.to(device), - self.equalization_scale.to(device), - ) - - -implements = WeightTensorWithEqualizationScales.implements - - -@implements(torch.nn.functional.linear) -def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - if isinstance(weight_tensor, WeightTensorWithEqualizationScales): - return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - - raise NotImplementedError( - "LinearActivationQuantizedTensor: No specialized dispatch found for linear op" - ) - - -@implements(aten.detach.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - -@implements(aten.clone.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - -@implements(aten._to_copy.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), - ) - - -@implements(aten.t.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.t) - ) - -to_weight_tensor_with_equalization_scales = WeightTensorWithEqualizationScales.from_quantized -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals( - [WeightTensorWithEqualizationScales] - ) \ No newline at end of file From 1fdf068a66b82210d78a68d0b725444fde1c6810 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 4 Oct 2024 21:53:57 -0700 Subject: [PATCH 56/69] quantization func changes --- torchao/_models/llama/generate.py | 24 +---- torchao/prototype/awq/api.py | 56 ++++++----- torchao/prototype/awq/example.py | 150 +++++++++++++----------------- 3 files changed, 98 insertions(+), 132 deletions(-) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 3e3fde7040..7e4708ba5d 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -256,28 +256,8 @@ def main( limit=calibration_limit, ) is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) - if "hqq" in quant and quant_dtype == torch.uint4: - from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, - _DTYPE_TO_QVALUE_BOUNDS, - ) - from torchao.dtypes import to_affine_quantized_intx, TensorCoreTiledLayoutType - def hqqint4(weight): - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - target_dtype = torch.int32 - quant_min = 0 - quant_max = 15 - eps = 1e-6 - preserve_zero = False - zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT - - return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8), use_hqq=True) - quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, weight_quant_fn=hqqint4), is_observed_linear) - else: - quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) + use_hqq = "hqq" in quantization + quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, use_hqq=use_hqq), is_observed_linear) if "uintx" in quantization: # uintx-nbits-groupsize, e.g. "uintx-2-64" if "hqq" in quantization: diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index b8b393a290..fe17a996dd 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -1,5 +1,3 @@ -from typing import Dict, Optional, Callable - import torch import torch.nn.functional as F @@ -12,7 +10,10 @@ from torchao.quantization.observer import PerGroup from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType -from torchao.dtypes import to_affine_quantized_intx +from torchao.dtypes import( + to_affine_quantized_intx, + TensorCoreLayoutType +) from .core import( AWQObserver, AWQObservedLinear, @@ -85,7 +86,7 @@ def insert_subclass(observed_linear): def awq_uintx(quant_dtype: torch.dtype = torch.uint4, group_size: int = 64, - weight_quant_fn: Optional[Callable[[torch.Tensor], torch.Tensor]]= None): + use_hqq: bool = False,): """ Quantizes linear layers when passed into quantize_() @@ -94,38 +95,43 @@ def awq_uintx(quant_dtype: torch.dtype = torch.uint4, group_size: Quantization granularity. Use -1 for channel wise quantization weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used """ - + assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" def weight_quant_func(observed_linear): - # weight quantization - # AQT config equalization_scale = observed_linear.act_obs.calculate_qparams() - if weight_quant_fn is not None: - qw = weight_quant_fn(observed_linear.weight * equalization_scale) + # AQT config + if quant_dtype == torch.uint4: + target_dtype = torch.int32 + eps = 1e-6 + preserve_zero = False + zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.FLOAT + layout_type = TensorCoreLayoutType(inner_k_tiles=8) else: - # usage according to original paper - assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" - target_dtype = torch.uint8 - mapping_type = MappingType.ASYMMETRIC - quantization_granularity = PerGroup(group_size) - quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0] - quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1] eps = torch.finfo(torch.float32).eps preserve_zero = True zero_point_dtype = torch.int64 zero_point_domain = ZeroPointDomain.INT layout_type = UintxLayoutType(quant_dtype) + + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0] + quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1] + qw = to_affine_quantized_intx( + observed_linear.weight * equalization_scale, + mapping_type, + block_size, + target_dtype, quant_min, + quant_max, eps, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + layout_type=layout_type, + use_hqq=use_hqq + ) - qw = to_affine_quantized_intx( - observed_linear.weight * equalization_scale, - mapping_type, (1, quantization_granularity.group_size), - target_dtype, quant_min, - quant_max, eps, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - layout_type=layout_type) return to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale) return _observed_linear_subclass_inserter(weight_quant_func) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 4be6a75860..a9b7a963f8 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -70,8 +70,8 @@ def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True): return {'perplexity':ppl, 'prediction_time':pred_time} -# from Hicham Badri (@mobicham) -def QA(model, tokenizer): +# adapted from Hicham Badri (@mobicham) +def benchmark(model, tokenizer, max_length, tasks=None): import numpy as np import copy import lm_eval @@ -83,65 +83,70 @@ def QA(model, tokenizer): pass model_eval = lm_eval.models.huggingface.HFLM(pretrained=model, tokenizer=tokenizer) eval_batch_size = 1 #8 - + if tasks is None: + tasks = ["PPL","truthfulqa_mc2", "winogrande", "arc_challenge", "hellaswag", "gsm8k", "mmlu"] results = {} + if "PPL" in tasks: + results["perplexity"] = wiki2_eval(model, tokenizer, 512, verbose=True) ############################################ - for task in [("truthfulqa_mc2", 0)]: - tag, fewshot = task - results[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] - print(tag, results[tag]) - - for task in [("winogrande", 5)]: - tag, fewshot = task - results[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] - print(tag, results[tag]) - - for task in [("arc_challenge", 25)]: - tag, fewshot = task - results[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] - print(tag, results[tag]) + if "truthfulqa_mc2" in tasks: + for task in [("truthfulqa_mc2", 0)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] + print(tag, results[tag]) + if "winogrande" in tasks: + for task in [("winogrande", 5)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] + print(tag, results[tag]) + if "arc_challenge" in tasks: + for task in [("arc_challenge", 25)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] + print(tag, results[tag]) # ############################################ - for task in [("hellaswag", 10)]: - tag, fewshot = task - results[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] - print(tag, results[tag]) - - for task in [("gsm8k", 5)]: - tag, fewshot = task - results[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] - print(tag, results[tag]) + if "hellaswag" in tasks: + for task in [("hellaswag", 10)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] + print(tag, results[tag]) + if "gsm8k" in tasks: + for task in [("gsm8k", 5)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] + print(tag, results[tag]) # ############################################ results_1 = copy.deepcopy(results) - - #MMLU - results_mmlu = {} - for task in [("mmlu", 5)]: - tag, fewshot = task - results_mmlu[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] - print(tag, results_mmlu[tag]) - - mmlu_list = "hendrycksTest-abstract_algebra,hendrycksTest-anatomy,hendrycksTest-astronomy,hendrycksTest-business_ethics,hendrycksTest-clinical_knowledge,hendrycksTest-college_biology,hendrycksTest-college_chemistry,hendrycksTest-college_computer_science,hendrycksTest-college_mathematics,hendrycksTest-college_medicine,hendrycksTest-college_physics,hendrycksTest-computer_security,hendrycksTest-conceptual_physics,hendrycksTest-econometrics,hendrycksTest-electrical_engineering,hendrycksTest-elementary_mathematics,hendrycksTest-formal_logic,hendrycksTest-global_facts,hendrycksTest-high_school_biology,hendrycksTest-high_school_chemistry,hendrycksTest-high_school_computer_science,hendrycksTest-high_school_european_history,hendrycksTest-high_school_geography,hendrycksTest-high_school_government_and_politics,hendrycksTest-high_school_macroeconomics,hendrycksTest-high_school_mathematics,hendrycksTest-high_school_microeconomics,hendrycksTest-high_school_physics,hendrycksTest-high_school_psychology,hendrycksTest-high_school_statistics,hendrycksTest-high_school_us_history,hendrycksTest-high_school_world_history,hendrycksTest-human_aging,hendrycksTest-human_sexuality,hendrycksTest-international_law,hendrycksTest-jurisprudence,hendrycksTest-logical_fallacies,hendrycksTest-machine_learning,hendrycksTest-management,hendrycksTest-marketing,hendrycksTest-medical_genetics,hendrycksTest-miscellaneous,hendrycksTest-moral_disputes,hendrycksTest-moral_scenarios,hendrycksTest-nutrition,hendrycksTest-philosophy,hendrycksTest-prehistory,hendrycksTest-professional_accounting,hendrycksTest-professional_law,hendrycksTest-professional_medicine,hendrycksTest-professional_psychology,hendrycksTest-public_relations,hendrycksTest-security_studies,hendrycksTest-sociology,hendrycksTest-us_foreign_policy,hendrycksTest-virology,hendrycksTest-world_religions" - mmlu_list = [l.replace('hendrycksTest-','') for l in mmlu_list.split(',')] - results_mmlu = results_mmlu['mmlu'] - - k = [] - for r in results_mmlu: - if np.any([(l in r) for l in mmlu_list]): - k.append(results_mmlu[r]['acc,none']) - - assert len(k)==57 - print('MMLU avg acc', np.mean(k)) - - results['mmlu'] = np.mean(k) + if "mmlu" in tasks: + #MMLU + results_mmlu = {} + for task in [("mmlu", 5)]: + tag, fewshot = task + results_mmlu[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] + print(tag, results_mmlu[tag]) + + mmlu_list = "hendrycksTest-abstract_algebra,hendrycksTest-anatomy,hendrycksTest-astronomy,hendrycksTest-business_ethics,hendrycksTest-clinical_knowledge,hendrycksTest-college_biology,hendrycksTest-college_chemistry,hendrycksTest-college_computer_science,hendrycksTest-college_mathematics,hendrycksTest-college_medicine,hendrycksTest-college_physics,hendrycksTest-computer_security,hendrycksTest-conceptual_physics,hendrycksTest-econometrics,hendrycksTest-electrical_engineering,hendrycksTest-elementary_mathematics,hendrycksTest-formal_logic,hendrycksTest-global_facts,hendrycksTest-high_school_biology,hendrycksTest-high_school_chemistry,hendrycksTest-high_school_computer_science,hendrycksTest-high_school_european_history,hendrycksTest-high_school_geography,hendrycksTest-high_school_government_and_politics,hendrycksTest-high_school_macroeconomics,hendrycksTest-high_school_mathematics,hendrycksTest-high_school_microeconomics,hendrycksTest-high_school_physics,hendrycksTest-high_school_psychology,hendrycksTest-high_school_statistics,hendrycksTest-high_school_us_history,hendrycksTest-high_school_world_history,hendrycksTest-human_aging,hendrycksTest-human_sexuality,hendrycksTest-international_law,hendrycksTest-jurisprudence,hendrycksTest-logical_fallacies,hendrycksTest-machine_learning,hendrycksTest-management,hendrycksTest-marketing,hendrycksTest-medical_genetics,hendrycksTest-miscellaneous,hendrycksTest-moral_disputes,hendrycksTest-moral_scenarios,hendrycksTest-nutrition,hendrycksTest-philosophy,hendrycksTest-prehistory,hendrycksTest-professional_accounting,hendrycksTest-professional_law,hendrycksTest-professional_medicine,hendrycksTest-professional_psychology,hendrycksTest-public_relations,hendrycksTest-security_studies,hendrycksTest-sociology,hendrycksTest-us_foreign_policy,hendrycksTest-virology,hendrycksTest-world_religions" + mmlu_list = [l.replace('hendrycksTest-','') for l in mmlu_list.split(',')] + results_mmlu = results_mmlu['mmlu'] + + k = [] + for r in results_mmlu: + if np.any([(l in r) for l in mmlu_list]): + k.append(results_mmlu[r]['acc,none']) + + assert len(k)==57 + print('MMLU avg acc', np.mean(k)) + + results['mmlu'] = np.mean(k) return results def wikitext2_ppl( repo_id: str, quant: str, - benchmark: str, + tasks: list[str], calibration_size: int, validation_size:int, group_size: int, @@ -159,6 +164,7 @@ def wikitext2_ppl( print(f"Time to load model: {time.time() - t0:.02f} seconds") if quant.startswith("awq"): quant_dtype = quant.split("-")[1] + group_size = int(quant.split("-")[2]) quant_dtype = getattr(torch, quant_dtype, torch.bfloat16) print(f"running {quant_dtype} calibration") t0 = time.time() @@ -171,50 +177,25 @@ def wikitext2_ppl( print(f"time for calibration: {time.time() - t0:.02f} seconds") is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) - if "hqq" in quant: - print(f"running awq-hqq quantization") - from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, - _DTYPE_TO_QVALUE_BOUNDS, - ) - from torchao.dtypes import to_affine_quantized_intx, TensorCoreTiledLayoutType - # example of using a different quantization function - def hqqint4(weight): - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - target_dtype = torch.int32 - quant_min = 0 - quant_max = 15 - eps = 1e-6 - preserve_zero = False - zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT - - return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8), use_hqq=True) - t0 = time.time() - quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, weight_quant_fn=hqqint4), is_observed_linear) - else: - print(f"running {quant_dtype} quantization") - t0 = time.time() - # use awq_uintx() to apply awq quantization - quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) + use_hqq = "hqq" in quant + print(f"running {quant_dtype} quantization") + t0 = time.time() + quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, use_hqq=use_hqq), is_observed_linear) print(f"time for quantization: {time.time() - t0:.02f} seconds") if model_save_path is not None: print(f"Saving model to {model_save_path}") torch.save(model, model_save_path) - elif quant=="int4": - print("running int4 quantization") + elif quant.startswith("int4wo"): + group_size = int(quant.split("-")[1]) + print(f"running int4 weight only quantization with group size {group_size}") quantize_(model, int4_weight_only(group_size=group_size)) elif quant=="hqq": - print("running int4-hqq quantization") + print("running int4-hqq weight only quantization") quantize_(model,int4_weight_only(group_size=group_size, use_hqq=True)) if compile: model = torch.compile(model) - if benchmark == "QA": - return QA(model, tokenizer) - elif benchmark == "PPL": - return wiki2_eval(model, tokenizer, sequence_length) + + results = benchmark(model, tokenizer, sequence_length, tasks=tasks) else: print("Invalid benchmark specified. Choose either PPL or QA") @@ -225,10 +206,9 @@ def hqqint4(weight): # Optional arguments with default values parser.add_argument("repo", type=str, help="Repository ID of the model.") parser.add_argument("quant", type=str, help="Quantization method. Options are either int4 or awq-uintx where x is [1..8]") - parser.add_argument("--benchmark", type=str, help="Task to benchmark model on. Either PPL or QA", default="QA") + parser.add_argument("--tasks", type=list[str], help="Task to benchmark model on. Either PPL or QA", default=["PPL"]) parser.add_argument("--calibration_samples", type=int, default=10, help="Number of samples to use for calibration. Default is 10.") parser.add_argument("--validation_size", type=int, default=1, help="Validation size. Default is 1.") - parser.add_argument("--group_size", type=int, default=64, help="Group size to use for weights. Default is 64") parser.add_argument("--device", type=str, default="cuda", help="Device to run the evaluation on. Default is 'cuda'.") parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") parser.add_argument("--seq_len", type=int, default=512, help="Length of examples to calibrate and evaluate model on. Default is 512") From 85ea32cede2ae45c3c7758e522d9a5639a6604c4 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 4 Oct 2024 22:00:27 -0700 Subject: [PATCH 57/69] typo --- torchao/prototype/awq/example.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index a9b7a963f8..f36dbaea9f 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -196,8 +196,6 @@ def wikitext2_ppl( model = torch.compile(model) results = benchmark(model, tokenizer, sequence_length, tasks=tasks) - else: - print("Invalid benchmark specified. Choose either PPL or QA") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Evaluate a model with the specified parameters.") From 0a12f9618b2fb3b9af140a7576735cf3b55a9835 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 4 Oct 2024 22:03:45 -0700 Subject: [PATCH 58/69] fix import --- torchao/prototype/awq/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index fe17a996dd..e3a8827e2a 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -12,7 +12,7 @@ from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType from torchao.dtypes import( to_affine_quantized_intx, - TensorCoreLayoutType + TensorCoreTiledLayoutType, ) from .core import( AWQObserver, @@ -106,7 +106,7 @@ def weight_quant_func(observed_linear): preserve_zero = False zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT - layout_type = TensorCoreLayoutType(inner_k_tiles=8) + layout_type = TensorCoreTiledLayoutType(inner_k_tiles=8) else: target_dtype = torch.uint8 eps = torch.finfo(torch.float32).eps From 1b5a57b1f993300fe3d8a3ac8c0ffac8a3f7af34 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 4 Oct 2024 22:05:53 -0700 Subject: [PATCH 59/69] fix fn params --- torchao/prototype/awq/example.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index f36dbaea9f..8efe187918 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -218,17 +218,17 @@ def wikitext2_ppl( # Convert precision argument to torch dtype precision_dtype = getattr(torch, args.precision, torch.bfloat16) ppl = wikitext2_ppl( - repo_id=args.repo, - quant=args.quant, - benchmark = args.benchmark, - calibration_size=args.calibration_samples, - validation_size=args.validation_size, - group_size= args.group_size, - device=args.device, - precision=precision_dtype, - sequence_length=args.seq_len, - compile=args.compile, - model_save_path=args.model_save_path + args.repo, + args.quant, + args.tasks, + args.calibration_samples, + args.validation_size, + args.group_size, + args.device, + args.precision_dtype, + args.seq_len, + args.compile, + args.model_save_path ) print(f"{args.quant} Perplexity: {ppl.items():.5f}") \ No newline at end of file From c193afb9dd02360c9579ff42301e4a24fb599484 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 4 Oct 2024 22:09:38 -0700 Subject: [PATCH 60/69] edit --- torchao/prototype/awq/example.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 8efe187918..e50ae8a68e 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -149,7 +149,6 @@ def wikitext2_ppl( tasks: list[str], calibration_size: int, validation_size:int, - group_size: int, device: str, precision:torch.dtype, sequence_length: int, @@ -187,11 +186,9 @@ def wikitext2_ppl( torch.save(model, model_save_path) elif quant.startswith("int4wo"): group_size = int(quant.split("-")[1]) - print(f"running int4 weight only quantization with group size {group_size}") - quantize_(model, int4_weight_only(group_size=group_size)) - elif quant=="hqq": - print("running int4-hqq weight only quantization") - quantize_(model,int4_weight_only(group_size=group_size, use_hqq=True)) + use_hqq = "hqq" in quant + print(f"running {quant} quantization with group size {group_size}") + quantize_(model, int4_weight_only(group_size=group_size, use_hqq)) if compile: model = torch.compile(model) @@ -203,7 +200,7 @@ def wikitext2_ppl( # Optional arguments with default values parser.add_argument("repo", type=str, help="Repository ID of the model.") - parser.add_argument("quant", type=str, help="Quantization method. Options are either int4 or awq-uintx where x is [1..8]") + parser.add_argument("quant", type=str, help="Quantization method. Options are either awq-uint- for x =[1..8], int4wo-, or int4wo--hqq.") parser.add_argument("--tasks", type=list[str], help="Task to benchmark model on. Either PPL or QA", default=["PPL"]) parser.add_argument("--calibration_samples", type=int, default=10, help="Number of samples to use for calibration. Default is 10.") parser.add_argument("--validation_size", type=int, default=1, help="Validation size. Default is 1.") @@ -223,7 +220,6 @@ def wikitext2_ppl( args.tasks, args.calibration_samples, args.validation_size, - args.group_size, args.device, args.precision_dtype, args.seq_len, From 4e60dfd7dd2416eb3b7b4627e831174726a7372c Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 4 Oct 2024 22:09:58 -0700 Subject: [PATCH 61/69] edit --- torchao/prototype/awq/example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index e50ae8a68e..6f94772db2 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -188,7 +188,7 @@ def wikitext2_ppl( group_size = int(quant.split("-")[1]) use_hqq = "hqq" in quant print(f"running {quant} quantization with group size {group_size}") - quantize_(model, int4_weight_only(group_size=group_size, use_hqq)) + quantize_(model, int4_weight_only(group_size=group_size, use_hqq= use_hqq)) if compile: model = torch.compile(model) From 93dcb79cd4b20b57ff05fd0960d2d76a3404cef8 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 4 Oct 2024 22:10:28 -0700 Subject: [PATCH 62/69] rename --- torchao/prototype/awq/example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 6f94772db2..80b8e368fe 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -221,7 +221,7 @@ def wikitext2_ppl( args.calibration_samples, args.validation_size, args.device, - args.precision_dtype, + args.precision, args.seq_len, args.compile, args.model_save_path From d2ed1f2f37bacd482f2303f78bba6a10910bd6d1 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 4 Oct 2024 22:26:00 -0700 Subject: [PATCH 63/69] fix indentation --- torchao/prototype/awq/example.py | 74 ++++++++++++++++---------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 80b8e368fe..8b2eb06758 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -33,42 +33,42 @@ def get_calib_dataset(tokenizer=None, n_samples=100, block_size=512): # from https://github.com/mobiusml/hqq/blob/master/examples/llama2_benchmark/eval_model.py def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True): - model.eval() - tokenizer.pad_token = tokenizer.eos_token - tokenizer.padding_side = "right" - tokenizer.add_eos_token = False - - dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') - encodings = tokenizer('\n\n'.join(dataset['text']), return_tensors='pt') - - encodings['input_ids'] = encodings['input_ids'].to('cuda') - - lls, t = [], [] - for i in tqdm(range(0, encodings['input_ids'].size(1), stride), disable=not verbose): - begin_loc = max(i + stride - sequence_length, 0) - end_loc = min(i + stride, encodings['input_ids'].size(1)) - trg_len = end_loc - i - input_ids = encodings['input_ids'][:,begin_loc:end_loc] - target_ids = input_ids.clone() - target_ids[:,:-trg_len] = -100 #ignore context - - t1 = time.time() - with torch.no_grad(): - log_likelihood = model(input_ids, labels=target_ids).loss * trg_len - torch.cuda.synchronize() - t2 = time.time() - t.append((t2-t1)) - lls.append(log_likelihood) - - del input_ids, target_ids - - ppl = float(torch.exp(torch.stack(lls).sum() / end_loc)) - pred_time = sum(t)/len(t) - if(verbose): - print('perplexity', ppl) - print('time', str(pred_time) + ' sec') - - return {'perplexity':ppl, 'prediction_time':pred_time} + model.eval() + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "right" + tokenizer.add_eos_token = False + + dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + encodings = tokenizer('\n\n'.join(dataset['text']), return_tensors='pt') + + encodings['input_ids'] = encodings['input_ids'].to('cuda') + + lls, t = [], [] + for i in tqdm(range(0, encodings['input_ids'].size(1), stride), disable=not verbose): + begin_loc = max(i + stride - sequence_length, 0) + end_loc = min(i + stride, encodings['input_ids'].size(1)) + trg_len = end_loc - i + input_ids = encodings['input_ids'][:,begin_loc:end_loc] + target_ids = input_ids.clone() + target_ids[:,:-trg_len] = -100 #ignore context + + t1 = time.time() + with torch.no_grad(): + log_likelihood = model(input_ids, labels=target_ids).loss * trg_len + torch.cuda.synchronize() + t2 = time.time() + t.append((t2-t1)) + lls.append(log_likelihood) + + del input_ids, target_ids + + ppl = float(torch.exp(torch.stack(lls).sum() / end_loc)) + pred_time = sum(t)/len(t) + if(verbose): + print('perplexity', ppl) + print('time', str(pred_time) + ' sec') + + return {'perplexity':ppl, 'prediction_time':pred_time} # adapted from Hicham Badri (@mobicham) def benchmark(model, tokenizer, max_length, tasks=None): @@ -227,4 +227,4 @@ def wikitext2_ppl( args.model_save_path ) - print(f"{args.quant} Perplexity: {ppl.items():.5f}") \ No newline at end of file + print(f"{args.quant} Results: {ppl}") \ No newline at end of file From 2650702001c64494a00a929cb873a06ba3cf69fb Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Sat, 5 Oct 2024 09:19:18 -0700 Subject: [PATCH 64/69] added bf16 gaurd for uint4 tinygemm quant --- test/prototype/test_awq.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index d7aa27c011..f56492d7a7 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -50,7 +50,8 @@ def test_awq_loading(device, qdtype, idtype): n_calibration_examples = 10 n_validation_examples = 10 sequence_length = 5 - + if quant_dtype == torch.uint4 and idtype != torch.bfloat16: + pytest.skip("uint4 is uses tinygemm kernel which is only supported for bfloat16 inputs") m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device) dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device) calibration_data = dataset[:n_calibration_examples] @@ -71,7 +72,7 @@ def test_awq_loading(device, qdtype, idtype): loaded_model = torch.load(model_save_path) os.remove(model_save_path) - if device == "cuda": + if torch.cuda.is_available(): m = torch.compile(m, fullgraph=True) loaded_model = torch.compile(loaded_model, fullgraph=True) @@ -87,7 +88,7 @@ def test_awq_loading(device, qdtype, idtype): def test_save_weights_only(): dataset_size = 100 l1,l2,l3 = 512,256,128 - original_dtype = torch.half + original_dtype = torch.bfloat16 quant_dtype = torch.uint4 device = "cuda" group_size = 128 From 5e254695a3afe30f5ea4ce1086e26e8b9cff3113 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Sun, 6 Oct 2024 13:29:45 -0700 Subject: [PATCH 65/69] remove bad tests --- test/prototype/test_awq.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index f56492d7a7..891c638c62 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -36,22 +36,20 @@ def run_before_and_after_tests(): yield torch._dynamo.reset() # reset cache between tests -idtypes = (torch.half, torch.bfloat16) @pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("qdtype", qdtypes) -@pytest.mark.parametrize("idtype", idtypes) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3,reason="torch.uint(2-7) requires torch2.3+") def test_awq_loading(device, qdtype, idtype): dataset_size = 100 l1,l2,l3 = 512,256,128 - original_dtype = idtype + original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs quant_dtype = qdtype group_size = 128 n_calibration_examples = 10 n_validation_examples = 10 sequence_length = 5 - if quant_dtype == torch.uint4 and idtype != torch.bfloat16: - pytest.skip("uint4 is uses tinygemm kernel which is only supported for bfloat16 inputs") + m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device) dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device) calibration_data = dataset[:n_calibration_examples] From 3aa279f9d38889a4545234aae3d4769c292a5692 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Sun, 6 Oct 2024 15:34:39 -0700 Subject: [PATCH 66/69] remove arg --- test/prototype/test_awq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 891c638c62..893016ca92 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -40,7 +40,7 @@ def run_before_and_after_tests(): @pytest.mark.parametrize("qdtype", qdtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3,reason="torch.uint(2-7) requires torch2.3+") -def test_awq_loading(device, qdtype, idtype): +def test_awq_loading(device, qdtype): dataset_size = 100 l1,l2,l3 = 512,256,128 original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs From c08fdb1b06bcf1c70d545c6d9e7d45a247ca8371 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Sun, 6 Oct 2024 18:26:36 -0700 Subject: [PATCH 67/69] one last guard.. --- test/prototype/test_awq.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 893016ca92..063378dbfe 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -41,6 +41,9 @@ def run_before_and_after_tests(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3,reason="torch.uint(2-7) requires torch2.3+") def test_awq_loading(device, qdtype): + if qdtype == torch.uint4 and device == "cpu": + pytest.skip("uint4 not supported on cpu") + dataset_size = 100 l1,l2,l3 = 512,256,128 original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs From b967ebd7d2c72a55f510e1a7ac9a25fc87ce1fc8 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Mon, 7 Oct 2024 00:39:47 -0700 Subject: [PATCH 68/69] require nightly --- test/prototype/test_awq.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 063378dbfe..e2f8b690e3 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -4,7 +4,7 @@ import torch from torchao.quantization import quantize_ -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5 if TORCH_VERSION_AT_LEAST_2_3: from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear @@ -39,7 +39,7 @@ def run_before_and_after_tests(): @pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("qdtype", qdtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3,reason="torch.uint(2-7) requires torch2.3+") +@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5,reason="requires nightly pytorch") def test_awq_loading(device, qdtype): if qdtype == torch.uint4 and device == "cpu": pytest.skip("uint4 not supported on cpu") From e7e329b3a312f3bda9e14bc47eeda2e8da04f63e Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Mon, 7 Oct 2024 11:09:07 -0700 Subject: [PATCH 69/69] require nightly on everything --- test/prototype/test_awq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index e2f8b690e3..eccf8db8f6 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -84,7 +84,7 @@ def test_awq_loading(device, qdtype): assert awq_save_load_out is not None assert torch.allclose(awq_out, awq_save_load_out, atol = 1e-2) -@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3,reason="torch.uint(2-7) requires torch2.3+") +@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5,reason="requires nightly pytorch") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_save_weights_only(): dataset_size = 100