Skip to content

[WIP] Activation Aware Weight Quantization (AWQ) #743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 83 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
d3db3db
init
vayuda Aug 15, 2024
1c29347
Merge branch 'awq' of https://github.com/vayuda/ao into awq
vayuda Aug 15, 2024
ed864a2
fixed implementation
vayuda Aug 16, 2024
90be216
Merge branch 'pytorch:main' into awq
vayuda Aug 20, 2024
0e690f9
reduced vmem req
vayuda Aug 21, 2024
05ae692
Merge branch 'awq' of https://github.com/vayuda/ao into awq
vayuda Aug 21, 2024
4519792
eval on LLMs
vayuda Aug 24, 2024
fca7895
Merge branch 'pytorch:main' into awq
vayuda Aug 24, 2024
0096a83
eval on llm
vayuda Aug 24, 2024
20c2529
Merge branch 'awq'
vayuda Aug 24, 2024
7614d51
convert list to tensor
vayuda Aug 24, 2024
33a28dd
restructuring
vayuda Aug 24, 2024
7d389b5
revert unecessary hf_eval changes
vayuda Aug 24, 2024
5913b98
added wikitext eval test
vayuda Aug 25, 2024
7d045f9
added tinygemm integration
vayuda Aug 27, 2024
db302ef
made the calibration step much faster
vayuda Aug 29, 2024
2ec38f1
merge pt1
vayuda Sep 3, 2024
b400711
Merge remote-tracking branch 'upstream/main' into awq
vayuda Sep 3, 2024
4aae94b
works/created tutorial
vayuda Sep 3, 2024
8e058d7
added docs, tests, cleaned code
vayuda Sep 5, 2024
dced6e5
updated benchmark
vayuda Sep 5, 2024
c43b997
update example
vayuda Sep 5, 2024
2388091
update example
vayuda Sep 5, 2024
8619cd5
added init file
vayuda Sep 6, 2024
5378ac9
Merge remote-tracking branch 'upstream/main' into awq
vayuda Sep 6, 2024
9027082
reduce vram for calibration
vayuda Sep 8, 2024
dbac7c8
eval changes+ llama2 data
vayuda Sep 8, 2024
aa62e5f
llama2 data + eval script init changes
vayuda Sep 8, 2024
a4b006f
Merge branch 'main' into awq
vayuda Sep 9, 2024
7f21bfc
fixed qdtype bounds and example import
vayuda Sep 9, 2024
fb7fb11
Merge branch 'awq' of https://github.com/vayuda/ao into awq
vayuda Sep 9, 2024
863d503
fix tests
vayuda Sep 9, 2024
4ca9117
fix tests
vayuda Sep 10, 2024
3e70a6f
use rolling log liklihood for eval and calibrate awq with run_eval
vayuda Sep 10, 2024
c37b396
Merge branch 'pytorch:main' into awq
vayuda Sep 10, 2024
389ea77
Merge remote-tracking branch 'upstream/main' into awq
vayuda Sep 12, 2024
8ab016a
make eval use less vram
vayuda Sep 12, 2024
27c062d
updated uintx import
vayuda Sep 12, 2024
59b4174
Merge remote-tracking branch 'upstream/main' into awq
vayuda Sep 21, 2024
9d52c93
add awq to generate
vayuda Sep 22, 2024
310138e
add calibration params to cli
vayuda Sep 22, 2024
a1b2bd0
fix name
vayuda Sep 22, 2024
9379686
pass linear properly
vayuda Sep 25, 2024
8d173df
recast W*eq_scale to original dtype
vayuda Sep 25, 2024
6aab8f8
revert bad change
vayuda Sep 25, 2024
4e97611
make scales same type as model
vayuda Sep 25, 2024
77db01f
compatible with compile
vayuda Sep 25, 2024
588e81e
cast eq scale to bf16
vayuda Sep 25, 2024
bfa5797
switch calibration dataset
vayuda Sep 25, 2024
41a621b
remove extra line
vayuda Sep 25, 2024
20767d5
add import
vayuda Sep 25, 2024
ae32c7c
added save/load scales
vayuda Sep 28, 2024
e2160ae
add save/store workflow to example
vayuda Sep 28, 2024
9704c38
add arg to fn
vayuda Sep 28, 2024
04ef51d
fix cli arg
vayuda Sep 28, 2024
17e2fbc
fix cli arg
vayuda Sep 28, 2024
71f9e27
add comma
vayuda Sep 28, 2024
1716b0c
add model to fn call
vayuda Sep 28, 2024
436fb9f
fix example
vayuda Sep 28, 2024
84b407c
refactored awq impl
vayuda Oct 1, 2024
49fbbe2
Merge branch 'pytorch:main' into awq
vayuda Oct 1, 2024
1216f97
edits +update usage
vayuda Oct 2, 2024
3930660
perplexity evals added
vayuda Oct 2, 2024
3e5710c
updated readme with benchmarks
vayuda Oct 2, 2024
da6a70d
add awq-hqq to generate
vayuda Oct 2, 2024
7314d99
better citation
vayuda Oct 2, 2024
dc0c507
Merge branch 'pytorch:main' into awq
vayuda Oct 3, 2024
68d9592
nits
vayuda Oct 3, 2024
bc7526e
remove layout.py
vayuda Oct 3, 2024
1fdf068
quantization func changes
vayuda Oct 5, 2024
85ea32c
typo
vayuda Oct 5, 2024
0a12f96
fix import
vayuda Oct 5, 2024
1b5a57b
fix fn params
vayuda Oct 5, 2024
c193afb
edit
vayuda Oct 5, 2024
4e60dfd
edit
vayuda Oct 5, 2024
93dcb79
rename
vayuda Oct 5, 2024
d2ed1f2
fix indentation
vayuda Oct 5, 2024
2650702
added bf16 gaurd for uint4 tinygemm quant
vayuda Oct 5, 2024
5e25469
remove bad tests
vayuda Oct 6, 2024
3aa279f
remove arg
vayuda Oct 6, 2024
c08fdb1
one last guard..
vayuda Oct 7, 2024
b967ebd
require nightly
vayuda Oct 7, 2024
e7e329b
require nightly on everything
vayuda Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions scripts/create_weight_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
This file produces a file named pytorch_model.bin.index.json based on the downloaded model weights.
It was primarily used to create run evals on llama2.c-stories15M model.
"""
import json
import torch
from transformers import AutoModel

def create_weight_map(model_name):
# Load the model
model = AutoModel.from_pretrained(model_name)

# Get the state dict
state_dict = model.state_dict()

# Create the weight map
weight_map = {}
for key, tensor in state_dict.items():
# In this example, we're assuming all weights are in a single file
# You may need to adjust this if your model uses sharded weights
weight_map[key] = "pytorch_model.bin"

# Create the index dictionary
index_dict = {
"metadata": {"total_size": sum(param.numel() * param.element_size() for param in model.parameters())},
"weight_map": weight_map
}

# Save the index dictionary to a JSON file
with open("pytorch_model.bin.index.json", "w") as f:
json.dump(index_dict, f, indent=2)

print("Created pytorch_model.bin.index.json")

# Usage
model_name = "checkpoints/Xenova/llama2.c-stories15M"
create_weight_map(model_name)
26 changes: 23 additions & 3 deletions scripts/hf_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def format_value(value):
def run_evaluation(repo_id, tasks, limit, device, precision, quantization, sparsity, compile, save, batch_size, max_length):

tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(repo_id).to(dtype=precision, device=device)
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision).to(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: does device=device not work

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea surprisingly it doesn't

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you open an issue with error messages


if quantization == "autoquant" and compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)
Expand All @@ -64,9 +64,29 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, spars
quantize_(model, fpx_weight_only(3, 2))
elif quantization == "autoquant":
model = autoquant(model.to(device=device))
elif quantization == "awq":
from datasets import load_dataset
from tqdm import tqdm
from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer, awq_quant

insert_awq_observer(model, precision, device)
wikitext103 = load_dataset("wikitext", "wikitext-103-v1")
wikitext103_train = wikitext103["train"]
wikitext103_calibration = wikitext103_train.select(range(1))
calibration_input_ids = [tokenizer.encode(text, return_tensors="pt") for text in wikitext103_calibration["text"]]
model.to(device)
print("running awq calibration")
for i, ids in tqdm(enumerate(calibration_input_ids)):
if ids.shape[-1] == 0:
continue
model(ids.to(device))


is_observed_linear = lambda m, fqn: isinstance(model, ObservedLinear)
quantize_(model, awq_quant, is_observed_linear)

if quantization != "autoquant" and compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)
model = torch.compile(model, mode= "max-autotune", fullgraph=True)

if sparsity == "semi_sparse":
def all_linear(mod, name):
Expand Down Expand Up @@ -114,7 +134,7 @@ def all_linear(mod, name):
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo", "autoquant", "None"], help='Which quantization technique to apply')
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "awq", "None"], help='Which quantization technique to apply')
parser.add_argument('-s', '--sparsity', default = "None", choices=["semi_sparse", "semi_sparse_mlp_only", "None"], help='Which sparsity technique to apply')
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--save', action='store_true', help='Whether to save the model.')
Expand Down
60 changes: 60 additions & 0 deletions test/prototype/test_awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from copy import deepcopy
import torch
from torchao.quantization import quantize_, int4_weight_only, int8_weight_only
from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer, awq_quant
import pytest


class ToyLinearModel(torch.nn.Module):
def __init__(self, m=512, n=256, k=128):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False)
self.linear2 = torch.nn.Linear(n, k, bias=False)
self.linear3 = torch.nn.Linear(k, 1, bias=False)

def example_inputs(self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"):
return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)]

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
return x

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test():
device = ("cuda")
dataset_size = 1000
original_dtype = torch.bfloat16
l1,l2,l3 = 512,256,128

m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device)
m_bf16 = deepcopy(m)

dataset = m.example_inputs(dataset_size, dtype=original_dtype, device=device)
calibration_data = dataset[:100]
bf16_out = torch.cat([m_bf16(i.squeeze(0)) for i in dataset], dim=0)


m_int4wo = deepcopy(m)
quantize_(m_int4wo, int4_weight_only())
int4wo_out = torch.cat([m_int4wo(i.squeeze(0)) for i in dataset])

# calibrate
quant_dtype = torch.uint4
group_size = 128
insert_awq_observer(m, quant_dtype, group_size, original_dtype, device)
for example in calibration_data:
m(example.to(device))
# print('calibrated')

# quantize
is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear)
quantize_(m, awq_quant(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear)
awq_out = torch.cat([m(i.squeeze(0)) for i in dataset])
m = torch.compile(m, fullgraph=True)
# compare accuracy
awq_err = torch.sum(torch.abs(awq_out - bf16_out)).sum().item() / dataset_size
int4wo_err = torch.sum(torch.abs(int4wo_out - bf16_out)).sum().item() / dataset_size
print(f"AWQ error: {awq_err}")
print(f"Int4WO error: {int4wo_err}")
9 changes: 9 additions & 0 deletions torchao/_models/_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,15 @@ def _model_call(self, inps):
return torch.randn(
(1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device
)

def _model_call(self, inps):
input = self.input_prep_func(inps.to(self._device))

max_seq_length = min(max(inps.size()), self.max_length)
with torch.device(self._device):
self.model_.setup_caches(self.batch_size, max_seq_length)
logits = self.model_(*input)
return logits

# pad or truncate to the right size
if T >= self.calibration_seq_length:
Expand Down
21 changes: 20 additions & 1 deletion torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def run_evaluation(

print("Loading model ...")
t0 = time.time()
model = _load_model(checkpoint_path, "cpu", precision)
model = _load_model(checkpoint_path, "cpu", precision).to(device)
print(model)

if max_length is None:
max_length = model.config.block_size
Expand Down Expand Up @@ -81,6 +82,7 @@ def run_evaluation(
assert "cuda" in device, "int4 gptq quantization only works on cuda"
inputs = InputRecorder(
tokenizer,
model,
calibration_seq_length,
prepare_inputs_for_model,
pad_calibration_inputs,
Expand All @@ -94,6 +96,23 @@ def run_evaluation(
quantizer = Int4WeightOnlyGPTQQuantizer(groupsize=groupsize, device=device)
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
model = quantizer.quantize(model, inputs).to(device)
elif "awq" in quantization:
from torchao.prototype.awq.api import ObservedLinear, insert_awq_observer, awq_quant
insert_awq_observer(model, precision, device)
InputRecorder(
tokenizer,
model,
calibration_seq_length,
prepare_inputs_for_model,
pad_calibration_inputs,
model.config.vocab_size,
device=device
).record_inputs(
calibration_tasks,
calibration_limit,
).get_inputs()
is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear)
quantize_(model, awq_quant, is_observed_linear)
else:
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(model)
Expand Down
5 changes: 2 additions & 3 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
is_device,
get_out_shape,
)

from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from dataclasses import dataclass
from torchao.utils import (
Expand Down Expand Up @@ -1008,8 +1009,7 @@ def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias):
_aqt_is_uint4(weight_tensor) and
weight_tensor.dtype == torch.bfloat16 and
len(weight_tensor.shape) == 2 and
weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT and
isinstance(weight_tensor.layout_type, TensorCoreTiledLayoutType)
weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT
)


Expand All @@ -1019,7 +1019,6 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias):
f"need input_tensor shape: {input_tensor.shape} final"
f"dim to match weight_tensor shape: {weight_tensor.shape} second dim "
)

# TODO: check groupsize quantization
# avoid circular dep, TODO: move this to a common util.py
act_mat = input_tensor
Expand Down
114 changes: 114 additions & 0 deletions torchao/prototype/awq/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import torch
import torch.nn.functional as F
from torchao.prototype.awq.core import AWQObserver, ObservedLinear, AWQLayoutType
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.dtypes import to_affine_quantized_intx
from torchao.dtypes.uintx.Uintx import to_uintx
from typing import Optional, Tuple




def insert_awq_observer(model: torch.nn.Module, quant_dtype: torch.dtype, group_size: int, input_dtype: torch.dtype, device: torch.device):
_is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear)
if quant_dtype == torch.uint4:
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)
target_dtype = torch.uint4
quant_min = 0
quant_max = 15
eps = torch.finfo(torch.float32).eps
preserve_zero = True
zero_point_dtype = torch.int64
zero_point_domain = ZeroPointDomain.INT

elif quant_dtype == torch.int8:
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
zero_point_domain = ZeroPointDomain.INT
preserve_zero = True
block_size = (1, -1)
quant_min = None
quant_max = None

else:
raise NotImplementedError(f"{quant_dtype} not supported. Use either torch.uint4 or torch.int8")

def replace_with_observer(layer):
observer = AWQObserver(
layer.weight,
layer.bias,
block_size,
input_dtype,
mapping_type,
target_dtype,
device,
preserve_zero = preserve_zero,
zero_point_domain = zero_point_domain,
zero_point_dtype = zero_point_dtype,
quant_min=quant_min,
quant_max = quant_max,
eps = eps)
return ObservedLinear.from_float(layer, observer)
_replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear)

def _observed_linear_subclass_inserter(constructor):
def insert_subclass(observed_linear):
linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, observed_linear.bias!=None, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype)
linear.weight = torch.nn.Parameter(constructor(observed_linear), requires_grad=False)
linear.bias = observed_linear.bias
return linear

return insert_subclass

def awq_quant(quant_dtype = torch.uint4, group_size = 128):

def weight_quant_func(observed_linear):
# weight quantization
equalization_scale = observed_linear.act_obs.calculate_qparams()
if quant_dtype == torch.uint4:
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)
target_dtype = torch.uint8
quant_min = 0
quant_max = 15
eps = torch.finfo(torch.float32).eps
preserve_zero = True
zero_point_dtype = torch.int64
zero_point_domain = ZeroPointDomain.INT
layout_type = AWQLayoutType(equalization_scale, quant_dtype)

elif quant_dtype == torch.int8:
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
zero_point_domain = ZeroPointDomain.INT
preserve_zero = True
block_size = (1, -1)
quant_min = None
quant_max = None
layout_type = AWQLayoutType(equalization_scale, quant_dtype)

else:
raise("AWQ supports only uint4 and int8 quantization for now")

return to_affine_quantized_intx(
observed_linear.weight,
mapping_type, block_size,
target_dtype, quant_min,
quant_max, eps,
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=zero_point_domain,
layout_type=layout_type)

return _observed_linear_subclass_inserter(weight_quant_func)


Loading
Loading