Skip to content

Commit b41626f

Browse files
committed
Remove redundancy pass & add dequant const param check
1 parent 6cc2264 commit b41626f

File tree

4 files changed

+13
-6
lines changed

4 files changed

+13
-6
lines changed

backends/qualcomm/passes/annotate_quant_attrs.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,11 @@ def _dequant_fold_params(self, n, quant_attrs, param):
9494
def _annotate_quant_attrs(
9595
self, graph_module: torch.fx.GraphModule
9696
) -> torch.fx.GraphModule:
97+
# Keep track of const params that has been dequant, so it does not get
98+
# dequant multiple times if the const param has more than 1 user
99+
visited_const_param = set()
97100
for n in graph_module.graph.nodes:
98101
self._annotate_requant(n)
99-
100102
# With fold_quant enabled, check if the input of dq op is quantized param.
101103
param = None
102104
if n.target in dq_ops:
@@ -106,7 +108,8 @@ def _annotate_quant_attrs(
106108
quant_attrs = get_quant_attrs(self.edge_program, n)
107109
self._annotate_source_nodes(n, quant_attrs)
108110

109-
if param is not None:
111+
if param is not None and n.args[0] not in visited_const_param:
112+
visited_const_param.add(n.args[0])
110113
self._dequant_fold_params(n, quant_attrs, param)
111114

112115
return graph_module

backends/qualcomm/passes/recompose_pixel_unshuffle.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,13 @@ def call(self, graph_module: torch.fx.GraphModule):
3535
for node in graph.nodes:
3636
if node.op == "call_function" and node.target == self.reshape_target:
3737
with graph.inserting_after(node):
38-
premute_node = node.args[0]
38+
39+
# Clone op still exists between permute and reshape_target during quantization,
40+
# so we need to check for args[0].args[0] to get permute node
41+
if self.quantization_capture:
42+
premute_node = node.args[0].args[0]
43+
else:
44+
premute_node = node.args[0]
3945
if any(
4046
[
4147
len(node.args[1]) != 4,

backends/qualcomm/quantizer/quantizer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
RecomposePixelUnshuffle,
1313
)
1414
from executorch.backends.qualcomm.passes.reduce_dynamic_range import ReduceDynamicRange
15-
from executorch.backends.qualcomm.passes.remove_redundancy import RemoveRedundancy
1615
from executorch.backends.qualcomm.passes.replace_inf_buffer import ReplaceInfBuffer
1716
from executorch.backends.transforms.decompose_sdpa import (
1817
DecomposeScaledDotProductAttention,
@@ -182,7 +181,6 @@ def set_per_channel_linear_quant(self, enable: bool) -> None:
182181
self._update_per_channel_weight_quant_ops(linear_ops, enable)
183182

184183
def transform_for_annotation(self, model: GraphModule) -> GraphModule:
185-
model = RemoveRedundancy()(model).graph_module
186184
model = ReduceDynamicRange()(model).graph_module
187185
model = RecomposePixelUnshuffle(quantization_capture=True)(model).graph_module
188186
model = DecomposeScaledDotProductAttention()(model).graph_module

examples/qualcomm/scripts/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def build_executorch_binary(
187187
quantizer = QnnQuantizer()
188188
quantizer.add_custom_quant_annotations(custom_annotations)
189189
quantizer.set_per_channel_linear_quant(per_channel_linear)
190+
quantizer.set_per_channel_conv_quant(True)
190191

191192
if quant_dtype == QuantDtype.use_8a8w:
192193
pass # default setting
@@ -214,7 +215,6 @@ def build_executorch_binary(
214215
for data in dataset:
215216
annotated_model(*data)
216217
quantized_model = convert_pt2e(annotated_model)
217-
218218
edge_prog = capture_program(quantized_model, inputs)
219219
else:
220220
edge_prog = capture_program(model, inputs)

0 commit comments

Comments
 (0)