Skip to content

Add Int4CPULayout and update int4 woq #1278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ jobs:
torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"

- name: CPU 2.3
runs-on: linux.4xlarge
torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu'
Expand Down
55 changes: 29 additions & 26 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
run_tests,
)

from torchao.dtypes import SemiSparseLayout
from torchao.dtypes import Int4CPULayout, SemiSparseLayout
from torchao.quantization import (
float8_weight_only,
int4_weight_only,
Expand All @@ -17,20 +17,25 @@
int8_weight_only,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6

is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)


def get_quantization_functions(do_sparse: bool, do_int4: bool):
def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"):
base_functions = [
int8_weight_only(),
int8_dynamic_activation_int4_weight(),
int8_dynamic_activation_int8_weight(),
int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC),
]
if do_int4:
base_functions.append(int4_weight_only(group_size=32))
if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6:
base_functions.append(
int4_weight_only(group_size=32, layout=Int4CPULayout())
)
else:
base_functions.append(int4_weight_only(group_size=32))

if do_sparse:
base_functions.append(
Expand Down Expand Up @@ -152,30 +157,28 @@ class TestAffineQuantizedBasic(TestCase):
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
COMMON_DTYPES = [torch.bfloat16]

@common_utils.parametrize("apply_quant", get_quantization_functions(False, True))
@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_flatten_unflatten(self, apply_quant, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")

linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
ql = apply_quant(linear)
lp_tensor = ql.weight
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
tensor_data_dict = {
name: getattr(lp_tensor, name) for name in tensor_data_name_dict
}
outer_size = lp_tensor.size()
outer_stride = lp_tensor.stride()
reconstructed = type(lp_tensor).__tensor_unflatten__(
tensor_data_dict, tensor_attributes, outer_size, outer_stride
)
example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),)
ref = ql(*example_inputs)
ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False)
reconstruct_res = ql(*example_inputs)
self.assertEqual(reconstruct_res, ref)
def test_flatten_unflatten(self, device, dtype):
apply_quant_list = get_quantization_functions(False, True, device)
for apply_quant in apply_quant_list:
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
ql = apply_quant(linear)
lp_tensor = ql.weight
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
tensor_data_dict = {
name: getattr(lp_tensor, name) for name in tensor_data_name_dict
}
outer_size = lp_tensor.size()
outer_stride = lp_tensor.stride()
reconstructed = type(lp_tensor).__tensor_unflatten__(
tensor_data_dict, tensor_attributes, outer_size, outer_stride
)
example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),)
ref = ql(*example_inputs)
ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False)
reconstruct_res = ql(*example_inputs)
self.assertEqual(reconstruct_res, ref)


common_utils.instantiate_parametrized_tests(TestAffineQuantized)
Expand Down
18 changes: 14 additions & 4 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torchao.quantization.dynamic_quant import (
DynamicallyPerAxisQuantizedLinear,
)
from torchao.dtypes import TensorCoreTiledLayout
from torchao.dtypes import TensorCoreTiledLayout, Int4CPULayout
from torchao.quantization.quant_api import (
int4_weight_only,
int8_weight_only,
Expand Down Expand Up @@ -93,6 +93,7 @@
is_fbcode,
benchmark_model
)
from torchao.dtypes.utils import is_device

logger = logging.getLogger("INFO")

Expand Down Expand Up @@ -133,7 +134,10 @@ def _int8da_int8w_api(mod):
change_linear_weights_to_int8_dqtensors(mod)

def _int4wo_api(mod):
if TORCH_VERSION_AT_LEAST_2_4:
if is_device(next(mod.parameters()).device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6:
quantize_(mod, int4_weight_only(layout=Int4CPULayout()), set_inductor_config=False)
unwrap_tensor_subclass(mod)
elif TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int4_weight_only(), set_inductor_config=False)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
Expand Down Expand Up @@ -935,10 +939,16 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
self.skipTest(f"Temporarily skipping for {device}")
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
layout_list = []
if device == 'cpu' and TORCH_VERSION_AT_LEAST_2_6:
layout_list.append(Int4CPULayout())
else:
for inner_k_tiles in [4, 2]:
layout_list.append(TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles))
for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])):
for groupsize in [64, 32]:
for inner_k_tiles in [4, 2]:
kwargs = {"groupsize": groupsize, "layout": TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)}
for layout in layout_list:
kwargs = {"groupsize": groupsize, "layout": layout}

def api(mod):
kwargs_copy = kwargs.copy()
Expand Down
10 changes: 7 additions & 3 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
TORCH_VERSION_AT_LEAST_2_6,
is_fbcode,
)
from torchao.dtypes.utils import is_device

_SEED = 1234
torch.manual_seed(_SEED)
Expand Down Expand Up @@ -102,7 +103,8 @@ def _groupwise_affine_quantize_tensor_from_qparams(
.reshape_as(w)
)
if TORCH_VERSION_AT_LEAST_2_5:
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
if not (is_device(w.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6):
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)

return w_int4x8

Expand Down Expand Up @@ -524,8 +526,10 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
groupsize = 128

if TORCH_VERSION_AT_LEAST_2_5:
input_uint8 = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_uint8, scales, zeros, n_bit, groupsize)
input_tmp = input
if not (is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6):
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_tmp, scales, zeros, n_bit, groupsize)
else:
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
Expand Down
2 changes: 2 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .nf4tensor import NF4Tensor, to_nf4
from .uintx import (
BlockSparseLayout,
Int4CPULayout,
MarlinQQQLayout,
MarlinSparseLayout,
SemiSparseLayout,
Expand Down Expand Up @@ -48,4 +49,5 @@
"UintxLayout",
"MarlinQQQTensor",
"MarlinQQQLayout",
"Int4CPULayout",
]
2 changes: 2 additions & 0 deletions torchao/dtypes/uintx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
SemiSparseLayout,
)
from .tensor_core_tiled_layout import (
Int4CPULayout,
TensorCoreTiledLayout,
)
from .uintx_layout import (
Expand All @@ -23,5 +24,6 @@
"MarlinSparseLayout",
"SemiSparseLayout",
"TensorCoreTiledLayout",
"Int4CPULayout",
"MarlinQQQLayout",
]
Loading