diff --git a/backends/xnnpack/partition/config/gemm_configs.py b/backends/xnnpack/partition/config/gemm_configs.py index 872ba355c70..9b10c3be530 100644 --- a/backends/xnnpack/partition/config/gemm_configs.py +++ b/backends/xnnpack/partition/config/gemm_configs.py @@ -21,6 +21,7 @@ is_dynamic_qdq, is_per_channel, is_per_channel_group, + is_per_tensor, is_qparam, is_quant, ) @@ -66,8 +67,6 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: return False is_valid, _ = self.get_deps(node, ep) - if not is_valid: - why(node, "Failed to get valid dependent nodes.") return is_valid def get_node_and_deps( @@ -123,6 +122,7 @@ def get_deps( precision = self._detect_precision(node) if precision not in self.supported_precision_types(): # detected precision but it is either disabled or not supported + why(node, f"Unsupported precision type {precision}") return (False, []) _, precision = self._overwrite_precision(node) valid_bias, bias_deps = self._get_bias_deps(node, ep, precision) @@ -143,7 +143,8 @@ def _get_weight_deps( # First find the weight weight_node = get_input_node(node, self.weight_idx) if not is_param_node(ep, weight_node): - return (False, []) # weight must be a static param + why(node, "Expected weight to be a static param") + return (False, []) gemm_deps.append(weight_node) return (True, gemm_deps) @@ -151,19 +152,33 @@ def _get_weight_deps( # Quantized Weight deps dequant_node = get_input_node(node, self.weight_idx) if not is_dequant(dequant_node): + why(node, "Expected weight to have a dequantized node") return False, [] gemm_deps.append(dequant_node) weight = get_input_node(dequant_node, 0) if not is_param_node(ep, weight): + why(node, "Expected weight to be a static param") return False, [] gemm_deps.append(weight) + if ( + is_per_tensor(dequant_node) + and precision == ConfigPrecisionType.DYNAMIC_QUANT + ): + why( + node, + "XNNPACK does not support per tensor quantized weights for dynamic quantization of activations", + ) + return False, [] + if is_per_channel(dequant_node) or is_per_channel_group(dequant_node): if len(dequant_node.all_input_nodes) < 2: # Expected channel quantized to have scale/zp nodes + why(node, "Expected channel quantized to have scale/zp nodes") return False, [] gemm_deps.extend(dequant_node.all_input_nodes[1:3]) + return (True, gemm_deps) def _get_output_deps( @@ -174,7 +189,7 @@ def _get_output_deps( # Look for fused activations and tail end quant node node_users = list(node.users.keys()) if len(node_users) != 1: - # Expect quantized node to have a single output (fused act or dequant) + why(node, "Expected quantized node to have a single output") return False, [] # Check if the quantized pattern has a fused activation @@ -190,6 +205,7 @@ def _get_output_deps( if not is_quant(n_output): # Expected gemm_node --> fused_act (optional) --> dequant + why(node, "Expected output node to have a dequantized node") return (False, []) gemm_deps.append(n_output) elif precision == ConfigPrecisionType.FP32: @@ -219,7 +235,8 @@ def _get_bias_deps( bias_node = get_input_node(node, self.bias_idx) if bias_node: if not is_param_node(ep, bias_node): - return (False, []) # bias node must be a static param + why(node, "Expected bias to be a static param") + return (False, []) gemm_deps.append(bias_node) return (True, gemm_deps) @@ -233,7 +250,7 @@ def _get_act_deps( else: dq_input = get_input_node(node, self.act_idx) if not is_dequant(dq_input): - # Expected static quant input to be dequant node + why(node, "Expected act input to be dequant node") return False, [] gemm_deps.append(dq_input) if precision == ConfigPrecisionType.STATIC_QUANT: @@ -243,6 +260,7 @@ def _get_act_deps( # q input node q_input = get_input_node(dq_input, 0) if not is_quant(q_input): + why(node, "Expected dequant input to be quant node") return (False, []) gemm_deps.append(q_input) @@ -250,20 +268,20 @@ def _get_act_deps( if is_affine_qdq(q_input): q_input_args = extract_qdq_affine_op_args_for_decomposed_ops(q_input) if not (is_node(q_input_args[1]) and is_node(q_input_args[2])): - # expected to find getitem node from choose qparam + why(node, "expected to find getitem node from choose qparam") return (False, []) getitem1 = q_input_args[1] getitem2 = q_input_args[2] if not (is_getitem(getitem1) and is_getitem(getitem2)): - # expected getitem node from choose qparam + why(node, "expected getitem node from choose qparam") return (False, []) gemm_deps.extend([getitem1, getitem2]) choose_qparam = get_input_node(getitem1, 0) if not is_qparam(choose_qparam): - # expected to find choose_qparam node + why(node, "expected to find choose_qparam node") return (False, []) gemm_deps.append(choose_qparam) return (True, gemm_deps) @@ -471,6 +489,7 @@ def find_partition_args(input_node): # there can only be a single output node in partition or len(src_partition.output_nodes) != 1 ): + why(node, "invalid source partition") return (False, []) # map addmm's args to the source partition linear's inputs and users diff --git a/backends/xnnpack/test/ops/test_linear.py b/backends/xnnpack/test/ops/test_linear.py index 30bb4f0aba2..b56a746651c 100644 --- a/backends/xnnpack/test/ops/test_linear.py +++ b/backends/xnnpack/test/ops/test_linear.py @@ -539,6 +539,66 @@ def _test_qd8_per_channel_linear(self, dtype: torch.dtype = torch.float): uses_bias=uses_bias, ) + def _test_qd8_linear_per_tensor_unsupported(self, dtype: torch.dtype = torch.float): + for uses_bias in (False, True): + module = BaseLinear( + in_size=8, + input_channels=13, + output_channels=17, + dtype=dtype, + use_bias=uses_bias, + ) + inputs = module.get_inputs() + dynamic_shapes = ({1: torch.export.Dim("batch", max=100)},) + + quant_config = get_symmetric_quantization_config( + is_per_channel=False, + is_dynamic=True, + ) + + for legacy_partitioner in (True, False): + for per_op_mode in (True, False): + # Every combination should fail to partition Linear or [add]mm. + DynamicallyQuantizedPartitioner = XnnpackPartitioner( + config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, + per_op_mode=per_op_mode, + ) + + tester = Tester(module, inputs, dynamic_shapes=dynamic_shapes) + tester.quantize(Quantize(quantization_config=quant_config)) + tester.export() + + if legacy_partitioner: + tester.to_edge() + tester.partition( + Partition(DynamicallyQuantizedPartitioner) + ).dump_artifact() + # should have [add]mm node + if uses_bias: + tester.check( + [ + "executorch_exir_dialects_edge__ops_aten_addmm_default", + ] + ) + else: + tester.check( + [ + "executorch_exir_dialects_edge__ops_aten_mm_default", + ] + ) + else: + tester.to_edge_transform_and_lower( + ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner]) + ).dump_artifact() + # should not have a delegate node + tester.check_not( + [ + "torch.ops.higher_order.executorch_call_delegate", + ] + ) + # No need to run the model, since it should fail to partition. + return + def _test_qd8_per_channel_4w_linear(self, dtype: torch.dtype = torch.float): qconfig = self._get_4b_dqconfig() input_channels = [2, 63] @@ -697,10 +757,24 @@ def test_qs8_linear(self): def test_qd8_f16_per_channel_linear(self): self._test_qd8_per_channel_linear(dtype=torch.half) + def test_qd8_f16_per_tensor_linear(self): + """ + XNNPACK doesn't support per_tensor quantized weights for dynamic quantized linear op. + This test is to verify that we can't lower per_tensor quantized weights to per_channel quantized weights. + """ + self._test_qd8_linear_per_tensor_unsupported(dtype=torch.half) + # Tests for q[dp]8-f32-qc8w def test_qd8_f32_per_channel_linear(self): self._test_qd8_per_channel_linear(dtype=torch.float) + def test_qd8_f32_per_tensor_linear(self): + """ + XNNPACK doesn't support per_tensor quantized weights for dynamic quantized linear op. + This test is to verify that we can't lower per_tensor quantized weights to per_channel quantized weights. + """ + self._test_qd8_linear_per_tensor_unsupported(dtype=torch.half) + # Tests for q[dp]8-f16-qc4w def test_linear_qd8_f16_per_channel_int4(self): self._test_qd8_per_channel_4w_linear(dtype=torch.half) diff --git a/backends/xnnpack/utils/quant_utils.py b/backends/xnnpack/utils/quant_utils.py index 7c035757a6f..49c5a963161 100644 --- a/backends/xnnpack/utils/quant_utils.py +++ b/backends/xnnpack/utils/quant_utils.py @@ -89,6 +89,15 @@ def is_per_channel(node: torch.fx.Node) -> bool: return is_per_channel or is_affine_per_channel_group +def is_per_tensor(node: torch.fx.Node) -> bool: + if not (is_quant(node) or is_dequant(node)): + return False + + is_per_tensor = "per_tensor" in node.target.__name__ # pyre-ignore + + return is_per_tensor and not (is_per_channel(node)) + + def is_affine_qdq(node: torch.fx.Node) -> bool: if not (is_quant(node) or is_dequant(node)): return False