Skip to content

Don't partition per_tensor weights with qd8 (#8787) #8890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions backends/xnnpack/partition/config/gemm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
is_dynamic_qdq,
is_per_channel,
is_per_channel_group,
is_per_tensor,
is_qparam,
is_quant,
)
Expand Down Expand Up @@ -66,8 +67,6 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
return False

is_valid, _ = self.get_deps(node, ep)
if not is_valid:
why(node, "Failed to get valid dependent nodes.")
return is_valid

def get_node_and_deps(
Expand Down Expand Up @@ -123,6 +122,7 @@ def get_deps(
precision = self._detect_precision(node)
if precision not in self.supported_precision_types():
# detected precision but it is either disabled or not supported
why(node, f"Unsupported precision type {precision}")
return (False, [])
_, precision = self._overwrite_precision(node)
valid_bias, bias_deps = self._get_bias_deps(node, ep, precision)
Expand All @@ -143,27 +143,42 @@ def _get_weight_deps(
# First find the weight
weight_node = get_input_node(node, self.weight_idx)
if not is_param_node(ep, weight_node):
return (False, []) # weight must be a static param
why(node, "Expected weight to be a static param")
return (False, [])
gemm_deps.append(weight_node)

return (True, gemm_deps)
else:
# Quantized Weight deps
dequant_node = get_input_node(node, self.weight_idx)
if not is_dequant(dequant_node):
why(node, "Expected weight to have a dequantized node")
return False, []
gemm_deps.append(dequant_node)
weight = get_input_node(dequant_node, 0)
if not is_param_node(ep, weight):
why(node, "Expected weight to be a static param")
return False, []
gemm_deps.append(weight)

if (
is_per_tensor(dequant_node)
and precision == ConfigPrecisionType.DYNAMIC_QUANT
):
why(
node,
"XNNPACK does not support per tensor quantized weights for dynamic quantization of activations",
)
return False, []

if is_per_channel(dequant_node) or is_per_channel_group(dequant_node):
if len(dequant_node.all_input_nodes) < 2:
# Expected channel quantized to have scale/zp nodes
why(node, "Expected channel quantized to have scale/zp nodes")
return False, []

gemm_deps.extend(dequant_node.all_input_nodes[1:3])

return (True, gemm_deps)

def _get_output_deps(
Expand All @@ -174,7 +189,7 @@ def _get_output_deps(
# Look for fused activations and tail end quant node
node_users = list(node.users.keys())
if len(node_users) != 1:
# Expect quantized node to have a single output (fused act or dequant)
why(node, "Expected quantized node to have a single output")
return False, []

# Check if the quantized pattern has a fused activation
Expand All @@ -190,6 +205,7 @@ def _get_output_deps(

if not is_quant(n_output):
# Expected gemm_node --> fused_act (optional) --> dequant
why(node, "Expected output node to have a dequantized node")
return (False, [])
gemm_deps.append(n_output)
elif precision == ConfigPrecisionType.FP32:
Expand Down Expand Up @@ -219,7 +235,8 @@ def _get_bias_deps(
bias_node = get_input_node(node, self.bias_idx)
if bias_node:
if not is_param_node(ep, bias_node):
return (False, []) # bias node must be a static param
why(node, "Expected bias to be a static param")
return (False, [])
gemm_deps.append(bias_node)

return (True, gemm_deps)
Expand All @@ -233,7 +250,7 @@ def _get_act_deps(
else:
dq_input = get_input_node(node, self.act_idx)
if not is_dequant(dq_input):
# Expected static quant input to be dequant node
why(node, "Expected act input to be dequant node")
return False, []
gemm_deps.append(dq_input)
if precision == ConfigPrecisionType.STATIC_QUANT:
Expand All @@ -243,27 +260,28 @@ def _get_act_deps(
# q input node
q_input = get_input_node(dq_input, 0)
if not is_quant(q_input):
why(node, "Expected dequant input to be quant node")
return (False, [])

gemm_deps.append(q_input)
q_input_args = q_input.args
if is_affine_qdq(q_input):
q_input_args = extract_qdq_affine_op_args_for_decomposed_ops(q_input)
if not (is_node(q_input_args[1]) and is_node(q_input_args[2])):
# expected to find getitem node from choose qparam
why(node, "expected to find getitem node from choose qparam")
return (False, [])

getitem1 = q_input_args[1]
getitem2 = q_input_args[2]

if not (is_getitem(getitem1) and is_getitem(getitem2)):
# expected getitem node from choose qparam
why(node, "expected getitem node from choose qparam")
return (False, [])

gemm_deps.extend([getitem1, getitem2])
choose_qparam = get_input_node(getitem1, 0)
if not is_qparam(choose_qparam):
# expected to find choose_qparam node
why(node, "expected to find choose_qparam node")
return (False, [])
gemm_deps.append(choose_qparam)
return (True, gemm_deps)
Expand Down Expand Up @@ -471,6 +489,7 @@ def find_partition_args(input_node):
# there can only be a single output node in partition
or len(src_partition.output_nodes) != 1
):
why(node, "invalid source partition")
return (False, [])

# map addmm's args to the source partition linear's inputs and users
Expand Down
74 changes: 74 additions & 0 deletions backends/xnnpack/test/ops/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,66 @@ def _test_qd8_per_channel_linear(self, dtype: torch.dtype = torch.float):
uses_bias=uses_bias,
)

def _test_qd8_linear_per_tensor_unsupported(self, dtype: torch.dtype = torch.float):
for uses_bias in (False, True):
module = BaseLinear(
in_size=8,
input_channels=13,
output_channels=17,
dtype=dtype,
use_bias=uses_bias,
)
inputs = module.get_inputs()
dynamic_shapes = ({1: torch.export.Dim("batch", max=100)},)

quant_config = get_symmetric_quantization_config(
is_per_channel=False,
is_dynamic=True,
)

for legacy_partitioner in (True, False):
for per_op_mode in (True, False):
# Every combination should fail to partition Linear or [add]mm.
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
per_op_mode=per_op_mode,
)

tester = Tester(module, inputs, dynamic_shapes=dynamic_shapes)
tester.quantize(Quantize(quantization_config=quant_config))
tester.export()

if legacy_partitioner:
tester.to_edge()
tester.partition(
Partition(DynamicallyQuantizedPartitioner)
).dump_artifact()
# should have [add]mm node
if uses_bias:
tester.check(
[
"executorch_exir_dialects_edge__ops_aten_addmm_default",
]
)
else:
tester.check(
[
"executorch_exir_dialects_edge__ops_aten_mm_default",
]
)
else:
tester.to_edge_transform_and_lower(
ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner])
).dump_artifact()
# should not have a delegate node
tester.check_not(
[
"torch.ops.higher_order.executorch_call_delegate",
]
)
# No need to run the model, since it should fail to partition.
return

def _test_qd8_per_channel_4w_linear(self, dtype: torch.dtype = torch.float):
qconfig = self._get_4b_dqconfig()
input_channels = [2, 63]
Expand Down Expand Up @@ -697,10 +757,24 @@ def test_qs8_linear(self):
def test_qd8_f16_per_channel_linear(self):
self._test_qd8_per_channel_linear(dtype=torch.half)

def test_qd8_f16_per_tensor_linear(self):
"""
XNNPACK doesn't support per_tensor quantized weights for dynamic quantized linear op.
This test is to verify that we can't lower per_tensor quantized weights to per_channel quantized weights.
"""
self._test_qd8_linear_per_tensor_unsupported(dtype=torch.half)

# Tests for q[dp]8-f32-qc8w
def test_qd8_f32_per_channel_linear(self):
self._test_qd8_per_channel_linear(dtype=torch.float)

def test_qd8_f32_per_tensor_linear(self):
"""
XNNPACK doesn't support per_tensor quantized weights for dynamic quantized linear op.
This test is to verify that we can't lower per_tensor quantized weights to per_channel quantized weights.
"""
self._test_qd8_linear_per_tensor_unsupported(dtype=torch.half)

# Tests for q[dp]8-f16-qc4w
def test_linear_qd8_f16_per_channel_int4(self):
self._test_qd8_per_channel_4w_linear(dtype=torch.half)
Expand Down
9 changes: 9 additions & 0 deletions backends/xnnpack/utils/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ def is_per_channel(node: torch.fx.Node) -> bool:
return is_per_channel or is_affine_per_channel_group


def is_per_tensor(node: torch.fx.Node) -> bool:
if not (is_quant(node) or is_dequant(node)):
return False

is_per_tensor = "per_tensor" in node.target.__name__ # pyre-ignore

return is_per_tensor and not (is_per_channel(node))


def is_affine_qdq(node: torch.fx.Node) -> bool:
if not (is_quant(node) or is_dequant(node)):
return False
Expand Down
Loading