Skip to content

Commit 4febe74

Browse files
committed
[XNNPACK] resolve ambiguity around 2d affine quantized tensors
1 parent ee1d7c3 commit 4febe74

File tree

2 files changed

+33
-22
lines changed

2 files changed

+33
-22
lines changed

backends/xnnpack/test/ops/test_linear.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -423,10 +423,16 @@ def test_qd8_fp32_per_token_weight_per_channel_group_int4(self):
423423
use_bias=use_bias,
424424
)
425425

426+
# rank 3
426427
inputs = (torch.randn(1, M, K),)
427428
self._test_groupwise_dq_linear(
428429
lin_mod, inputs, group_size=bl, use_bias=use_bias
429430
)
431+
# rank 2
432+
inputs = (torch.randn(1, K),)
433+
self._test_groupwise_dq_linear(
434+
lin_mod, inputs, group_size=bl, use_bias=use_bias
435+
)
430436

431437
@unittest.skipIf(
432438
not torchao_installed, "Per Channel Group Quantization Required TorchAO"
@@ -437,28 +443,29 @@ def test_qd8_fp16_per_token_weight_per_channel_group_int4(self):
437443
bl_sizes = [32, 32, 32, 64]
438444
N_sizes = [2, 17, 92, 128]
439445

440-
for use_bias in [True, False]:
441-
for M, K, bl, N in zip(M_sizes, K_sizes, bl_sizes, N_sizes):
442-
lin_mod = BaseLinear(
443-
in_size=M,
444-
input_channels=K,
445-
output_channels=N,
446-
dtype=torch.float16,
447-
use_bias=use_bias,
448-
)
446+
for input_rank in range(2, 4):
447+
for use_bias in [True, False]:
448+
for M, K, bl, N in zip(M_sizes, K_sizes, bl_sizes, N_sizes):
449+
lin_mod = BaseLinear(
450+
in_size=M,
451+
input_channels=K,
452+
output_channels=N,
453+
dtype=torch.float16,
454+
use_bias=use_bias,
455+
)
449456

450-
inputs = lin_mod.get_inputs()
451-
# This requires slightly higher atol, but if you look at error it is not that bad:
452-
# Difference: max: 0.00140380859375, abs: 0.00140380859375, mean abs error: 0.00042724609375.
453-
# -- Model vs. Reference --
454-
# Numel: 4, 4
455-
# Median: -0.05023193359375, -0.0516357421875
456-
# Mean: 0.2373046875, 0.237060546875
457-
# Max: 1.0078125, 1.0078125
458-
# Min: -0.08465576171875, -0.08441162109375
459-
self._test_groupwise_dq_linear(
460-
lin_mod, inputs, group_size=bl, use_bias=use_bias, atol=1e-2
461-
)
457+
inputs = lin_mod.get_inputs(rank=input_rank)
458+
# This requires slightly higher atol, but if you look at error it is not that bad:
459+
# Difference: max: 0.00140380859375, abs: 0.00140380859375, mean abs error: 0.00042724609375.
460+
# -- Model vs. Reference --
461+
# Numel: 4, 4
462+
# Median: -0.05023193359375, -0.0516357421875
463+
# Mean: 0.2373046875, 0.237060546875
464+
# Max: 1.0078125, 1.0078125
465+
# Min: -0.08465576171875, -0.08441162109375
466+
self._test_groupwise_dq_linear(
467+
lin_mod, inputs, group_size=bl, use_bias=use_bias, atol=1e-2
468+
)
462469

463470
@unittest.skipIf(
464471
not torchao_installed, "Per Channel Group Quantization Required TorchAO"

backends/xnnpack/utils/quant_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def is_dynamic_qdq(node: torch.fx.Node) -> bool:
5050
if node.op != "call_function":
5151
return False
5252
node_name = format_target_name(node.target.__name__) # pyre-ignore
53-
is_dynamic_affine = is_per_token(node) and not is_per_channel_group(node)
53+
is_dynamic_affine = is_per_token(node)
5454

5555
return node_name in _DYNAMIC_OPS or is_dynamic_affine
5656

@@ -120,6 +120,9 @@ def is_per_token(node: torch.fx.Node):
120120

121121
flag &= block_size[-1] == input_val.shape[-1]
122122
flag &= scale_val.numel() == scale_numel_expected
123+
scale_node = node.all_input_nodes[1]
124+
# per token must have dynamically chosen scale
125+
flag &= scale_node.target == operator.getitem
123126
return flag
124127

125128
return False
@@ -140,6 +143,7 @@ def is_per_channel_group(node: torch.fx.Node):
140143
scale_numel = list(accumulate(scale_val.shape, operator.mul))[-1]
141144
input_numel = list(accumulate(input_val.shape, operator.mul))[-1]
142145
flag &= input_numel == group_size * scale_numel
146+
flag &= not is_per_token(node)
143147
return flag
144148

145149
return False

0 commit comments

Comments
 (0)