Skip to content

Commit 63017e4

Browse files
authored
Search graph for quantization nodes (#6452)
* Search graph for quantization nodes Generalizes the search for quantization parameters. The idea is to make a graph like this a valid quantized graph: dq -> view -> transpose -> some_op ^ / dq ------> expand -------/ For a subset of operations 'passable_op' it is is allowed to "pass through" the op when searching for qparams. If multiple qparams are encounterd in one search, they are asserted to be equal. Signed-off-by: Erik Lundell <[email protected]> Change-Id: I6dbb82fb39164c246ea74a9642d907dba22ab2c3 * Add searching qparams to op_tanh Signed-off-by: Erik Lundell <[email protected]> Change-Id: I029d95ecdfa85f5cdc63997ad1eb7515a016bae4 --------- Signed-off-by: Erik Lundell <[email protected]>
1 parent 244546b commit 63017e4

24 files changed

+292
-175
lines changed

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
get_first_fake_tensor,
1515
insert_q_dq_pair,
1616
)
17-
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
17+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op
1818
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
1919
from executorch.exir.dialects._ops import ops as exir_ops
2020
from executorch.exir.pass_base import ExportPass, PassResult
@@ -42,6 +42,9 @@ def _transpose_impl(*args, **kwargs):
4242
return args[0]
4343

4444

45+
register_passable_op(torch.ops.passthrough_to_tosa._transpose)
46+
47+
4548
class AnnotateChannelsLastDimOrder(ExportPass):
4649
"""
4750
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order

backends/arm/_passes/insert_squeeze_after_sum_pass.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88

99
import torch
1010
import torch.fx
11-
from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_q_dq_pair
12-
13-
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, is_quant_node
11+
from executorch.backends.arm._passes.arm_pass_utils import create_node
1412
from executorch.exir.dialects._ops import ops as exir_ops
1513
from executorch.exir.pass_base import ExportPass, PassResult
1614

@@ -28,8 +26,6 @@ class InsertSqueezeAfterSumPass(ExportPass):
2826
sum(dims, keep_dim = False)
2927
After pass:
3028
sum(dims, keep_dim = True)
31-
(q)
32-
(dq)
3329
squeeze(dim = dims)
3430
"""
3531

@@ -45,12 +41,6 @@ def call(self, graph_module: torch.fx.GraphModule):
4541
continue
4642

4743
dim_list = cast(list[int], sum_node.args[1])
48-
quantized = is_quant_node(sum_node)
49-
if quantized:
50-
qparams = get_quant_node_args(sum_node.all_input_nodes[0])
51-
qparams = qparams + (torch.int8,)
52-
else:
53-
qparams = None
5444

5545
# Add keep_dim = True arg to sum node.
5646
sum_node.args = sum_node.args[0:2] + (True,)
@@ -61,8 +51,6 @@ def call(self, graph_module: torch.fx.GraphModule):
6151
)
6252
sum_node.replace_all_uses_with(squeeze_node)
6353
squeeze_node.args = (sum_node, dim_list)
64-
if quantized:
65-
sum_node = insert_q_dq_pair(graph_module.graph, sum_node, qparams)
6654
graph_module.graph.eliminate_dead_code()
6755
graph_module.recompile()
6856
graph_module = super().call(graph_module).graph_module

backends/arm/_passes/size_adjust_conv2d_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import cast, Optional
1010

1111
import torch.fx
12-
from executorch.backends.arm.tosa_quant_utils import is_quant_node
12+
from executorch.backends.arm.tosa_quant_utils import is_node_quantized
1313
from executorch.exir.dialects._ops import ops as exir_ops
1414
from executorch.exir.pass_base import ExportPass, PassResult
1515
from torch._ops import OpOverload
@@ -113,7 +113,7 @@ def call(self, graph_module: torch.fx.GraphModule):
113113
slice_node = graph.create_node(
114114
"call_function", self.slice_op, (last_node,) + args
115115
)
116-
if is_quant_node(last_node):
116+
if is_node_quantized(last_node):
117117
q_params = last_node.args[1:]
118118
dq_node = insert_q_dq_pair(
119119
graph_module.graph, slice_node, q_params

backends/arm/operators/op_addmm.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
register_node_visitor,
1515
)
1616
from executorch.backends.arm.tosa_mapping import TosaArg
17-
from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args
17+
from executorch.backends.arm.tosa_quant_utils import (
18+
build_rescale,
19+
search_quant_arg_downstream,
20+
search_quant_arg_upstream,
21+
)
1822

1923
from executorch.backends.arm.tosa_utils import build_reshape
20-
from executorch.exir.dialects._ops import ops as exir_ops
2124
from serializer.tosa_serializer import TosaOp
2225

2326

@@ -67,12 +70,7 @@ def define_node(
6770
input_zp = 0
6871
if is_quant_node:
6972
input_node = node.all_input_nodes[1]
70-
# rank > 2 linear layer
71-
if input_node.target == exir_ops.edge.aten.view_copy.default:
72-
quant_node = input_node.all_input_nodes[0]
73-
else:
74-
quant_node = input_node
75-
input_zp = get_quant_node_args(quant_node).zp
73+
input_zp = search_quant_arg_upstream(input_node).zp
7674
attr.ConvAttribute(
7775
pad=pad_attr,
7876
stride=stride_attr,
@@ -107,24 +105,16 @@ def define_node(
107105
# Read inputs' parent nodes
108106
_, input_node, weight_node = node.all_input_nodes
109107

110-
# rank > 2 linear layer
111-
if input_node.target == exir_ops.edge.aten.view_copy.default:
112-
quant_node = input_node.all_input_nodes[0]
113-
input_scale = get_quant_node_args(quant_node).scale
114-
consumer_node = list(node.users)[0]
115-
consumer_consumer_node = list(consumer_node.users)[0]
116-
quant_args = get_quant_node_args(consumer_consumer_node)
117-
consumer_node_scale = quant_args.scale
118-
consumer_node_node_zp = quant_args.zp
119-
else:
120-
input_scale = get_quant_node_args(input_node).scale
121-
consumer_node = list(node.users)[0]
122-
quant_args = get_quant_node_args(consumer_node)
123-
consumer_node_scale = quant_args.scale
124-
consumer_node_node_zp = quant_args.zp
108+
qargs = search_quant_arg_upstream(input_node)
109+
input_scale = qargs.scale
110+
consumer_node = list(node.users)[0]
111+
quant_args = search_quant_arg_downstream(consumer_node)
112+
113+
consumer_node_scale = quant_args.scale
114+
consumer_node_node_zp = quant_args.zp
125115

126116
weight_node_q_node = weight_node.all_input_nodes[0]
127-
weight_scale = get_quant_node_args(weight_node_q_node).scale
117+
weight_scale = search_quant_arg_upstream(weight_node_q_node).scale
128118

129119
output_rescale_scale = (input_scale * weight_scale) / consumer_node_scale
130120

backends/arm/operators/op_bmm.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
register_node_visitor,
1515
)
1616
from executorch.backends.arm.tosa_mapping import TosaArg
17-
from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args
17+
from executorch.backends.arm.tosa_quant_utils import (
18+
build_rescale,
19+
search_quant_arg_downstream,
20+
search_quant_arg_upstream,
21+
)
1822
from executorch.backends.arm.tosa_utils import get_two_inputs
1923
from serializer.tosa_serializer import TosaOp
2024

@@ -42,8 +46,10 @@ def define_node(
4246
# For INT8, we need to get the zero points and add an intermediate tensor
4347
# for a later rescale.
4448
if is_quant_node:
45-
input0_zp = get_quant_node_args(input0).zp
46-
input1_zp = get_quant_node_args(input1).zp
49+
input0_q_params = search_quant_arg_upstream(input0)
50+
input1_q_params = search_quant_arg_upstream(input1)
51+
input0_zp = input0_q_params.zp
52+
input1_zp = input1_q_params.zp
4753
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
4854
bmm_output_name = bmm_result.name
4955
else:
@@ -63,9 +69,7 @@ def define_node(
6369

6470
# As INT8 accumulates into INT32, we need to rescale it back to INT8
6571
if is_quant_node:
66-
input0_q_params = get_quant_node_args(input0)
67-
input1_q_params = get_quant_node_args(input1)
68-
output_q_params = get_quant_node_args(list(node.users)[0])
72+
output_q_params = search_quant_arg_downstream(list(node.users)[0])
6973

7074
final_output_scale = (
7175
input0_q_params.scale * input1_q_params.scale

backends/arm/operators/op_conv2d.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import cast, List
7+
from typing import List
88

99
import serializer.tosa_serializer as ts
1010
import torch
@@ -15,9 +15,10 @@
1515
from executorch.backends.arm.tosa_mapping import TosaArg
1616
from executorch.backends.arm.tosa_quant_utils import (
1717
build_rescale_conv_output,
18-
get_quant_node_args,
18+
search_quant_arg_downstream,
19+
search_quant_arg_upstream,
1920
)
20-
from executorch.backends.arm.tosa_utils import build_reshape, getNodeArgs, tosa_shape
21+
from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape
2122

2223
from serializer.tosa_serializer import TosaOp
2324

@@ -82,7 +83,9 @@ def define_node(
8283
)
8384

8485
input_zp = (
85-
get_quant_node_args(node.all_input_nodes[0]).zp if is_quant_node else 0
86+
search_quant_arg_upstream(node.all_input_nodes[0]).zp
87+
if is_quant_node
88+
else 0
8689
)
8790

8891
attr.ConvAttribute(
@@ -158,9 +161,10 @@ def define_node(
158161
# integer value domain of the next op. Otherwise return float32 output.
159162
if is_quant_node:
160163
# Get scale_factor from input, weight, and output.
161-
_, input_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[0]))
162-
_, weight_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[1]))
163-
_, output_scale, output_zp, _, _, _ = getNodeArgs(list(node.users)[0])
164+
input_scale = search_quant_arg_upstream(node.all_input_nodes[0]).scale
165+
weight_scale = search_quant_arg_upstream(node.all_input_nodes[1]).scale
166+
output_qargs = search_quant_arg_downstream(list(node.users)[0])
167+
164168
build_rescale_conv_output(
165169
tosa_graph,
166170
# pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined.
@@ -169,6 +173,6 @@ def define_node(
169173
actual_out_type,
170174
input_scale,
171175
weight_scale,
172-
output_scale,
173-
output_zp,
176+
output_qargs.scale,
177+
output_qargs.zp,
174178
)

backends/arm/operators/op_exp.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
from executorch.backends.arm.tosa_quant_utils import (
1919
dequantize_value,
20-
get_quant_node_args,
2120
QuantArgs,
2221
quantize_value,
22+
search_quant_arg_downstream,
23+
search_quant_arg_upstream,
2324
)
2425
from serializer.tosa_serializer import TosaOp
2526
from torch.fx import Node
@@ -48,9 +49,9 @@ def define_node(
4849

4950
# Create attribute for 8 bit table lookup.
5051
input_node = node.all_input_nodes[0]
51-
in_quantargs = get_quant_node_args(input_node)
52+
in_quantargs = search_quant_arg_upstream(input_node)
5253
output_node = list(node.users)[0]
53-
out_quantargs = get_quant_node_args(output_node)
54+
out_quantargs = search_quant_arg_downstream(output_node)
5455

5556
table = exp_table_8bit(in_quantargs, out_quantargs)
5657
table_attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_full.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
register_node_visitor,
1515
)
1616
from executorch.backends.arm.tosa_mapping import TosaArg
17-
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args
17+
from executorch.backends.arm.tosa_quant_utils import (
18+
quantize_value,
19+
search_quant_arg_downstream,
20+
)
1821
from executorch.backends.arm.tosa_utils import tosa_shape
1922
from torch.fx import Node
2023

@@ -39,10 +42,8 @@ def define_node(
3942

4043
value = inputs[1].number
4144
if is_quant_node:
42-
qargs = get_quant_node_args(list(node.users)[0])
43-
qvalue = np.clip(
44-
np.round(value / qargs.scale) + qargs.zp, qargs.qmin, qargs.qmax
45-
)
45+
qargs = search_quant_arg_downstream(list(node.users)[0])
46+
qvalue = quantize_value(value, qargs)
4647
dtype = ts.DType.INT8
4748
data = np.full(shape, qvalue, dtype=np.int8)
4849
else:

backends/arm/operators/op_hardtanh.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
)
1515
from executorch.backends.arm.tosa_mapping import TosaArg
1616

17-
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args
17+
from executorch.backends.arm.tosa_quant_utils import (
18+
quantize_value,
19+
search_quant_arg_upstream,
20+
)
1821
from serializer.tosa_serializer import TosaOp
1922

2023

@@ -37,12 +40,10 @@ def define_node(
3740

3841
if is_quant_node:
3942
# Get quant parameters
40-
scale, zp, qmin, qmax = get_quant_node_args(node.all_input_nodes[0])
43+
qargs = search_quant_arg_upstream(node.all_input_nodes[0])
4144
# Convert to quantized representation
42-
clamp_min_qs = round((inputs[1].number / scale) + zp)
43-
clamp_min_qs = max(clamp_min_qs, qmin)
44-
clamp_max_qs = round((inputs[2].number / scale) + zp)
45-
clamp_max_qs = min(clamp_max_qs, qmax)
45+
clamp_min_qs = quantize_value(inputs[1].number, qargs)
46+
clamp_max_qs = quantize_value(inputs[2].number, qargs)
4647
# Set fp values to 0.0 since they are not used
4748
clamp_min_fp = 0.0
4849
clamp_max_fp = 0.0

backends/arm/operators/op_log.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
from executorch.backends.arm.tosa_quant_utils import (
1919
dequantize_value,
20-
get_quant_node_args,
2120
QuantArgs,
2221
quantize_value,
22+
search_quant_arg_downstream,
23+
search_quant_arg_upstream,
2324
)
2425
from serializer.tosa_serializer import TosaOp
2526
from torch.fx import Node
@@ -49,9 +50,9 @@ def define_node(
4950

5051
# Create attribute for 8 bit table lookup.
5152
input_node = node.all_input_nodes[0]
52-
in_quantargs = get_quant_node_args(input_node)
53+
in_quantargs = search_quant_arg_upstream(input_node)
5354
output_node = list(node.users)[0]
54-
out_quantargs = get_quant_node_args(output_node)
55+
out_quantargs = search_quant_arg_downstream(output_node)
5556

5657
table = log_table_8bit(in_quantargs, out_quantargs)
5758
table_attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_mm.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
register_node_visitor,
1515
)
1616
from executorch.backends.arm.tosa_mapping import TosaArg
17-
from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args
17+
from executorch.backends.arm.tosa_quant_utils import (
18+
build_rescale,
19+
search_quant_arg_downstream,
20+
search_quant_arg_upstream,
21+
)
1822
from executorch.backends.arm.tosa_utils import (
1923
build_reshape,
2024
expand_dims,
@@ -54,8 +58,8 @@ def define_node(
5458
# For INT8, we need to get the zero point, otherwise it is 0
5559
input0_zp, input1_zp = 0, 0
5660
if is_quant_node:
57-
input0_zp = get_quant_node_args(input0).zp
58-
input1_zp = get_quant_node_args(input1).zp
61+
input0_zp = search_quant_arg_upstream(input0).zp
62+
input1_zp = search_quant_arg_upstream(input1).zp
5963

6064
mat_mul_result = tosa_graph.addIntermediate(
6165
output_new_shape, ts.DType.INT32 if is_quant_node else output.dtype
@@ -86,9 +90,9 @@ def define_node(
8690

8791
# As INT8 accumulates into INT32, we need to rescale it back to INT8
8892
if is_quant_node:
89-
input0_q_params = get_quant_node_args(input0)
90-
input1_q_params = get_quant_node_args(input1)
91-
output_q_params = get_quant_node_args(list(node.users)[0])
93+
input0_q_params = search_quant_arg_upstream(input0)
94+
input1_q_params = search_quant_arg_upstream(input1)
95+
output_q_params = search_quant_arg_downstream(list(node.users)[0])
9296

9397
final_output_scale = (
9498
input0_q_params.scale * input1_q_params.scale

backends/arm/operators/op_mul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ def define_node(
3737
if is_quant_node:
3838
input_A = inputs[0]
3939
input_B = inputs[1]
40-
input_A_qargs = tqutils.get_quant_node_args(
40+
input_A_qargs = tqutils.search_quant_arg_upstream(
4141
cast(torch.fx.Node, node.args[0])
4242
)
43-
input_B_qargs = tqutils.get_quant_node_args(
43+
input_B_qargs = tqutils.search_quant_arg_upstream(
4444
cast(torch.fx.Node, node.args[1])
4545
)
4646

0 commit comments

Comments
 (0)