Skip to content

Commit 522b8c7

Browse files
committed
encapsulate version check as helpers
remove zero_point_dtype assigning Signed-off-by: Meng, Hengyu <[email protected]> fix import lint enable zp dtype: u8/s8/s16/s32/s64 Signed-off-by: Meng, Hengyu <[email protected]>
1 parent bf6c814 commit 522b8c7

File tree

10 files changed

+57
-53
lines changed

10 files changed

+57
-53
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636
from torchao.testing.utils import skip_if_no_cuda, skip_if_rocm
3737
from torchao.utils import (
3838
TORCH_VERSION_AT_LEAST_2_5,
39-
TORCH_VERSION_AT_LEAST_2_6,
40-
TORCH_VERSION_AT_LEAST_2_8,
39+
check_cpu_version,
40+
check_xpu_version,
4141
is_fbcode,
4242
is_ROCM,
4343
is_sm_at_least_89,
@@ -58,11 +58,11 @@ def get_quantization_functions(
5858
int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC),
5959
]
6060
if do_int4:
61-
if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6:
61+
if check_cpu_version(device):
6262
base_functions.append(
6363
int4_weight_only(group_size=32, layout=Int4CPULayout())
6464
)
65-
elif device == "xpu" and TORCH_VERSION_AT_LEAST_2_8:
65+
elif check_xpu_version(device):
6666
base_functions.append(
6767
int4_weight_only(group_size=32, layout=Int4XPULayout())
6868
)

test/integration/test_integration.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import torchao
2222
from torchao.dtypes import Int4CPULayout, Int4XPULayout, TensorCoreTiledLayout
23-
from torchao.dtypes.utils import is_device
2423
from torchao.quantization import safe_int_mm
2524
from torchao.quantization.autoquant import (
2625
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
@@ -83,8 +82,9 @@
8382
TORCH_VERSION_AT_LEAST_2_5,
8483
TORCH_VERSION_AT_LEAST_2_6,
8584
TORCH_VERSION_AT_LEAST_2_7,
86-
TORCH_VERSION_AT_LEAST_2_8,
8785
benchmark_model,
86+
check_cpu_version,
87+
check_xpu_version,
8888
is_fbcode,
8989
is_sm_at_least_90,
9090
unwrap_tensor_subclass,
@@ -147,21 +147,15 @@ def _int8da_int8w_api(
147147

148148

149149
def _int4wo_api(mod, use_hqq=False):
150-
if (
151-
is_device(next(mod.parameters()).device.type, "cpu")
152-
and TORCH_VERSION_AT_LEAST_2_6
153-
):
150+
if check_cpu_version(next(mod.parameters()).device):
154151
quantize_(
155152
mod,
156153
int4_weight_only(
157154
layout=Int4CPULayout(), use_hqq=use_hqq, set_inductor_config=False
158155
),
159156
)
160157
unwrap_tensor_subclass(mod)
161-
elif (
162-
is_device(next(mod.parameters()).device.type, "xpu")
163-
and TORCH_VERSION_AT_LEAST_2_8
164-
):
158+
elif check_xpu_version(next(mod.parameters()).device):
165159
quantize_(
166160
mod, int4_weight_only(layout=Int4XPULayout()), set_inductor_config=False
167161
)
@@ -1138,9 +1132,9 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
11381132
if dtype != torch.bfloat16:
11391133
self.skipTest(f"Fails for {dtype}")
11401134
layout_list = []
1141-
if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6:
1135+
if check_cpu_version(device):
11421136
layout_list.append(Int4CPULayout())
1143-
elif device == "xpu" and TORCH_VERSION_AT_LEAST_2_8:
1137+
elif check_xpu_version(device):
11441138
layout_list.append(Int4XPULayout())
11451139
else:
11461140
for inner_k_tiles in [4, 2]:

test/quantization/test_quant_primitives.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import torch
1212
from parameterized import parameterized
1313

14-
from torchao.dtypes.utils import is_device
1514
from torchao.float8.float8_utils import EPS as float8_eps
1615
from torchao.quantization.quant_primitives import (
1716
MappingType,
@@ -38,7 +37,8 @@
3837
TORCH_VERSION_AT_LEAST_2_4,
3938
TORCH_VERSION_AT_LEAST_2_5,
4039
TORCH_VERSION_AT_LEAST_2_6,
41-
TORCH_VERSION_AT_LEAST_2_8,
40+
check_cpu_version,
41+
check_xpu_version,
4242
is_fbcode,
4343
)
4444

@@ -136,9 +136,7 @@ def _groupwise_affine_quantize_tensor_from_qparams(
136136
)
137137

138138
if TORCH_VERSION_AT_LEAST_2_5:
139-
if (not (is_device(w.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6)) and (
140-
not (is_device(w.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_8)
141-
):
139+
if (not (check_cpu_version(w.device))) and (not (check_xpu_version(w.device))):
142140
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
143141

144142
return w_int4x8
@@ -747,16 +745,8 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
747745
zeros = torch.randint(0, 15, (10, 2), dtype=torch.int32)
748746
if TORCH_VERSION_AT_LEAST_2_5:
749747
input_tmp = input
750-
if (
751-
not (
752-
is_device(input.device.type, "cpu")
753-
and TORCH_VERSION_AT_LEAST_2_6
754-
)
755-
) and (
756-
not (
757-
is_device(input.device.type, "xpu")
758-
and TORCH_VERSION_AT_LEAST_2_8
759-
)
748+
if (not (check_cpu_version(input.device))) and (
749+
not (check_xpu_version(input.device))
760750
):
761751
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
762752
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(

torchao/kernel/intmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import torch
1010

11-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_2, TORCH_VERSION_AT_LEAST_2_6
11+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_2, check_cpu_version
1212

1313
logger = logging.getLogger(__name__)
1414
logger.addHandler(logging.NullHandler())
@@ -154,7 +154,7 @@ def int_scaled_matmul(
154154
scales1 = scales1.expand((M, N))
155155
assert scales1.dim() == 2
156156

157-
if scales1.device.type == "cpu" and TORCH_VERSION_AT_LEAST_2_6:
157+
if check_cpu_version(scales1.device):
158158
# CPU prefers decomposed version of int_scaled_matmul
159159
# to leverage the fusion capability of Inductor
160160
c = torch._int_mm(a, b)

torchao/prototype/hqq/hqq_tinygemm_linear.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torch import Tensor, nn
1818

1919
from torchao.dtypes.utils import is_device
20-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6
20+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, check_cpu_version
2121

2222

2323
class HQQLinearTorchWeightOnlyInt4(torch.nn.Module):
@@ -167,7 +167,7 @@ def process_hqq_quants(self, W_q, meta):
167167
W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants(
168168
W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits
169169
)
170-
if is_device(W_q.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6:
170+
if check_cpu_version(W_q.device):
171171
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
172172
W_q_torch, self.inner_k_tiles
173173
)
@@ -243,7 +243,7 @@ def pack_scales_and_zeros(self, scales, zeros):
243243
def matmul(self, x):
244244
origin_x_size = x.size()
245245
x = x.reshape(-1, origin_x_size[-1])
246-
if is_device(x.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6:
246+
if check_cpu_version(x.device):
247247
c = torch.ops.aten._weight_int4pack_mm_for_cpu(
248248
x, self.weight_int4pack, self.groupsize, self.scales_and_zeros
249249
)

torchao/quantization/quant_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,6 +1039,9 @@ def _int4_weight_only_transform(
10391039
zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)]
10401040
), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}"
10411041

1042+
if zero_point_domain == ZeroPointDomain.INT and isinstance(layout, Int4XPULayout):
1043+
zero_point_dtype = torch.int32
1044+
10421045
preserve_zero = (
10431046
config.preserve_zero
10441047
if config.preserve_zero is not None

torchao/quantization/quant_primitives.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -958,7 +958,16 @@ def _choose_qparams_affine(
958958
elif zero_point_domain == ZeroPointDomain.INT.name:
959959
zero_point = quant_min - torch.round(min_val_neg / scale)
960960
zero_point = torch.clamp(zero_point, quant_min, quant_max)
961-
zero_point_dtype = torch.int32
961+
assert (
962+
zero_point_dtype
963+
in [
964+
torch.int8,
965+
torch.uint8,
966+
torch.int16,
967+
torch.int32,
968+
torch.int64,
969+
]
970+
), "zero_point_dtype must be int8/uint8/int16/int32/int64 if ZeroPointDomain.INT"
962971
else:
963972
assert (
964973
zero_point_domain == ZeroPointDomain.FLOAT.name

torchao/quantization/subclass.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch
99
from torch.utils._python_dispatch import return_and_correct_aliasing
1010

11-
from torchao.dtypes.utils import is_device
1211
from torchao.quantization.utils import (
1312
dequantize_per_channel,
1413
dynamically_quantize_per_channel,
@@ -17,8 +16,8 @@
1716
unpack_tinygemm_scales_and_zeros,
1817
)
1918
from torchao.utils import (
20-
TORCH_VERSION_AT_LEAST_2_6,
21-
TORCH_VERSION_AT_LEAST_2_8,
19+
check_cpu_version,
20+
check_xpu_version,
2221
find_multiple,
2322
)
2423

@@ -473,14 +472,14 @@ def _quantized_op(act_mat, w_qtensor, bias):
473472
act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1]))
474473

475474
# matmul
476-
if is_device(act_mat.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6:
475+
if check_cpu_version(act_mat.device):
477476
y = aten._weight_int4pack_mm_for_cpu(
478477
act_mat.contiguous(),
479478
w_qtensor.int_data,
480479
w_qtensor.groupsize,
481480
w_qtensor.scales_and_zeros,
482481
)
483-
elif is_device(act_mat.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_8:
482+
elif check_xpu_version(act_mat.device):
484483
if not w_qtensor.zero_point_domain == ZeroPointDomain.INT:
485484
y = aten._weight_int4pack_mm(
486485
act_mat.contiguous(),
@@ -694,11 +693,11 @@ def to_qtensor_components(
694693
zero_point_domain=zero_point_domain,
695694
preserve_zero=preserve_zero,
696695
)
697-
if is_device(input_float.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6:
696+
if check_cpu_version(input_float.device):
698697
int_data = aten._convert_weight_to_int4pack_for_cpu(
699698
input_int4x8, inner_k_tiles
700699
)
701-
if is_device(input_float.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_8:
700+
if check_xpu_version(input_float.device):
702701
from torchao.quantization.utils import convert_weight_to_int4pack_xpu
703702

704703
int_data = convert_weight_to_int4pack_xpu(

torchao/quantization/utils.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import torch
1010
from torch.utils._python_dispatch import TorchDispatchMode
1111

12-
from torchao.dtypes.utils import is_device
1312
from torchao.kernel import (
1413
int_scaled_matmul,
1514
)
@@ -22,8 +21,8 @@
2221
)
2322
from torchao.utils import (
2423
TORCH_VERSION_AT_LEAST_2_5,
25-
TORCH_VERSION_AT_LEAST_2_6,
26-
TORCH_VERSION_AT_LEAST_2_8,
24+
check_cpu_version,
25+
check_xpu_version,
2726
)
2827

2928
__all__ = [
@@ -431,10 +430,8 @@ def groupwise_affine_quantize_tensor_from_qparams(
431430
zero_point_domain=zero_point_domain,
432431
)
433432
if TORCH_VERSION_AT_LEAST_2_5 and w.shape[-1] > 1:
434-
if (
435-
not (is_device(int_data.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6)
436-
) and (
437-
not (is_device(int_data.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_8)
433+
if (not (check_cpu_version(int_data.device))) and (
434+
not (check_xpu_version(int_data.device))
438435
):
439436
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
440437
return int_data
@@ -454,8 +451,8 @@ def groupwise_affine_dequantize_tensor_from_qparams(
454451
if (
455452
TORCH_VERSION_AT_LEAST_2_5
456453
and (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1)
457-
and not (is_device(w_int4x8.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6)
458-
and not (is_device(w_int4x8.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_8)
454+
and not (check_cpu_version(w_int4x8.device))
455+
and not (check_xpu_version(w_int4x8.device))
459456
):
460457
data = w_int4x8.to(torch.int32)
461458
high_bits = data >> 4

torchao/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,18 @@ def is_sm_at_least_100():
676676
)
677677

678678

679+
def check_cpu_version(device, version="2.6.0"):
680+
if isinstance(device, torch.device):
681+
device = device.type
682+
return device == "cpu" and compare_versions(torch.__version__, version) >= 0
683+
684+
685+
def check_xpu_version(device, version="2.8.0"):
686+
if isinstance(device, torch.device):
687+
device = device.type
688+
return device == "xpu" and compare_versions(torch.__version__, version) >= 0
689+
690+
679691
TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev")
680692
TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev")
681693
TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev")

0 commit comments

Comments
 (0)