Skip to content

INT4 XPU enabling #1577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 79 additions & 60 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
)

from torchao.core.config import AOBaseConfig
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
from torchao.dtypes import (
CutlassInt4PackedLayout,
Int4CPULayout,
Int4XPULayout,
SemiSparseLayout,
)
from torchao.quantization import (
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Expand All @@ -31,7 +36,8 @@
from torchao.testing.utils import skip_if_no_cuda, skip_if_rocm
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
check_cpu_version,
check_xpu_version,
is_fbcode,
is_ROCM,
is_sm_at_least_89,
Expand All @@ -52,15 +58,19 @@ def get_quantization_functions(
int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC),
]
if do_int4:
if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6:
if check_cpu_version(device):
base_functions.append(
int4_weight_only(group_size=32, layout=Int4CPULayout())
)
elif check_xpu_version(device):
base_functions.append(
int4_weight_only(group_size=32, layout=Int4XPULayout())
)
if int4_zp_int:
base_functions.append(
int4_weight_only(
group_size=32,
layout=Int4CPULayout(),
layout=Int4XPULayout(),
zero_point_domain=ZeroPointDomain.INT,
)
)
Expand All @@ -77,7 +87,7 @@ def get_quantization_functions(
)
base_functions.append(int4_dynamic_activation_int4_weight())

if do_sparse:
if do_sparse and device != "xpu":
base_functions.append(
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
)
Expand All @@ -89,6 +99,10 @@ def get_quantization_functions(


class TestAffineQuantized(TestCase):
GPU_DEVICES = (["cuda"] if torch.cuda.is_available() else []) + (
["xpu"] if torch.xpu.is_available() else []
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_tensor_core_layout_transpose(self):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
Expand All @@ -109,51 +123,53 @@ def test_tensor_core_layout_transpose(self):
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize(
"apply_quant",
get_quantization_functions(is_cusparselt_available, True, "cuda", True),
)
@skip_if_rocm("ROCm enablement in progress")
def test_weights_only(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
ql = linear
else:
# TODO(#1690): delete this once config migration is done
ql = apply_quant(linear)
with tempfile.NamedTemporaryFile() as f:
torch.save(ql.state_dict(), f)
f.seek(0)
# `weights_only=True` is enabled for torch 2.5+
if TORCH_VERSION_AT_LEAST_2_5:
_ = torch.load(f, weights_only=True)
else:
_ = torch.load(f, weights_only=False)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available")
def test_weights_only(self):
for device in self.GPU_DEVICES:
apply_quant_list = get_quantization_functions(
is_cusparselt_available, True, device, True
)
for apply_quant in apply_quant_list:
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device)
if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
ql = linear
else:
# TODO(#1690): delete this once config migration is done
ql = apply_quant(linear)
with tempfile.NamedTemporaryFile() as f:
torch.save(ql.state_dict(), f)
f.seek(0)
# `weights_only=True` is enabled for torch 2.5+
if TORCH_VERSION_AT_LEAST_2_5:
_ = torch.load(f, weights_only=True)
else:
_ = torch.load(f, weights_only=False)

@unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available")
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
def test_to_device(self, apply_quant):
def _apply(module, config_or_subclass_inserter):
if isinstance(config_or_subclass_inserter, AOBaseConfig):
quantize_(module, config_or_subclass_inserter)
else:
# TODO(#1690): delete this once config migration is done
module = config_or_subclass_inserter(module)
return module
for device in self.GPU_DEVICES:

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = _apply(linear, apply_quant)
ql.to("cuda")
def _apply(module, config_or_subclass_inserter):
if isinstance(config_or_subclass_inserter, AOBaseConfig):
quantize_(module, config_or_subclass_inserter)
else:
# TODO(#1690): delete this once config migration is done
module = config_or_subclass_inserter(module)
return module

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = _apply(linear, apply_quant)
ql.to(device="cuda")
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = _apply(linear, apply_quant)
ql.to(device)

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = _apply(linear, apply_quant)
ql.cuda()
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = _apply(linear, apply_quant)
ql.to(device=device)

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = _apply(linear, apply_quant)
ql.to(device)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_register_new_dispatch(self):
Expand Down Expand Up @@ -203,20 +219,19 @@ def apply_uint6_weight_only_quant(linear):

deregister_aqt_quantized_linear_dispatch(dispatch_condition)

@common_utils.parametrize(
"apply_quant", get_quantization_functions(is_cusparselt_available, True)
)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@skip_if_rocm("ROCm enablement in progress")
def test_print_quantized_module(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
ql = linear
else:
# TODO(#1690): delete this once config migration is done
ql = apply_quant(linear)
assert "AffineQuantizedTensor" in str(ql)
@unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available")
def test_print_quantized_module(self):
for device in self.GPU_DEVICES:
apply_quant_list = get_quantization_functions(True, True, device, True)
for apply_quant in apply_quant_list:
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device)
if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
ql = linear
else:
# TODO(#1690): delete this once config migration is done
ql = apply_quant(linear)
assert "AffineQuantizedTensor" in str(ql)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize(
Expand Down Expand Up @@ -267,7 +282,11 @@ def test_copy__mismatch_metadata(self, apply_quant):


class TestAffineQuantizedBasic(TestCase):
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
COMMON_DEVICES = (
["cpu"]
+ (["cuda"] if torch.cuda.is_available() else [])
+ (["xpu"] if torch.xpu.is_available() else [])
)
COMMON_DTYPES = [torch.bfloat16]

@common_utils.parametrize("device", COMMON_DEVICES)
Expand Down
19 changes: 12 additions & 7 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
from torch._inductor.utils import run_and_get_code

import torchao
from torchao.dtypes import Int4CPULayout, TensorCoreTiledLayout
from torchao.dtypes.utils import is_device
from torchao.dtypes import Int4CPULayout, Int4XPULayout, TensorCoreTiledLayout
from torchao.quantization import safe_int_mm
from torchao.quantization.autoquant import (
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
Expand Down Expand Up @@ -84,6 +83,8 @@
TORCH_VERSION_AT_LEAST_2_6,
TORCH_VERSION_AT_LEAST_2_7,
benchmark_model,
check_cpu_version,
check_xpu_version,
is_fbcode,
is_sm_at_least_90,
unwrap_tensor_subclass,
Expand Down Expand Up @@ -146,17 +147,19 @@ def _int8da_int8w_api(


def _int4wo_api(mod, use_hqq=False):
if (
is_device(next(mod.parameters()).device.type, "cpu")
and TORCH_VERSION_AT_LEAST_2_6
):
if check_cpu_version(next(mod.parameters()).device):
quantize_(
mod,
int4_weight_only(
layout=Int4CPULayout(), use_hqq=use_hqq, set_inductor_config=False
),
)
unwrap_tensor_subclass(mod)
elif check_xpu_version(next(mod.parameters()).device):
quantize_(
mod, int4_weight_only(layout=Int4XPULayout()), set_inductor_config=False
)
unwrap_tensor_subclass(mod)
elif TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int4_weight_only(set_inductor_config=False))
if not TORCH_VERSION_AT_LEAST_2_5:
Expand Down Expand Up @@ -1129,8 +1132,10 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
layout_list = []
if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6:
if check_cpu_version(device):
layout_list.append(Int4CPULayout())
elif check_xpu_version(device):
layout_list.append(Int4XPULayout())
else:
for inner_k_tiles in [4, 2]:
layout_list.append(TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles))
Expand Down
75 changes: 58 additions & 17 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
from torchao import quantize_
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
from torchao._models.llama.tokenizer import get_tokenizer
from torchao.dtypes import AffineQuantizedTensor
from torchao.dtypes import (
AffineQuantizedTensor,
Int4CPULayout,
Int4XPULayout,
)
from torchao.quantization import LinearActivationQuantizedTensor
from torchao.quantization.quant_api import (
Quantizer,
Expand Down Expand Up @@ -54,6 +58,7 @@
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
TORCH_VERSION_AT_LEAST_2_8,
is_sm_at_least_89,
is_sm_at_least_90,
unwrap_tensor_subclass,
Expand Down Expand Up @@ -189,6 +194,10 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):


class TestQuantFlow(TestCase):
GPU_DEVICES = (["cuda"] if torch.cuda.is_available() else []) + (
["xpu"] if torch.xpu.is_available() else []
)

def test_dynamic_quant_gpu_singleline(self):
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
Expand Down Expand Up @@ -229,6 +238,34 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
compiled = m(*example_inputs)
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)

@unittest.skipIf(not torch.xpu.is_available(), "Need XPU available")
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "only works for torch 2.8+")
def test_int4_wo_quant_save_load(self):
m = ToyLinearModel().eval().cpu()

def api(model):
quantize_(model, int4_weight_only(layout=Int4XPULayout()))
unwrap_tensor_subclass(model)

api(m)

example_inputs = m.example_inputs()
ref = m(*example_inputs)
with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)

m2 = ToyLinearModel().eval().cpu()
api(m2)

m2.load_state_dict(state_dict)
m2 = m2.to(device="xpu")
example_inputs = map(lambda x: x.xpu(), example_inputs)
res = m2(*example_inputs)

torch.testing.assert_close(ref, res.cpu())

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "only works for torch 2.4+")
def test_int8_wo_quant_save_load(self):
Expand Down Expand Up @@ -615,25 +652,31 @@ def test_quantized_tensor_subclass_8da4w(self, mapping_type):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "Test currently doesn't work for 2.5+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available")
def test_quantized_tensor_subclass_int4(self):
# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")
for device in self.GPU_DEVICES:
# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to(device)
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device)

group_size = 32
quantize_(m, int4_weight_only(group_size=group_size))
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
group_size = 32
if device == "xpu":
quantize_(
m, int4_weight_only(group_size=group_size, layout=Int4XPULayout())
)
else:
quantize_(m, int4_weight_only(group_size=group_size))
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

# reference
_ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=group_size)
# reference
_ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=group_size)

res = m(*example_inputs)
ref = m_copy(*example_inputs)
res = m(*example_inputs)
ref = m_copy(*example_inputs)

self.assertTrue(torch.equal(res, ref))
self.assertTrue(torch.equal(res, ref))

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down Expand Up @@ -799,8 +842,6 @@ def reset_memory():
@common_utils.parametrize("x_dim", [2, 3])
@common_utils.parametrize("use_hqq", [True, False])
def test_int4wo_cpu(self, dtype, x_dim, use_hqq):
from torchao.dtypes import Int4CPULayout

device = "cpu"
m = ToyLinearModel().eval().to(dtype).to(device)
example_inputs = m.example_inputs(dtype=dtype, device=device)
Expand Down
Loading
Loading