diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 4064bff535..9db421ad56 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -15,7 +15,12 @@ ) from torchao.core.config import AOBaseConfig -from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout +from torchao.dtypes import ( + CutlassInt4PackedLayout, + Int4CPULayout, + Int4XPULayout, + SemiSparseLayout, +) from torchao.quantization import ( Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig, @@ -31,7 +36,8 @@ from torchao.testing.utils import skip_if_no_cuda, skip_if_rocm from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, + check_cpu_version, + check_xpu_version, is_fbcode, is_ROCM, is_sm_at_least_89, @@ -52,15 +58,19 @@ def get_quantization_functions( int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC), ] if do_int4: - if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6: + if check_cpu_version(device): base_functions.append( int4_weight_only(group_size=32, layout=Int4CPULayout()) ) + elif check_xpu_version(device): + base_functions.append( + int4_weight_only(group_size=32, layout=Int4XPULayout()) + ) if int4_zp_int: base_functions.append( int4_weight_only( group_size=32, - layout=Int4CPULayout(), + layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.INT, ) ) @@ -77,7 +87,7 @@ def get_quantization_functions( ) base_functions.append(int4_dynamic_activation_int4_weight()) - if do_sparse: + if do_sparse and device != "xpu": base_functions.append( int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) ) @@ -89,6 +99,10 @@ def get_quantization_functions( class TestAffineQuantized(TestCase): + GPU_DEVICES = (["cuda"] if torch.cuda.is_available() else []) + ( + ["xpu"] if torch.xpu.is_available() else [] + ) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_tensor_core_layout_transpose(self): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") @@ -109,51 +123,53 @@ def test_tensor_core_layout_transpose(self): aqt_shape = aqt.shape self.assertEqual(aqt_shape, shape) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @common_utils.parametrize( - "apply_quant", - get_quantization_functions(is_cusparselt_available, True, "cuda", True), - ) - @skip_if_rocm("ROCm enablement in progress") - def test_weights_only(self, apply_quant): - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - if isinstance(apply_quant, AOBaseConfig): - quantize_(linear, apply_quant) - ql = linear - else: - # TODO(#1690): delete this once config migration is done - ql = apply_quant(linear) - with tempfile.NamedTemporaryFile() as f: - torch.save(ql.state_dict(), f) - f.seek(0) - # `weights_only=True` is enabled for torch 2.5+ - if TORCH_VERSION_AT_LEAST_2_5: - _ = torch.load(f, weights_only=True) - else: - _ = torch.load(f, weights_only=False) - - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") + def test_weights_only(self): + for device in self.GPU_DEVICES: + apply_quant_list = get_quantization_functions( + is_cusparselt_available, True, device, True + ) + for apply_quant in apply_quant_list: + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device) + if isinstance(apply_quant, AOBaseConfig): + quantize_(linear, apply_quant) + ql = linear + else: + # TODO(#1690): delete this once config migration is done + ql = apply_quant(linear) + with tempfile.NamedTemporaryFile() as f: + torch.save(ql.state_dict(), f) + f.seek(0) + # `weights_only=True` is enabled for torch 2.5+ + if TORCH_VERSION_AT_LEAST_2_5: + _ = torch.load(f, weights_only=True) + else: + _ = torch.load(f, weights_only=False) + + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") @common_utils.parametrize("apply_quant", get_quantization_functions(False, False)) def test_to_device(self, apply_quant): - def _apply(module, config_or_subclass_inserter): - if isinstance(config_or_subclass_inserter, AOBaseConfig): - quantize_(module, config_or_subclass_inserter) - else: - # TODO(#1690): delete this once config migration is done - module = config_or_subclass_inserter(module) - return module + for device in self.GPU_DEVICES: - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = _apply(linear, apply_quant) - ql.to("cuda") + def _apply(module, config_or_subclass_inserter): + if isinstance(config_or_subclass_inserter, AOBaseConfig): + quantize_(module, config_or_subclass_inserter) + else: + # TODO(#1690): delete this once config migration is done + module = config_or_subclass_inserter(module) + return module - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = _apply(linear, apply_quant) - ql.to(device="cuda") + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + ql = _apply(linear, apply_quant) + ql.to(device) - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = _apply(linear, apply_quant) - ql.cuda() + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + ql = _apply(linear, apply_quant) + ql.to(device=device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + ql = _apply(linear, apply_quant) + ql.to(device) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_register_new_dispatch(self): @@ -203,20 +219,19 @@ def apply_uint6_weight_only_quant(linear): deregister_aqt_quantized_linear_dispatch(dispatch_condition) - @common_utils.parametrize( - "apply_quant", get_quantization_functions(is_cusparselt_available, True) - ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @skip_if_rocm("ROCm enablement in progress") - def test_print_quantized_module(self, apply_quant): - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - if isinstance(apply_quant, AOBaseConfig): - quantize_(linear, apply_quant) - ql = linear - else: - # TODO(#1690): delete this once config migration is done - ql = apply_quant(linear) - assert "AffineQuantizedTensor" in str(ql) + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") + def test_print_quantized_module(self): + for device in self.GPU_DEVICES: + apply_quant_list = get_quantization_functions(True, True, device, True) + for apply_quant in apply_quant_list: + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device) + if isinstance(apply_quant, AOBaseConfig): + quantize_(linear, apply_quant) + ql = linear + else: + # TODO(#1690): delete this once config migration is done + ql = apply_quant(linear) + assert "AffineQuantizedTensor" in str(ql) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize( @@ -267,7 +282,11 @@ def test_copy__mismatch_metadata(self, apply_quant): class TestAffineQuantizedBasic(TestCase): - COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + COMMON_DEVICES = ( + ["cpu"] + + (["cuda"] if torch.cuda.is_available() else []) + + (["xpu"] if torch.xpu.is_available() else []) + ) COMMON_DTYPES = [torch.bfloat16] @common_utils.parametrize("device", COMMON_DEVICES) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 9bbf625fc4..36cddc7a0f 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -19,8 +19,7 @@ from torch._inductor.utils import run_and_get_code import torchao -from torchao.dtypes import Int4CPULayout, TensorCoreTiledLayout -from torchao.dtypes.utils import is_device +from torchao.dtypes import Int4CPULayout, Int4XPULayout, TensorCoreTiledLayout from torchao.quantization import safe_int_mm from torchao.quantization.autoquant import ( AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, @@ -84,6 +83,8 @@ TORCH_VERSION_AT_LEAST_2_6, TORCH_VERSION_AT_LEAST_2_7, benchmark_model, + check_cpu_version, + check_xpu_version, is_fbcode, is_sm_at_least_90, unwrap_tensor_subclass, @@ -146,10 +147,7 @@ def _int8da_int8w_api( def _int4wo_api(mod, use_hqq=False): - if ( - is_device(next(mod.parameters()).device.type, "cpu") - and TORCH_VERSION_AT_LEAST_2_6 - ): + if check_cpu_version(next(mod.parameters()).device): quantize_( mod, int4_weight_only( @@ -157,6 +155,11 @@ def _int4wo_api(mod, use_hqq=False): ), ) unwrap_tensor_subclass(mod) + elif check_xpu_version(next(mod.parameters()).device): + quantize_( + mod, int4_weight_only(layout=Int4XPULayout()), set_inductor_config=False + ) + unwrap_tensor_subclass(mod) elif TORCH_VERSION_AT_LEAST_2_4: quantize_(mod, int4_weight_only(set_inductor_config=False)) if not TORCH_VERSION_AT_LEAST_2_5: @@ -1129,8 +1132,10 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") layout_list = [] - if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6: + if check_cpu_version(device): layout_list.append(Int4CPULayout()) + elif check_xpu_version(device): + layout_list.append(Int4XPULayout()) else: for inner_k_tiles in [4, 2]: layout_list.append(TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index ac8ca1dccb..cce8d17c19 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -24,7 +24,11 @@ from torchao import quantize_ from torchao._models.llama.model import Transformer, prepare_inputs_for_model from torchao._models.llama.tokenizer import get_tokenizer -from torchao.dtypes import AffineQuantizedTensor +from torchao.dtypes import ( + AffineQuantizedTensor, + Int4CPULayout, + Int4XPULayout, +) from torchao.quantization import LinearActivationQuantizedTensor from torchao.quantization.quant_api import ( Quantizer, @@ -54,6 +58,7 @@ TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_89, is_sm_at_least_90, unwrap_tensor_subclass, @@ -189,6 +194,10 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs): class TestQuantFlow(TestCase): + GPU_DEVICES = (["cuda"] if torch.cuda.is_available() else []) + ( + ["xpu"] if torch.xpu.is_available() else [] + ) + def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() example_inputs = m.example_inputs() @@ -229,6 +238,34 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): compiled = m(*example_inputs) torch.testing.assert_close(quantized, compiled, atol=0, rtol=0) + @unittest.skipIf(not torch.xpu.is_available(), "Need XPU available") + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "only works for torch 2.8+") + def test_int4_wo_quant_save_load(self): + m = ToyLinearModel().eval().cpu() + + def api(model): + quantize_(model, int4_weight_only(layout=Int4XPULayout())) + unwrap_tensor_subclass(model) + + api(m) + + example_inputs = m.example_inputs() + ref = m(*example_inputs) + with tempfile.NamedTemporaryFile() as f: + torch.save(m.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + + m2 = ToyLinearModel().eval().cpu() + api(m2) + + m2.load_state_dict(state_dict) + m2 = m2.to(device="xpu") + example_inputs = map(lambda x: x.xpu(), example_inputs) + res = m2(*example_inputs) + + torch.testing.assert_close(ref, res.cpu()) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "only works for torch 2.4+") def test_int8_wo_quant_save_load(self): @@ -615,25 +652,31 @@ def test_quantized_tensor_subclass_8da4w(self, mapping_type): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "Test currently doesn't work for 2.5+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_quantized_tensor_subclass_int4(self): - # use 1024 so that we don't need padding - m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") - m_copy = copy.deepcopy(m) - example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda") + for device in self.GPU_DEVICES: + # use 1024 so that we don't need padding + m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to(device) + m_copy = copy.deepcopy(m) + example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device) - group_size = 32 - quantize_(m, int4_weight_only(group_size=group_size)) - assert isinstance(m.linear1.weight, AffineQuantizedTensor) - assert isinstance(m.linear2.weight, AffineQuantizedTensor) + group_size = 32 + if device == "xpu": + quantize_( + m, int4_weight_only(group_size=group_size, layout=Int4XPULayout()) + ) + else: + quantize_(m, int4_weight_only(group_size=group_size)) + assert isinstance(m.linear1.weight, AffineQuantizedTensor) + assert isinstance(m.linear2.weight, AffineQuantizedTensor) - # reference - _ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=group_size) + # reference + _ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=group_size) - res = m(*example_inputs) - ref = m_copy(*example_inputs) + res = m(*example_inputs) + ref = m_copy(*example_inputs) - self.assertTrue(torch.equal(res, ref)) + self.assertTrue(torch.equal(res, ref)) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @@ -799,8 +842,6 @@ def reset_memory(): @common_utils.parametrize("x_dim", [2, 3]) @common_utils.parametrize("use_hqq", [True, False]) def test_int4wo_cpu(self, dtype, x_dim, use_hqq): - from torchao.dtypes import Int4CPULayout - device = "cpu" m = ToyLinearModel().eval().to(dtype).to(device) example_inputs = m.example_inputs(dtype=dtype, device=device) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 3ca58ff996..861ebe5e94 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -11,7 +11,6 @@ import torch from parameterized import parameterized -from torchao.dtypes.utils import is_device from torchao.float8.float8_utils import EPS as float8_eps from torchao.quantization.quant_primitives import ( MappingType, @@ -38,6 +37,8 @@ TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + check_cpu_version, + check_xpu_version, is_fbcode, ) @@ -68,6 +69,7 @@ def _get_groupwise_affine_qparams( groupsize=128, dtype=torch.bfloat16, zero_point_domain=ZeroPointDomain.FLOAT, + zero_point_dtype=torch.bfloat16, ): if groupsize > w.shape[-1]: groupsize = w.shape[-1] @@ -86,11 +88,11 @@ def _get_groupwise_affine_qparams( scales = (max_val - min_val).clamp(min=1e-6) / max_int if zero_point_domain == ZeroPointDomain.FLOAT: zeros = min_val + scales * (2 ** (n_bit - 1)) - zeros = zeros.to(dtype=dtype).reshape(w.shape[0], -1) + zeros = zeros.to(dtype=zero_point_dtype).reshape(w.shape[0], -1) else: zeros = quant_min - torch.round(min_val / scales) zeros = torch.clamp(zeros, quant_min, quant_max) - zeros = zeros.to(dtype=dtype).reshape(w.shape[0], -1) + zeros = zeros.to(dtype=zero_point_dtype).reshape(w.shape[0], -1) scales = scales.to(dtype=dtype).reshape(w.shape[0], -1) return scales, zeros @@ -135,7 +137,7 @@ def _groupwise_affine_quantize_tensor_from_qparams( ) if TORCH_VERSION_AT_LEAST_2_5: - if not (is_device(w.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): + if (not (check_cpu_version(w.device))) and (not (check_xpu_version(w.device))): w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) return w_int4x8 @@ -682,6 +684,7 @@ def test_get_groupwise_affine_qparams(self): n_bit = 4 zero_point_domains = [ZeroPointDomain.FLOAT, ZeroPointDomain.INT] + zero_point_dtypes = [torch.bfloat16, torch.int32] mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (1, 128) @@ -689,8 +692,9 @@ def test_get_groupwise_affine_qparams(self): quant_max = 2**n_bit - 1 eps = 1e-6 scale_dtype = torch.bfloat16 - zero_point_dtype = torch.bfloat16 - for zero_point_domain in zero_point_domains: + for zero_point_domain, zero_point_dtype in zip( + zero_point_domains, zero_point_dtypes + ): scale_ref, zero_point_ref = _get_groupwise_affine_qparams( input, n_bit=n_bit, @@ -744,8 +748,8 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self): zeros = torch.randint(0, 15, (10, 2), dtype=torch.int32) if TORCH_VERSION_AT_LEAST_2_5: input_tmp = input - if not ( - is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6 + if (not (check_cpu_version(input.device))) and ( + not (check_xpu_version(input.device)) ): input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index e4ae0ba1cb..eb253c11bc 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -17,6 +17,7 @@ BlockSparseLayout, CutlassInt4PackedLayout, Int4CPULayout, + Int4XPULayout, MarlinQQQLayout, MarlinQQQTensor, MarlinSparseLayout, @@ -59,4 +60,5 @@ "QDQLayout", "PackedLinearInt8DynamicActivationIntxWeightLayout", "to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight", + "Int4XPULayout", ] diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index b4912523bb..50ef8c9e89 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -43,6 +43,12 @@ _linear_fp_act_uint4_weight_cpu_check, _linear_fp_act_uint4_weight_cpu_impl, ) +from torchao.dtypes.uintx.int4_xpu_layout import ( + _linear_bf16_act_uint4_weight_float_zero_check, + _linear_bf16_act_uint4_weight_float_zero_impl, + _linear_bf16_act_uint4_weight_int8_zero_check, + _linear_bf16_act_uint4_weight_int8_zero_impl, +) from torchao.dtypes.uintx.marlin_qqq_tensor import ( _linear_int8_act_int4_weight_marlin_qqq_check, _linear_int8_act_int4_weight_marlin_qqq_impl, @@ -225,6 +231,14 @@ def _register_aqt_quantized_linear_dispatches(): _linear_q_dq_check, _linear_q_dq_impl, ), + ( + _linear_bf16_act_uint4_weight_int8_zero_check, + _linear_bf16_act_uint4_weight_int8_zero_impl, + ), + ( + _linear_bf16_act_uint4_weight_float_zero_check, + _linear_bf16_act_uint4_weight_float_zero_impl, + ), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) @@ -384,6 +398,13 @@ def _(func, types, args, kwargs): ) +@implements(aten.copy_.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[1]._apply_fn_to_data(torch.clone) + ) + + @implements(aten._to_copy.default) def _(func, types, args, kwargs): return return_and_correct_aliasing( diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 901cd8215c..fee6141164 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -7,6 +7,9 @@ from .int4_cpu_layout import ( Int4CPULayout, ) +from .int4_xpu_layout import ( + Int4XPULayout, +) from .marlin_qqq_tensor import ( MarlinQQQLayout, MarlinQQQTensor, @@ -44,4 +47,5 @@ "CutlassInt4PackedLayout", "PackedLinearInt8DynamicActivationIntxWeightLayout", "QDQLayout", + "Int4XPULayout", ] diff --git a/torchao/dtypes/uintx/int4_xpu_layout.py b/torchao/dtypes/uintx/int4_xpu_layout.py new file mode 100644 index 0000000000..38c9d01993 --- /dev/null +++ b/torchao/dtypes/uintx/int4_xpu_layout.py @@ -0,0 +1,445 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device +from torchao.quantization.quant_primitives import ZeroPointDomain +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_8, + fill_defaults, +) + +aten = torch.ops.aten + + +def _aqt_is_xpu_layout_uint4(aqt): + """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" + # TODO: use torch.uint4 + return ( + aqt.tensor_impl.dtype == torch.int32 + and aqt.quant_min == 0 + and aqt.quant_max == 15 + ) + + +def _linear_bf16_act_uint4_weight_float_zero_check(input_tensor, weight_tensor, bias): + return ( + # input is native bfloat16 tensor + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.dtype == torch.bfloat16 + and + # weight is uint4, group quantized tensor_core_tiled tensor impl affine quantized tensor + isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_xpu_layout_uint4(weight_tensor) + and weight_tensor.dtype == torch.bfloat16 + and len(weight_tensor.shape) == 2 + and weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT + and weight_tensor.tensor_impl.scale_and_zero is not None + and weight_tensor.tensor_impl.scale_and_zero.dtype == torch.bfloat16 + and isinstance(weight_tensor._layout, Int4XPULayout) + ) + + +def _linear_bf16_act_uint4_weight_float_zero_impl(input_tensor, weight_tensor, bias): + assert ( + weight_tensor.block_size[0] == 1 + ), f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"need input_tensor shape: {input_tensor.shape} final" + f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " + ) + + # TODO: check groupsize quantization + # avoid circular dep, TODO: move this to a common util.py + act_mat = input_tensor + if act_mat.is_contiguous() == False: + act_mat = act_mat.contiguous() + # weight is packed from padded (out_features, in_features) weight tensor + # (same dimension requirement as F.linear weight) + packed_weight = weight_tensor.tensor_impl.packed_weight + scales_and_zeros = weight_tensor.tensor_impl.scale_and_zero + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + + act_mat = act_mat.reshape(-1, act_mat.shape[-1]).to(torch.bfloat16) + + # groupwise int4 quantization + groupsize = weight_tensor.block_size[1] + y = torch.ops.aten._weight_int4pack_mm( + act_mat, packed_weight, groupsize, scales_and_zeros + ) + + # remove out_feature padding + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + if bias is not None: + y += bias + return y.to(orig_dtype) + + +def _linear_bf16_act_uint4_weight_int8_zero_check(input_tensor, weight_tensor, bias): + return ( + # input is native bfloat16 tensor + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.dtype == torch.bfloat16 + and + # weight is uint4, group quantized tensor_core_tiled tensor impl affine quantized tensor + isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_xpu_layout_uint4(weight_tensor) + and weight_tensor.dtype == torch.bfloat16 + and len(weight_tensor.shape) == 2 + and weight_tensor.zero_point_domain == ZeroPointDomain.INT + and weight_tensor.tensor_impl.scale_and_zero is None + and weight_tensor.tensor_impl.scale.dtype == torch.bfloat16 + and weight_tensor.tensor_impl.zero.dtype == torch.int8 + and isinstance(weight_tensor._layout, Int4XPULayout) + ) + + +def _linear_bf16_act_uint4_weight_int8_zero_impl(input_tensor, weight_tensor, bias): + assert ( + weight_tensor.block_size[0] == 1 + ), f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"need input_tensor shape: {input_tensor.shape} final" + f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " + ) + + # TODO: check groupsize quantization + # avoid circular dep, TODO: move this to a common util.py + act_mat = input_tensor + # weight is packed from padded (out_features, in_features) weight tensor + # (same dimension requirement as F.linear weight) + packed_weight = weight_tensor.tensor_impl.packed_weight + scale = weight_tensor.tensor_impl.scale + zero = weight_tensor.tensor_impl.zero + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + + act_mat = act_mat.reshape(-1, act_mat.shape[-1]).to(torch.bfloat16) + + # groupwise int4 quantization + groupsize = weight_tensor.block_size[1] + + y = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( + act_mat, packed_weight, groupsize, scale, zero + ) + + # remove out_feature padding + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + if bias is not None: + y += bias + return y.to(orig_dtype) + + +@dataclass(frozen=True) +class Int4XPULayout(Layout): + """Only for PyTorch version at least 2.7""" + + pass + + +@register_layout(Int4XPULayout) +class Int4XPUAQTTensorImpl(AQTTensorImpl): + """ + TensorImpl for int4 XPU layout for affine quantized tensor, this is for int4 only, + used by tinygemm kernels `_weight_int4pack_mm_xpu` and `_weight_int4pack_mm_with_zeros_and_scales` (TBD) + It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of + dimension: [n][k / 8] (int32 dtype) + (unpacked Tensor shape is n * k) + Note: we also pack scale and zero point together here for tinygemm kernel + Note: technically Int4 XPU layout should be the layout for the underlying packed weight + (int Tensor) but since the scale and zero_point are also packed into the same tensor here which is not used + in plain layout, we just created a layout for AQT right now, this could be improved if we split out + int4 aqt into a separate tensor subclass + fields: + packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 XPU layout + [Optional] scale_and_zero (torch.Tensor): the combined scale Tensor used to map between floating point tensor to quantized tensor and zero_point Tensor + [Optional] scale (torch.Tensor): scale tensors, should be the same dtype of packed weight + [Optional] zeros (torch.Tensor): can be of the same dtype of packed weight or different dtype + """ + + def __new__( + cls, + packed_weight: torch.Tensor, + scale_and_zero: torch.Tensor, + transposed: bool, + _layout: Layout, + scale: torch.Tensor = None, + zero: torch.Tensor = None, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout + ) + kwargs["dtype"] = packed_weight.dtype + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scale_and_zero: torch.Tensor, + transposed: bool, + _layout: Layout, + scale: torch.Tensor = None, + zero: torch.Tensor = None, + ): + self.packed_weight = packed_weight + self.scale_and_zero = scale_and_zero + self.transposed = False + self._layout = _layout + self.scale = scale + self.zero = zero + + def __tensor_flatten__(self): + if self.scale_and_zero is not None: + return ["packed_weight", "scale_and_zero"], [self.transposed, self._layout] + else: + return ["packed_weight", "scale", "zero"], [self.transposed, self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight = tensor_data_dict["packed_weight"] + scale_and_zero = ( + tensor_data_dict.get("scale_and_zero") + if "scale_and_zero" in tensor_data_dict + else None + ) + scale = tensor_data_dict.get("scale") if "scale" in tensor_data_dict else None + zero = tensor_data_dict.get("zero") if "zero" in tensor_data_dict else None + ( + transposed, + _layout, + ) = tensor_attributes + return cls(packed_weight, scale_and_zero, transposed, _layout, scale, zero) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert isinstance(_layout, Int4XPULayout) + + from torchao.quantization.utils import convert_weight_to_int4pack_xpu + + if TORCH_VERSION_AT_LEAST_2_8: + assert ( + int_data.dtype == torch.int32 + ), "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" + packed_weight = convert_weight_to_int4pack_xpu( + int_data, zero_point.dtype != scale.dtype + ) + else: + assert False, "INT4 not supported on XPU until 2.8" + + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) + if zero_point.dtype == scale.dtype: + from torchao.quantization.utils import pack_tinygemm_scales_and_zeros + + scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) + return cls(packed_weight, scale_and_zero, False, _layout, None, None) + else: + return cls( + packed_weight, + None, + False, + _layout, + scale.transpose(0, 1).contiguous(), + zero_point.transpose(0, 1).contiguous().to(torch.int8), + ) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs["device"] + if not is_device(torch.device(self.device).type, device): + raise ValueError( + f"Int4XPUAQTTensorImpl does not support conversion from {self.device} to {device}" + ) + return self.__class__( + self.packed_weight.to(device), + self.scale_and_zero.to(device) if self.scale_and_zero is not None else None, + self.transposed, + self._layout, + self.scale.to(device) if self.scale is not None else None, + self.zero.to(device) if self.zero is not None else None, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.packed_weight), + fn(self.scale_and_zero) if self.scale_and_zero is not None else None, + self.transposed, + self._layout, + fn(self.scale) if self.scale is not None else None, + fn(self.zero) if self.zero is not None else None, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + if func is aten.t.default: + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + transposed = Int4XPUAQTTensorImpl( + args[0].packed_weight, + args[0].scale_and_zero, + not args[0].transposed, + args[0]._layout, + args[0].scale, + args[0].zero, + ) + return return_and_correct_aliasing(func, args, kwargs, transposed) + + if func is torch.ops.aten.copy_.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + if func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + int_data, scale, zero_point = self.get_plain() + int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # this is to handle padding + int_data = self._layout.post_process(int_data) + sliced = self.from_plain(int_data, scale, zero_point, self._layout) + return return_and_correct_aliasing(func, args, kwargs, sliced) + elif dim == 1: + int_data, scale, zero_point = self.get_plain() + assert step == 1, "Only step == 1 is supported in slicing right now" + data_len = int_data.shape[dim] + scale_len = scale.shape[dim] + ratio = data_len / scale_len + start_scale = int(start / ratio) + end_scale = int(end / ratio) + + int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # this is to handle padding + int_data = self._layout.post_process(int_data) + scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) + zero_point = aten.slice.Tensor( + zero_point, dim, start_scale, end_scale, step + ) + sliced = self.from_plain(int_data, scale, zero_point, self._layout) + return sliced + else: + raise NotImplementedError( + f"Int4XPUAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) + + raise NotImplementedError( + f"Int4XPUAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from torchao.quantization.quant_primitives import ( + ZeroPointDomain, + quantize_affine, + ) + from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros + + if self.scale_and_zero is not None: + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) + else: + scale = self.scale.transpose(0, 1).contiguous() + zero = self.zero.transpose(0, 1).contiguous() + + cur_shape = self.shape + assert len(cur_shape) == 2 + original_shape = (cur_shape[0], cur_shape[1] * 8) + eye_shape = original_shape[1] + groupsize = int(original_shape[1] / scale.shape[1]) + block_size = (1, groupsize) + device = self.device + original_dtype = torch.bfloat16 + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + assert len(block_size) == 2 and block_size[0] == 1 + if self.scale_and_zero is None: + zero_point_domain = ZeroPointDomain.INT + dequantized = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( + torch.eye(eye_shape, device=device, dtype=original_dtype), + self.packed_weight, + groupsize, + self.scale, + self.zero, + ) + dequantized = dequantized.t().contiguous() + int_data = quantize_affine( + dequantized, + block_size, + scale, + zero, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) + else: + zero_point_domain = ZeroPointDomain.FLOAT + dequantized = torch.ops.aten._weight_int4pack_mm( + torch.eye(eye_shape, device=device, dtype=original_dtype), + self.packed_weight, + groupsize, + self.scale_and_zero, + ) + dequantized = dequantized.t().contiguous() + # TODO: move this to `unpack_tinygemm_scales_and_zeros`? + scale = scale.reshape(scale.shape[:-1]).contiguous() + zero = zero.reshape(zero.shape[:-1]).contiguous() + int_data = quantize_affine( + dequantized, + block_size, + scale, + zero, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) + return int_data, scale, zero + + def get_layout(self) -> Layout: + return self._layout diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index f7fcf46f36..5864db6448 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -8,7 +8,7 @@ import torch -from torchao.utils import TORCH_VERSION_AT_LEAST_2_2, TORCH_VERSION_AT_LEAST_2_6 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_2, check_cpu_version logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -154,7 +154,7 @@ def int_scaled_matmul( scales1 = scales1.expand((M, N)) assert scales1.dim() == 2 - if scales1.device.type == "cpu" and TORCH_VERSION_AT_LEAST_2_6: + if check_cpu_version(scales1.device): # CPU prefers decomposed version of int_scaled_matmul # to leverage the fusion capability of Inductor c = torch._int_mm(a, b) diff --git a/torchao/prototype/hqq/hqq_tinygemm_linear.py b/torchao/prototype/hqq/hqq_tinygemm_linear.py index 2e11210601..f15c9a8104 100644 --- a/torchao/prototype/hqq/hqq_tinygemm_linear.py +++ b/torchao/prototype/hqq/hqq_tinygemm_linear.py @@ -17,7 +17,7 @@ from torch import Tensor, nn from torchao.dtypes.utils import is_device -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, check_cpu_version class HQQLinearTorchWeightOnlyInt4(torch.nn.Module): @@ -167,7 +167,7 @@ def process_hqq_quants(self, W_q, meta): W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants( W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits ) - if is_device(W_q.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + if check_cpu_version(W_q.device): self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu( W_q_torch, self.inner_k_tiles ) @@ -243,7 +243,7 @@ def pack_scales_and_zeros(self, scales, zeros): def matmul(self, x): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) - if is_device(x.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + if check_cpu_version(x.device): c = torch.ops.aten._weight_int4pack_mm_for_cpu( x, self.weight_int4pack, self.groupsize, self.scales_and_zeros ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index fd9fb97363..995030df67 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -33,6 +33,7 @@ CutlassSemiSparseLayout, Float8Layout, Int4CPULayout, + Int4XPULayout, MarlinQQQLayout, MarlinSparseLayout, PackedLinearInt8DynamicActivationIntxWeightLayout, @@ -140,12 +141,14 @@ TensorCoreTiledLayout: [ZeroPointDomain.FLOAT], MarlinSparseLayout: [ZeroPointDomain.INT], Int4CPULayout: [ZeroPointDomain.FLOAT], + Int4XPULayout: [ZeroPointDomain.FLOAT, ZeroPointDomain.INT], } LAYOUT_TO_PRESERVE_ZEROS = { TensorCoreTiledLayout: False, MarlinSparseLayout: True, Int4CPULayout: False, + Int4XPULayout: False, } @@ -203,7 +206,12 @@ def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): def change_linear_weights_to_int4_woqtensors( - model, groupsize=128, inner_k_tiles=8, filter_fn=None + model, + groupsize=128, + inner_k_tiles=8, + filter_fn=None, + zero_point_domain=ZeroPointDomain.FLOAT, + preserve_zero=False, ): """ Converts all linear weight tensors to the @@ -214,6 +222,11 @@ def change_linear_weights_to_int4_woqtensors( `groupsize`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [256, 128, 64, 32] `inner_k_tiles`: parameter for int4 mm kernel, choices are [8, 4, 2] + `filter_fn`: function that takes a nn.Module instance and fully qualified name of the module, \ + returns True if we want to run `config` on + `zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, \ + ZeroPointDomain.INT, ZeroPointDomain.NONE] + `preserve_zero`: whether to preserve zero, default is False """ if TORCH_VERSION_AT_LEAST_2_4: raise ImportError( @@ -230,6 +243,8 @@ def change_linear_weights_to_int4_woqtensors( enable_parametrization=False, groupsize=groupsize, inner_k_tiles=inner_k_tiles, + zero_point_domain=zero_point_domain, + preserve_zero=preserve_zero, ), filter_fn, ) @@ -963,6 +978,7 @@ class Int4WeightOnlyConfig(AOBaseConfig): `use_hqq`: whether to use hqq or default quantization mode, default is False `zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE] `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. + `preserve_zero`: whether to preserve zero, default is None. Will be set to True if zero_point_domain is ZeroPointDomain.INT """ group_size: int = 128 @@ -970,6 +986,7 @@ class Int4WeightOnlyConfig(AOBaseConfig): use_hqq: bool = False zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.NONE set_inductor_config: bool = True + preserve_zero: Optional[bool] = None # for BC @@ -1006,7 +1023,6 @@ def _int4_weight_only_transform( quant_min = 0 quant_max = 15 eps = 1e-6 - preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)] zero_point_dtype = ( weight.dtype if isinstance(layout, Int4CPULayout) else torch.bfloat16 ) @@ -1023,6 +1039,14 @@ def _int4_weight_only_transform( zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)] ), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}" + if zero_point_domain == ZeroPointDomain.INT and isinstance(layout, Int4XPULayout): + zero_point_dtype = torch.int32 + + preserve_zero = ( + config.preserve_zero + if config.preserve_zero is not None + else LAYOUT_TO_PRESERVE_ZEROS[type(layout)] + ) # Sparse Marlin only supports symmetric quantization. # NOTE: If we start having lots of layouts that require different configurations, # we should consider moving this logic somewhere else. diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index bc176c9d17..df94fa2f8a 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -779,6 +779,7 @@ def choose_qparams_affine( """ if zero_point_domain is None: raise ValueError("Please use ZeroPointDomain.NONE instead of None") + return _choose_qparams_affine( input, mapping_type.name, @@ -875,8 +876,6 @@ def _choose_qparams_affine( if input is not None: if scale_dtype is None: scale_dtype = input.dtype - if zero_point_dtype is None: - zero_point_dtype = input.dtype if eps is None: eps = torch.finfo(input.dtype).eps @@ -900,8 +899,6 @@ def _choose_qparams_affine( if scale_dtype is None: scale_dtype = min_val.dtype - if zero_point_dtype is None: - zero_point_dtype = min_val.dtype if eps is None: eps = torch.finfo(min_val.dtype).eps @@ -955,19 +952,20 @@ def _choose_qparams_affine( scale = torch.clamp(scale, min=eps) if zero_point_domain == ZeroPointDomain.NONE.name: zero_point = None + elif zero_point_domain == ZeroPointDomain.INT.name: + zero_point = quant_min - torch.round(min_val_neg / scale) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + if zero_point_dtype is None: + zero_point_dtype = torch.int32 else: - if preserve_zero: - zero_point = quant_min - torch.round(min_val_neg / scale) - zero_point = torch.clamp(zero_point, quant_min, quant_max) - else: - assert ( - zero_point_domain == ZeroPointDomain.FLOAT.name - ), "if not preserve_zero, zero_point must be in FLOAT domain" - mid_point = (quant_max + quant_min + 1) / 2 - # this is not preserving zero_point, this is converting to TensorCoreTiledFormat - # TODO move the conversion of zero_point out of quant_primitives - # and into TensorCoreTiledLayout.from_plain - zero_point = min_val_neg + scale * mid_point + assert ( + zero_point_domain == ZeroPointDomain.FLOAT.name + ), "zero_point must be in FLOAT/INT/None domain for asymmetric quantization" + mid_point = (quant_max + quant_min + 1) / 2 + # this is not preserving zero_point, this is converting to TensorCoreTiledFormat + # TODO move the conversion of zero_point out of quant_primitives + # and into TensorCoreTiledLayout.from_plain + zero_point = min_val_neg + scale * mid_point if zero_point is not None: zero_point = zero_point.to(dtype=zero_point_dtype) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 9715d99e08..abaad317eb 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -8,7 +8,6 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.dtypes.utils import is_device from torchao.quantization.utils import ( dequantize_per_channel, dynamically_quantize_per_channel, @@ -16,7 +15,15 @@ quant_int8_dynamic_per_token_linear, unpack_tinygemm_scales_and_zeros, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, find_multiple +from torchao.utils import ( + check_cpu_version, + check_xpu_version, + find_multiple, +) + +from .quant_primitives import ( + ZeroPointDomain, +) __all__ = [ "Int8DynamicallyQuantizedLinearWeight", @@ -419,6 +426,8 @@ def __new__( shape, groupsize=128, inner_k_tiles=8, + zero_point_domain=ZeroPointDomain.FLOAT, + preserve_zero=False, dtype=None, **kwargs, ): @@ -435,6 +444,8 @@ def __init__( shape, groupsize, inner_k_tiles, + zero_point_domain, + preserve_zero, dtype, **kwargs, ): @@ -446,6 +457,8 @@ def __init__( self.scales_and_zeros = scales_and_zeros self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles + self.zero_point_domain = zero_point_domain + self.preserve_zero = preserve_zero super().__init__(int_data, transposed) @staticmethod @@ -459,13 +472,29 @@ def _quantized_op(act_mat, w_qtensor, bias): act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1])) # matmul - if is_device(act_mat.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + if check_cpu_version(act_mat.device): y = aten._weight_int4pack_mm_for_cpu( act_mat.contiguous(), w_qtensor.int_data, w_qtensor.groupsize, w_qtensor.scales_and_zeros, ) + elif check_xpu_version(act_mat.device): + if not w_qtensor.zero_point_domain == ZeroPointDomain.INT: + y = aten._weight_int4pack_mm( + act_mat.contiguous(), + w_qtensor.int_data, + w_qtensor.groupsize, + w_qtensor.scales_and_zeros, + ) + else: + y = aten._weight_int4pack_mm_with_scales_and_zeros( + act_mat.contiguous(), + w_qtensor.int_data, + w_qtensor.groupsize, + w_qtensor.scales_and_zeros[0], + w_qtensor.scales_and_zeros[1], + ) else: y = aten._weight_int4pack_mm( act_mat.contiguous(), @@ -513,6 +542,8 @@ def to(self, *args, **kwargs): self.shape, self.groupsize, self.inner_k_tiles, + self.zero_point_domain, + self.preserve_zero, **kwargs, ) @@ -524,6 +555,8 @@ def _apply_fn_to_data(self, fn): self.shape, self.groupsize, self.inner_k_tiles, + self.zero_point_domain, + self.preserve_zero, dtype=self.dtype, ) @@ -537,6 +570,8 @@ def _change_shape(self, shape): shape, self.groupsize, self.inner_k_tiles, + self.zero_point_domain, + self.preserve_zero, dtype=self.dtype, ) @@ -546,6 +581,8 @@ def __tensor_flatten__(self): self.shape, self.groupsize, self.inner_k_tiles, + self.zero_point_domain, + self.preserve_zero, self.dtype, ) @@ -560,7 +597,15 @@ def __tensor_unflatten__( tensor_data_dict["int_data"], tensor_data_dict["scales_and_zeros"], ) - transposed, shape, groupsize, inner_k_tiles, dtype = attributes + ( + transposed, + shape, + groupsize, + inner_k_tiles, + zero_point_domain, + preserve_zero, + dtype, + ) = attributes return cls( int_data, scales_and_zeros, @@ -568,12 +613,22 @@ def __tensor_unflatten__( shape if outer_size is None else outer_size, groupsize, inner_k_tiles, + zero_point_domain=zero_point_domain, + preserve_zero=preserve_zero, dtype=dtype, strides=outer_stride, ) @classmethod - def from_float(cls, input_float, groupsize=128, inner_k_tiles=8, dtype=None): + def from_float( + cls, + input_float, + groupsize=128, + inner_k_tiles=8, + zero_point_domain=ZeroPointDomain.FLOAT, + preserve_zero=False, + dtype=None, + ): """ Method used to convert a linear weight tensor to an instance of the Int4WeightOnlyQuantizedLinearWeight subclass. @@ -588,7 +643,13 @@ def from_float(cls, input_float, groupsize=128, inner_k_tiles=8, dtype=None): dtype = input_float.dtype int_data, scales_and_zeros, transposed, groupsize, inner_k_tils = ( - cls.to_qtensor_components(input_float, groupsize, inner_k_tiles) + cls.to_qtensor_components( + input_float, + groupsize, + inner_k_tiles, + zero_point_domain=zero_point_domain, + preserve_zero=preserve_zero, + ) ) return cls( int_data, @@ -597,11 +658,20 @@ def from_float(cls, input_float, groupsize=128, inner_k_tiles=8, dtype=None): input_float.shape, groupsize, inner_k_tiles, + zero_point_domain=zero_point_domain, + preserve_zero=preserve_zero, dtype=dtype, ) @classmethod - def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8): + def to_qtensor_components( + cls, + input_float, + groupsize=128, + inner_k_tiles=8, + zero_point_domain=ZeroPointDomain.FLOAT, + preserve_zero=False, + ): assert groupsize in [256, 128, 64, 32] assert inner_k_tiles in [8, 4, 2] orig_out_features, orig_in_features = input_float.shape @@ -616,12 +686,24 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8): # quantization and packing input_int4x8, scales_and_zeros = groupwise_affine_quantize_tensor( - input_float, 4, groupsize, dtype=input_float.dtype + input_float, + 4, + groupsize, + dtype=input_float.dtype, + zero_point_domain=zero_point_domain, + preserve_zero=preserve_zero, ) - if is_device(input_float.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + if check_cpu_version(input_float.device): int_data = aten._convert_weight_to_int4pack_for_cpu( input_int4x8, inner_k_tiles ) + if check_xpu_version(input_float.device): + from torchao.quantization.utils import convert_weight_to_int4pack_xpu + + int_data = convert_weight_to_int4pack_xpu( + input_int4x8, + zero_point_domain_is_int=zero_point_domain == ZeroPointDomain.INT, + ) else: int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) return int_data, scales_and_zeros, False, groupsize, inner_k_tiles diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index b23f39c6d7..f5bdfa9193 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -9,7 +9,6 @@ import torch from torch.utils._python_dispatch import TorchDispatchMode -from torchao.dtypes.utils import is_device from torchao.kernel import ( int_scaled_matmul, ) @@ -20,7 +19,11 @@ dequantize_affine, quantize_affine, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + check_cpu_version, + check_xpu_version, +) __all__ = [ "compute_error", @@ -315,7 +318,14 @@ def dequantize_per_channel(int_repr, scales, zero_points, out_dtype=torch.float3 return dequantized -def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16): +def get_groupwise_affine_qparams( + w, + n_bit=4, + groupsize=128, + dtype=torch.bfloat16, + zero_point_domain=ZeroPointDomain.FLOAT, + preserve_zero=False, +): if groupsize > w.shape[-1]: groupsize = w.shape[-1] assert groupsize > 1 @@ -330,7 +340,9 @@ def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16 quant_max = 2**n_bit - 1 eps = 1e-6 scale_dtype = dtype - zero_point_dtype = dtype + zero_point_dtype = ( + dtype if zero_point_domain != ZeroPointDomain.INT else torch.int32 + ) scale, zero_point = choose_qparams_affine( w, @@ -342,12 +354,12 @@ def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16 eps, scale_dtype=scale_dtype, zero_point_dtype=zero_point_dtype, - preserve_zero=False, - zero_point_domain=ZeroPointDomain.FLOAT, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, ) return scale.to(dtype=dtype).reshape(w.shape[0], -1), zero_point.to( - dtype=dtype + dtype=zero_point_dtype ).reshape(w.shape[0], -1) @@ -372,6 +384,25 @@ def unpack_tinygemm_scales_and_zeros(scales_and_zeros): return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) +def convert_weight_to_int4pack_xpu(weight, zero_point_domain_is_int=False): + assert weight.device.type == "xpu" + + if zero_point_domain_is_int: + # int_data = weight.to(dtype=torch.uint8) + int_data = (weight[::, 1::2] << 4 | weight[::, ::2]).to(torch.uint8) + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + int_data, + 8, # TODO:remove + ) + else: + out = weight.to(dtype=torch.uint8) + out = (out[::, 1::2] << 4 | out[::, ::2]).to(torch.uint8) + packed_weight = out.view(torch.int32) + + # Second, N * K/2 uint8 -> N * K/8 int32 + return packed_weight + + def groupwise_affine_quantize_tensor_from_qparams( w, scales, zeros, n_bit=4, groupsize=128, zero_point_domain=ZeroPointDomain.FLOAT ): @@ -399,7 +430,9 @@ def groupwise_affine_quantize_tensor_from_qparams( zero_point_domain=zero_point_domain, ) if TORCH_VERSION_AT_LEAST_2_5 and w.shape[-1] > 1: - if not (is_device(int_data.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): + if (not (check_cpu_version(int_data.device))) and ( + not (check_xpu_version(int_data.device)) + ): int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) return int_data @@ -418,7 +451,8 @@ def groupwise_affine_dequantize_tensor_from_qparams( if ( TORCH_VERSION_AT_LEAST_2_5 and (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1) - and not (is_device(w_int4x8.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6) + and not (check_cpu_version(w_int4x8.device)) + and not (check_xpu_version(w_int4x8.device)) ): data = w_int4x8.to(torch.int32) high_bits = data >> 4 @@ -454,10 +488,24 @@ def groupwise_affine_dequantize_tensor_from_qparams( ) -def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16): - scales, zeros = get_groupwise_affine_qparams(w, n_bit, groupsize, dtype) +def groupwise_affine_quantize_tensor( + w, + n_bit=4, + groupsize=128, + dtype=torch.bfloat16, + zero_point_domain=ZeroPointDomain.FLOAT, + preserve_zero=False, +): + scales, zeros = get_groupwise_affine_qparams( + w, + n_bit, + groupsize, + dtype, + zero_point_domain=zero_point_domain, + preserve_zero=preserve_zero, + ) w_int4x8 = groupwise_affine_quantize_tensor_from_qparams( - w, scales, zeros, n_bit, groupsize + w, scales, zeros, n_bit, groupsize, zero_point_domain=zero_point_domain ) scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros, dtype) return w_int4x8, scales_and_zeros diff --git a/torchao/utils.py b/torchao/utils.py index 5577a66637..c8465274ea 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -676,6 +676,18 @@ def is_sm_at_least_100(): ) +def check_cpu_version(device, version="2.6.0"): + if isinstance(device, torch.device): + device = device.type + return device == "cpu" and compare_versions(torch.__version__, version) >= 0 + + +def check_xpu_version(device, version="2.8.0"): + if isinstance(device, torch.device): + device = device.type + return device == "xpu" and compare_versions(torch.__version__, version) >= 0 + + TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev") TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev") TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev")