Skip to content

Refactor the API for quant method argument for quantize function #400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ All with no intrusive code changes and minimal accuracy degradation.
Quantizing your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/) and a HuggingFace inference example [here](scripts/hf_eval.py)

```python
from torchao.quantization.quant_api import quantize
m = quantize(m, "int4wo")
from torchao.quantization.quant_api import quantize, int4_weight_only
m = quantize(m, int4_weight_only())
```

Benchmarks are run on a machine with a single A100 GPU using the script in `_models/llama` which generates text in a latency-optimized way (batchsize=1)
Expand Down
6 changes: 3 additions & 3 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
TestCase,
run_tests,
)
from torchao.quantization.quant_api import int4wo
from torchao.quantization.quant_api import int4_weight_only
import torch
import unittest

Expand All @@ -12,8 +12,8 @@ class TestAffineQuantized(TestCase):
def test_tensor_core_layout_transpose(self):
t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda")
shape = t.shape
apply_int4wo_quant = int4wo(groupsize=32)
aqt = apply_int4wo_quant(t)
apply_int4_weight_only_quant = int4_weight_only(group_size=32)
aqt = apply_int4_weight_only_quant(t)
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)

Expand Down
21 changes: 12 additions & 9 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
DynamicallyPerAxisQuantizedLinear,
)
from torchao.quantization.quant_api import (
int4wo,
int8wo,
int8da_int8w,
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int8_weight,
quantize,
_replace_with_custom_fn_if_matches_filter,
)
Expand Down Expand Up @@ -98,21 +98,21 @@

def _int8wo_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int8wo())
quantize(mod, int8_weight_only())
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_woqtensors(mod)

def _int8da_int8w_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int8da_int8w())
quantize(mod, int8_dynamic_activation_int8_weight())
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_dqtensors(mod)

def _int4wo_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int4wo())
quantize(mod, int4_weight_only())
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int4_woqtensors(mod)
Expand Down Expand Up @@ -832,7 +832,10 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):

def api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int4wo(**kwargs))
kwargs_copy = kwargs.copy()
kwargs_copy["group_size"] = groupsize
del kwargs_copy["groupsize"]
quantize(mod, int4_weight_only(**kwargs_copy))
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int4_woqtensors(mod, **kwargs)
Expand All @@ -853,7 +856,7 @@ def test_dynamic_quant(self):
m = nn.Sequential(nn.Linear(K, N))

y_ref = m(x)
quantize(m, int8da_int8w())
quantize(m, int8_dynamic_activation_int8_weight())
y_test = m(x)

sqnr = compute_error(y_ref, y_test)
Expand Down Expand Up @@ -1463,7 +1466,7 @@ def test_get_model_size_aqt(self, api, test_device, test_dtype):
api(model)
size2 = torchao.utils.get_model_size_in_bytes(model)
self.assertTrue(size2 < size)




Expand Down
51 changes: 18 additions & 33 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
_replace_with_custom_fn_if_matches_filter,
Quantizer,
TwoStepQuantizer,
int8da_int4w,
int4wo,
int8wo,
int8da_int8w,
int8_dynamic_activation_int4_weight,
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int8_weight,
)
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
Expand Down Expand Up @@ -89,7 +89,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module:

class TorchCompileDynamicQuantizer(Quantizer):
def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
quantize(model, int8da_int8w())
quantize(model, int8_dynamic_activation_int8_weight())
return model

class ToyLinearModel(torch.nn.Module):
Expand Down Expand Up @@ -152,7 +152,7 @@ class TestQuantFlow(TestCase):
def test_dynamic_quant_gpu_singleline(self):
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
m = quantize(m, int8da_int8w())
m = quantize(m, int8_dynamic_activation_int8_weight())
quantized = m(*example_inputs)
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
Expand Down Expand Up @@ -195,7 +195,7 @@ def test_int8_wo_quant_save_load(self):
)
m = ToyLinearModel().eval().cpu()
def api(model):
model = quantize(model, int8wo())
model = quantize(model, int8_weight_only())
unwrap_tensor_subclass(model)

api(m)
Expand Down Expand Up @@ -335,7 +335,7 @@ def test_8da4w_quantizer_eval(self):
)

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_gptq_quantizer_int4wo(self):
def test_gptq_quantizer_int4_weight_only(self):
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
from torchao._models._eval import InputRecorder, TransformerEvalWrapper
torchao._models.llama.model.use_index_put_for_kv_cache = True
Expand Down Expand Up @@ -397,7 +397,7 @@ def test_gptq_quantizer_int4wo(self):
)

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_quantizer_int4wo(self):
def test_quantizer_int4_weight_only(self):
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
from torchao._models._eval import TransformerEvalWrapper
precision = torch.bfloat16
Expand Down Expand Up @@ -499,11 +499,11 @@ def test_eval_wrapper_llama3(self):
# TODO: move to a separate test file
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
def test_quantized_tensor_subclass_8da4w(self):
groupsize = 32
group_size = 32
m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()
m = quantize(m, int8da_int4w(groupsize=groupsize))
m = quantize(m, int8_dynamic_activation_int4_weight(group_size=group_size))

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
Expand All @@ -514,7 +514,7 @@ def test_quantized_tensor_subclass_8da4w(self):
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear

quantizer = Int8DynActInt4WeightQuantizer(groupsize=groupsize)
quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size)
m_copy = quantizer.quantize(m_copy)
assert isinstance(m_copy.linear1, Int8DynActInt4WeightLinear)
assert isinstance(m_copy.linear2, Int8DynActInt4WeightLinear)
Expand All @@ -531,13 +531,13 @@ def test_quantized_tensor_subclass_int4(self):
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")

groupsize = 32
m = quantize(m, int4wo(groupsize=groupsize))
group_size = 32
m = quantize(m, int4_weight_only(group_size=group_size))
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

# reference
_ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize)
_ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=group_size)

res = m(*example_inputs)
ref = m_copy(*example_inputs)
Expand All @@ -552,7 +552,7 @@ def test_quantized_tensor_subclass_int8_wo(self):
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))

m = quantize(m, int8wo())
m = quantize(m, int8_weight_only())

assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
Expand All @@ -575,7 +575,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
m_copy = copy.deepcopy(m)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
m = quantize(m, int8da_int8w())
m = quantize(m, int8_dynamic_activation_int8_weight())

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
Expand All @@ -602,29 +602,14 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
# make sure it compiles
torch._export.aot_compile(m_unwrapped, example_inputs)

def test_register_apply_tensor_subclass(self):
from torchao import register_apply_tensor_subclass
def apply_my_dtype(weight):
return weight * 2

m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
with self.assertRaisesRegex(ValueError, "not supported"):
quantize(m, "my_dtype")

register_apply_tensor_subclass("my_dtype", apply_my_dtype)
# make sure it runs
quantize(m, "my_dtype")
m(*example_inputs)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_save_load(self):
m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16)

m = quantize(m, "int8_weight_only")
m = quantize(m, int8_weight_only())
ref = m(*example_inputs)
with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
Expand Down
2 changes: 0 additions & 2 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,11 @@
from torchao.quantization import (
autoquant,
quantize,
register_apply_tensor_subclass,
)
from . import dtypes

__all__ = [
"dtypes",
"autoquant",
"quantize",
"register_apply_tensor_subclass",
]
2 changes: 1 addition & 1 deletion torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .nf4tensor import NF4Tensor, to_nf4
# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
from .uint4 import UInt4Tensor
from .aqt import AffineQuantizedTensor, to_affine_quantized
from .affine_quantized_tensor import AffineQuantizedTensor, to_affine_quantized

__all__ = [
"NF4Tensor",
Expand Down
File renamed without changes.
22 changes: 10 additions & 12 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ from torch._inductor.runtime.runtime_utils import do_bench_gpu
import copy
from torchao.quantization.quant_api import (
quantize,
int4wo,
int4_weight_only,
)

class ToyLinearModel(torch.nn.Module):
Expand All @@ -104,8 +104,8 @@ example_inputs = m.example_inputs(dtype=dtype, device="cuda")

m_bf16 = torch.compile(m_bf16, mode='max-autotune')
# apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao)
groupsize = 32
m = quantize(m, int4wo(groupsize=groupsize))
group_size = 32
m = quantize(m, int4_weight_only(group_size=group_size))

torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True
Expand Down Expand Up @@ -152,7 +152,7 @@ for n, m in model.named_modules():
The model/tensor subclass should also be compatible with AOTI and torch.export, currently we can support
`torch.export.export` and `torch.aot_compile` with the following workaround:
```
from torchao.quantization.utils import unwrap_tensor_subclass
from torchao.utils import unwrap_tensor_subclass
m_unwrapped = unwrap_tensor_subclass(m)


Expand All @@ -169,11 +169,10 @@ torch._export.aot_compile(m_unwrapped, example_inputs)
```python
# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor
torch._inductor.config.force_fuse_int_mm_with_mul = True
from torchao.quantization import quant_api

# for torch 2.4+
from torchao.quantization.quant_api import quantize
quantize(model, "int8_dynamic")
from torchao.quantization import quantize, int8_dynamic_activation_int8_weight
quantize(model, int8_dynamic_activation_int8_weight())

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
Expand All @@ -184,9 +183,8 @@ change_linear_weights_to_int8_dqtensors(model)

```python
# for torch 2.4+
from torchao.quantization.quant_api import quantize
from torchao.quantization.quant_api import int8wo
quantize(model, "int8_weight_only")
from torchao.quantization import quantize, int8_weight_only
quantize(model, int8_weight_only())

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
Expand All @@ -200,8 +198,8 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is

```python
# for torch 2.4+
from torchao.quantization.quant_api import quantize
quantize(model, "int4_weight_only")
from torchao.quantization import quantize, int4_weight_only
quantize(model, int4_weight_only())

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
Expand Down
5 changes: 4 additions & 1 deletion torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,8 @@
"dequantize_affine",
"choose_qprams_affine",
"quantize",
"register_apply_tensor_subclass",
"int8_dynamic_act_int4_weight",
"int8_dynamic_act_int8_weight",
"int4_weight_only",
"int8_weight_only",
]
Loading
Loading