-
Notifications
You must be signed in to change notification settings - Fork 306
[WIP] Activation Aware Weight Quantization (AWQ) #743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+807
−9
Merged
Changes from 19 commits
Commits
Show all changes
83 commits
Select commit
Hold shift + click to select a range
d3db3db
init
vayuda 1c29347
Merge branch 'awq' of https://github.com/vayuda/ao into awq
vayuda ed864a2
fixed implementation
vayuda 90be216
Merge branch 'pytorch:main' into awq
vayuda 0e690f9
reduced vmem req
vayuda 05ae692
Merge branch 'awq' of https://github.com/vayuda/ao into awq
vayuda 4519792
eval on LLMs
vayuda fca7895
Merge branch 'pytorch:main' into awq
vayuda 0096a83
eval on llm
vayuda 20c2529
Merge branch 'awq'
vayuda 7614d51
convert list to tensor
vayuda 33a28dd
restructuring
vayuda 7d389b5
revert unecessary hf_eval changes
vayuda 5913b98
added wikitext eval test
vayuda 7d045f9
added tinygemm integration
vayuda db302ef
made the calibration step much faster
vayuda 2ec38f1
merge pt1
vayuda b400711
Merge remote-tracking branch 'upstream/main' into awq
vayuda 4aae94b
works/created tutorial
vayuda 8e058d7
added docs, tests, cleaned code
vayuda dced6e5
updated benchmark
vayuda c43b997
update example
vayuda 2388091
update example
vayuda 8619cd5
added init file
vayuda 5378ac9
Merge remote-tracking branch 'upstream/main' into awq
vayuda 9027082
reduce vram for calibration
vayuda dbac7c8
eval changes+ llama2 data
vayuda aa62e5f
llama2 data + eval script init changes
vayuda a4b006f
Merge branch 'main' into awq
vayuda 7f21bfc
fixed qdtype bounds and example import
vayuda fb7fb11
Merge branch 'awq' of https://github.com/vayuda/ao into awq
vayuda 863d503
fix tests
vayuda 4ca9117
fix tests
vayuda 3e70a6f
use rolling log liklihood for eval and calibrate awq with run_eval
vayuda c37b396
Merge branch 'pytorch:main' into awq
vayuda 389ea77
Merge remote-tracking branch 'upstream/main' into awq
vayuda 8ab016a
make eval use less vram
vayuda 27c062d
updated uintx import
vayuda 59b4174
Merge remote-tracking branch 'upstream/main' into awq
vayuda 9d52c93
add awq to generate
vayuda 310138e
add calibration params to cli
vayuda a1b2bd0
fix name
vayuda 9379686
pass linear properly
vayuda 8d173df
recast W*eq_scale to original dtype
vayuda 6aab8f8
revert bad change
vayuda 4e97611
make scales same type as model
vayuda 77db01f
compatible with compile
vayuda 588e81e
cast eq scale to bf16
vayuda bfa5797
switch calibration dataset
vayuda 41a621b
remove extra line
vayuda 20767d5
add import
vayuda ae32c7c
added save/load scales
vayuda e2160ae
add save/store workflow to example
vayuda 9704c38
add arg to fn
vayuda 04ef51d
fix cli arg
vayuda 17e2fbc
fix cli arg
vayuda 71f9e27
add comma
vayuda 1716b0c
add model to fn call
vayuda 436fb9f
fix example
vayuda 84b407c
refactored awq impl
vayuda 49fbbe2
Merge branch 'pytorch:main' into awq
vayuda 1216f97
edits +update usage
vayuda 3930660
perplexity evals added
vayuda 3e5710c
updated readme with benchmarks
vayuda da6a70d
add awq-hqq to generate
vayuda 7314d99
better citation
vayuda dc0c507
Merge branch 'pytorch:main' into awq
vayuda 68d9592
nits
vayuda bc7526e
remove layout.py
vayuda 1fdf068
quantization func changes
vayuda 85ea32c
typo
vayuda 0a12f96
fix import
vayuda 1b5a57b
fix fn params
vayuda c193afb
edit
vayuda 4e60dfd
edit
vayuda 93dcb79
rename
vayuda d2ed1f2
fix indentation
vayuda 2650702
added bf16 gaurd for uint4 tinygemm quant
vayuda 5e25469
remove bad tests
vayuda 3aa279f
remove arg
vayuda c08fdb1
one last guard..
vayuda b967ebd
require nightly
vayuda e7e329b
require nightly on everything
vayuda File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
""" | ||
This file produces a file named pytorch_model.bin.index.json based on the downloaded model weights. | ||
It was primarily used to create run evals on llama2.c-stories15M model. | ||
""" | ||
import json | ||
import torch | ||
from transformers import AutoModel | ||
|
||
def create_weight_map(model_name): | ||
# Load the model | ||
model = AutoModel.from_pretrained(model_name) | ||
|
||
# Get the state dict | ||
state_dict = model.state_dict() | ||
|
||
# Create the weight map | ||
weight_map = {} | ||
for key, tensor in state_dict.items(): | ||
# In this example, we're assuming all weights are in a single file | ||
# You may need to adjust this if your model uses sharded weights | ||
weight_map[key] = "pytorch_model.bin" | ||
|
||
# Create the index dictionary | ||
index_dict = { | ||
"metadata": {"total_size": sum(param.numel() * param.element_size() for param in model.parameters())}, | ||
"weight_map": weight_map | ||
} | ||
|
||
# Save the index dictionary to a JSON file | ||
with open("pytorch_model.bin.index.json", "w") as f: | ||
json.dump(index_dict, f, indent=2) | ||
|
||
print("Created pytorch_model.bin.index.json") | ||
|
||
# Usage | ||
model_name = "checkpoints/Xenova/llama2.c-stories15M" | ||
create_weight_map(model_name) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from copy import deepcopy | ||
import torch | ||
from torchao.quantization import quantize_, int4_weight_only, int8_weight_only | ||
from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer, awq_quant | ||
import pytest | ||
|
||
|
||
class ToyLinearModel(torch.nn.Module): | ||
def __init__(self, m=512, n=256, k=128): | ||
super().__init__() | ||
self.linear1 = torch.nn.Linear(m, n, bias=False) | ||
self.linear2 = torch.nn.Linear(n, k, bias=False) | ||
self.linear3 = torch.nn.Linear(k, 1, bias=False) | ||
|
||
def example_inputs(self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"): | ||
return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)] | ||
|
||
def forward(self, x): | ||
x = self.linear1(x) | ||
x = self.linear2(x) | ||
x = self.linear3(x) | ||
return x | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
def test(): | ||
device = ("cuda") | ||
dataset_size = 1000 | ||
original_dtype = torch.bfloat16 | ||
l1,l2,l3 = 512,256,128 | ||
|
||
m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device) | ||
m_bf16 = deepcopy(m) | ||
|
||
dataset = m.example_inputs(dataset_size, dtype=original_dtype, device=device) | ||
calibration_data = dataset[:100] | ||
bf16_out = torch.cat([m_bf16(i.squeeze(0)) for i in dataset], dim=0) | ||
|
||
|
||
m_int4wo = deepcopy(m) | ||
quantize_(m_int4wo, int4_weight_only()) | ||
int4wo_out = torch.cat([m_int4wo(i.squeeze(0)) for i in dataset]) | ||
|
||
# calibrate | ||
quant_dtype = torch.uint4 | ||
group_size = 128 | ||
insert_awq_observer(m, quant_dtype, group_size, original_dtype, device) | ||
for example in calibration_data: | ||
m(example.to(device)) | ||
# print('calibrated') | ||
|
||
# quantize | ||
is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) | ||
quantize_(m, awq_quant(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear) | ||
awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) | ||
m = torch.compile(m, fullgraph=True) | ||
# compare accuracy | ||
awq_err = torch.sum(torch.abs(awq_out - bf16_out)).sum().item() / dataset_size | ||
int4wo_err = torch.sum(torch.abs(int4wo_out - bf16_out)).sum().item() / dataset_size | ||
vayuda marked this conversation as resolved.
Show resolved
Hide resolved
|
||
print(f"AWQ error: {awq_err}") | ||
print(f"Int4WO error: {int4wo_err}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from torchao.prototype.awq.core import AWQObserver, ObservedLinear, AWQLayoutType | ||
from torchao.quantization.quant_primitives import ( | ||
MappingType, | ||
ZeroPointDomain, | ||
) | ||
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter | ||
from torchao.dtypes import to_affine_quantized_intx | ||
from torchao.dtypes.uintx.Uintx import to_uintx | ||
from typing import Optional, Tuple | ||
|
||
|
||
|
||
|
||
def insert_awq_observer(model: torch.nn.Module, quant_dtype: torch.dtype, group_size: int, input_dtype: torch.dtype, device: torch.device): | ||
vayuda marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) | ||
if quant_dtype == torch.uint4: | ||
mapping_type = MappingType.ASYMMETRIC | ||
block_size = (1, group_size) | ||
target_dtype = torch.uint4 | ||
quant_min = 0 | ||
quant_max = 15 | ||
eps = torch.finfo(torch.float32).eps | ||
preserve_zero = True | ||
zero_point_dtype = torch.int64 | ||
zero_point_domain = ZeroPointDomain.INT | ||
|
||
elif quant_dtype == torch.int8: | ||
mapping_type = MappingType.SYMMETRIC | ||
target_dtype = torch.int8 | ||
eps = torch.finfo(torch.float32).eps | ||
zero_point_dtype = torch.int64 | ||
zero_point_domain = ZeroPointDomain.INT | ||
preserve_zero = True | ||
block_size = (1, -1) | ||
quant_min = None | ||
quant_max = None | ||
|
||
else: | ||
raise NotImplementedError(f"{quant_dtype} not supported. Use either torch.uint4 or torch.int8") | ||
|
||
def replace_with_observer(layer): | ||
observer = AWQObserver( | ||
layer.weight, | ||
layer.bias, | ||
block_size, | ||
input_dtype, | ||
mapping_type, | ||
target_dtype, | ||
device, | ||
preserve_zero = preserve_zero, | ||
zero_point_domain = zero_point_domain, | ||
zero_point_dtype = zero_point_dtype, | ||
quant_min=quant_min, | ||
quant_max = quant_max, | ||
eps = eps) | ||
return ObservedLinear.from_float(layer, observer) | ||
_replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) | ||
|
||
def _observed_linear_subclass_inserter(constructor): | ||
def insert_subclass(observed_linear): | ||
linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, observed_linear.bias!=None, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype) | ||
linear.weight = torch.nn.Parameter(constructor(observed_linear), requires_grad=False) | ||
linear.bias = observed_linear.bias | ||
return linear | ||
|
||
return insert_subclass | ||
|
||
def awq_quant(quant_dtype = torch.uint4, group_size = 128): | ||
|
||
def weight_quant_func(observed_linear): | ||
# weight quantization | ||
equalization_scale = observed_linear.act_obs.calculate_qparams() | ||
if quant_dtype == torch.uint4: | ||
mapping_type = MappingType.ASYMMETRIC | ||
block_size = (1, group_size) | ||
target_dtype = torch.uint8 | ||
quant_min = 0 | ||
quant_max = 15 | ||
eps = torch.finfo(torch.float32).eps | ||
preserve_zero = True | ||
zero_point_dtype = torch.int64 | ||
zero_point_domain = ZeroPointDomain.INT | ||
layout_type = AWQLayoutType(equalization_scale, quant_dtype) | ||
|
||
elif quant_dtype == torch.int8: | ||
mapping_type = MappingType.SYMMETRIC | ||
target_dtype = torch.int8 | ||
eps = torch.finfo(torch.float32).eps | ||
zero_point_dtype = torch.int64 | ||
zero_point_domain = ZeroPointDomain.INT | ||
preserve_zero = True | ||
block_size = (1, -1) | ||
quant_min = None | ||
quant_max = None | ||
layout_type = AWQLayoutType(equalization_scale, quant_dtype) | ||
|
||
else: | ||
raise("AWQ supports only uint4 and int8 quantization for now") | ||
|
||
return to_affine_quantized_intx( | ||
observed_linear.weight, | ||
mapping_type, block_size, | ||
target_dtype, quant_min, | ||
quant_max, eps, | ||
zero_point_dtype=zero_point_dtype, | ||
preserve_zero=preserve_zero, | ||
zero_point_domain=zero_point_domain, | ||
layout_type=layout_type) | ||
|
||
return _observed_linear_subclass_inserter(weight_quant_func) | ||
|
||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: does
device=device
not workThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea surprisingly it doesn't
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you open an issue with error messages