Skip to content

Commit f1a3a1e

Browse files
committed
Add initial lowering of aten.convolution to tosa.conv2d support
1 parent 17fee78 commit f1a3a1e

File tree

4 files changed

+269
-25
lines changed

4 files changed

+269
-25
lines changed

backends/arm/arm_backend.py

Lines changed: 87 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,10 @@ def preprocess( # noqa: C901
246246
if path is None:
247247
path = tempfile.mkdtemp(prefix="arm_tosa_")
248248

249+
# Verify if this is a quantized model ahead so that the tensor data type of
250+
# tosa operations during lowering can be easier determined.
251+
is_quantized_model = tosa_quant_utils.isQuantizedModel(edge_program.graph)
252+
249253
# Converted output for this subgraph, serializer needs path early as it emits
250254
# const data directly. Path created and data written only in debug builds.
251255
tosa_fb = ts.TosaSerializer(path)
@@ -476,10 +480,15 @@ def preprocess( # noqa: C901
476480
elif exir_ops.edge.aten.convolution.default == node.target:
477481
input, weight, bias, stride, pad, dilation, _, _, group = inputs
478482

483+
# Currently only int8 is supported in quantized types.
484+
actual_out_type = (
485+
ts.DType.INT8 if is_quantized_model else outp.dtype
486+
)
487+
479488
## Transpose input tensor to NHWC_Order for TOSA
480489
NHWC_Order = [0, 2, 3, 1]
481490
input_transposed = transpose_helper(
482-
tosa_fb, input, NHWC_Order, outp.dtype
491+
tosa_fb, input, NHWC_Order, actual_out_type
483492
)
484493

485494
## CONV2DOp
@@ -523,14 +532,17 @@ def preprocess( # noqa: C901
523532
# Transpose weight to [OC, H, W, IC]
524533
weight_CHWC_Order = [0, 2, 3, 1]
525534
weight_transposed = transpose_helper(
526-
tosa_fb, weight, weight_CHWC_Order, outp.dtype
535+
tosa_fb, weight, weight_CHWC_Order, actual_out_type
527536
)
528537

529538
## TOSA output shape is [NHWO]
530539
NHWO_Order = [0, 2, 3, 1]
531540
out_shape_TOSA_CONV2D = [outp.shape[i] for i in NHWO_Order]
541+
542+
# The output type is int32 when input type is int8.
532543
conv2d_res = tosa_fb.addIntermediate(
533-
out_shape_TOSA_CONV2D, outp.dtype
544+
out_shape_TOSA_CONV2D,
545+
ts.DType.INT32 if is_quant_node else outp.dtype,
534546
)
535547
tosa_fb.addOperator(
536548
TosaOp.Op().CONV2D,
@@ -547,12 +559,45 @@ def preprocess( # noqa: C901
547559
NOHW_Order = [0, 3, 1, 2]
548560
attr_output_transpose = ts.TosaSerializerAttribute()
549561
attr_output_transpose.TransposeAttribute(NOHW_Order)
550-
tosa_fb.addOperator(
551-
TosaOp.Op().TRANSPOSE,
552-
[conv2d_res.name],
553-
[outp.name],
554-
attr_output_transpose,
555-
)
562+
563+
if len(node.all_input_nodes) == 3:
564+
input_node, weight_node, bias_node = node.all_input_nodes
565+
else:
566+
raise AssertionError(
567+
"non-biased conv2d is not supported for now"
568+
)
569+
570+
output_node = list(node.users)[0]
571+
572+
# For quantized convolution, rescale the output value back to the same
573+
# integer value domain of the next op. Otherwise return float32 output.
574+
if is_quant_node:
575+
# Get scale_factor from input, weight, and output.
576+
_, input_scale, _, _, _, _ = getNodeArgs(input_node)
577+
_, weight_scale, _, _, _, _ = getNodeArgs(weight_node)
578+
_, output_scale, _, _, _, _ = getNodeArgs(output_node)
579+
rescaled_conv2d_res = tosa_quant_utils.buildRescaleOpConvOutput(
580+
tosa_fb,
581+
conv2d_res,
582+
actual_out_type,
583+
input_scale,
584+
weight_scale,
585+
output_scale,
586+
)
587+
tosa_fb.addOperator(
588+
TosaOp.Op().TRANSPOSE,
589+
[rescaled_conv2d_res.name],
590+
[outp.name],
591+
attr_output_transpose,
592+
)
593+
else:
594+
tosa_fb.addOperator(
595+
TosaOp.Op().TRANSPOSE,
596+
[conv2d_res.name],
597+
[outp.name],
598+
attr_output_transpose,
599+
)
600+
556601
elif exir_ops.edge.aten.div.Tensor == node.target:
557602
# Div is implemented as x/y = x*1/y
558603
recip = tosa_fb.addIntermediate(inputs[1].shape, inputs[1].dtype)
@@ -802,7 +847,7 @@ def preprocess( # noqa: C901
802847
p_data = edge_program.state_dict[parameter_name]
803848

804849
assert isinstance(p_data, torch.Tensor), "Expect Attr to be tensor"
805-
weight_values = p_data.detach().numpy()
850+
ph_values = p_data.detach().numpy()
806851

807852
# Check if they're for quantized nodes
808853
consumer_node = list(node.users)[0]
@@ -811,14 +856,14 @@ def preprocess( # noqa: C901
811856
consumer_node
812857
)
813858

814-
weight_values_quantized = (
815-
(weight_values / weight_node_scale.number)
859+
ph_values_quantized = (
860+
(ph_values / weight_node_scale.number)
816861
+ weight_node_zp.number
817862
).astype(np.int8)
818863
tosa_fb.addConst(
819864
inputs[0].shape,
820865
ts.DType.INT8,
821-
weight_values_quantized,
866+
ph_values_quantized,
822867
name=out,
823868
)
824869
elif (
@@ -837,30 +882,53 @@ def preprocess( # noqa: C901
837882
weight_node
838883
)
839884

840-
weight_values_quantized = (
841-
weight_values / (input_node_scale * weight_node_scale)
885+
ph_values_quantized = (
886+
ph_values / (input_node_scale * weight_node_scale)
842887
).astype(np.int32)
843888

844889
tosa_fb.addConst(
845890
inputs[0].shape,
846891
ts.DType.INT32,
847-
weight_values_quantized,
892+
ph_values_quantized,
893+
name=out,
894+
)
895+
elif (
896+
consumer_node.target == exir_ops.edge.aten.convolution.default
897+
and list(consumer_node.users)[0].target == tosa_quant_utils.q_op
898+
):
899+
(
900+
input_node,
901+
weight_node,
902+
bias_node,
903+
) = consumer_node.all_input_nodes
904+
905+
input_node_scale, _ = getQuantNodeArgs(input_node)
906+
weight_node_scale, _ = getQuantNodeArgs(weight_node)
907+
908+
bias_scales = input_node_scale * weight_node_scale
909+
ph_values_quantized = (ph_values / bias_scales).astype(np.int32)
910+
911+
tosa_fb.addConst(
912+
inputs[0].shape,
913+
ts.DType.INT32,
914+
ph_values_quantized,
848915
name=out,
849916
)
850917
else:
851918
tosa_fb.addConst(
852-
inputs[0].shape, inputs[0].dtype, weight_values, name=out
919+
inputs[0].shape, inputs[0].dtype, ph_values, name=out
853920
)
921+
854922
elif out in edge_program.graph_signature.inputs_to_buffers:
855923
parameter_name = edge_program.graph_signature.inputs_to_buffers[
856924
node.name
857925
]
858926
p_data = edge_program.state_dict[parameter_name]
859927

860928
assert isinstance(p_data, torch.Tensor), "Expect Attr to be tensor"
861-
weight_values = p_data.detach().numpy()
929+
ph_values = p_data.detach().numpy()
862930
tosa_fb.addConst(
863-
inputs[0].shape, inputs[0].dtype, weight_values, name=out
931+
inputs[0].shape, inputs[0].dtype, ph_values, name=out
864932
)
865933
else:
866934
tensor = ts.TosaSerializerTensor(

backends/arm/test/test_models.py

Lines changed: 136 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,16 @@
99

1010
from enum import Enum
1111

12+
import numpy as np
13+
1214
import torch
1315

1416
TestList = {}
1517

18+
# Seed the RNG a convenient number so that we get the same random tests for each test each time
19+
seed = 42
20+
rng = np.random.default_rng(seed)
21+
1622

1723
def register_test(cls):
1824
TestList[cls.__name__] = cls()
@@ -103,15 +109,19 @@ class simple_linear(torch.nn.Module):
103109

104110
def __init__(self):
105111
super().__init__()
106-
torch.manual_seed(42)
112+
torch.manual_seed(seed)
107113
self.fc = torch.nn.Linear(20, 30)
108114

109115
def forward(self, x):
110116
x = self.fc(x)
111117
return x
112118

119+
"""Currenly we compare the quantized result directly with the floating point result, to avoid a noticable
120+
precision difference due to wide random numerical distribution, generate small random value range for
121+
convolution testing instead for now"""
122+
113123
@register_test
114-
class simple_conv2d(torch.nn.Module):
124+
class simple_conv2d_3x3_1x3x256x256_st1(torch.nn.Module):
115125
inputs = {
116126
TosaProfile.BI: (
117127
torch.ones(
@@ -129,6 +139,115 @@ def __init__(self):
129139
self.conv2d = torch.nn.Conv2d(
130140
in_channels=3, out_channels=10, kernel_size=3, stride=1
131141
)
142+
with torch.no_grad():
143+
self.conv2d.weight.copy_(
144+
torch.from_numpy(
145+
np.float32(rng.integers(low=1, high=4, size=(10, 3, 3, 3)))
146+
)
147+
)
148+
self.conv2d.bias.copy_(
149+
torch.from_numpy(np.float32(rng.integers(low=1, high=4, size=(10))))
150+
)
151+
152+
def forward(self, x):
153+
x = self.conv2d(x)
154+
return x
155+
156+
@register_test
157+
class simple_conv2d_1x1_1x2x128x128_st1(torch.nn.Module):
158+
inputs = {
159+
TosaProfile.BI: (
160+
torch.from_numpy(
161+
np.float32(rng.integers(low=10, high=20, size=(1, 2, 128, 128)))
162+
),
163+
),
164+
TosaProfile.MI: (
165+
torch.from_numpy(
166+
np.float32(rng.integers(low=10, high=20, size=(1, 2, 128, 128)))
167+
),
168+
),
169+
}
170+
171+
def __init__(self):
172+
super().__init__()
173+
self.conv2d = torch.nn.Conv2d(
174+
in_channels=2, out_channels=1, kernel_size=1, stride=1
175+
)
176+
with torch.no_grad():
177+
self.conv2d.weight.copy_(
178+
torch.from_numpy(
179+
np.float32(rng.integers(low=1, high=4, size=(1, 2, 1, 1)))
180+
)
181+
)
182+
self.conv2d.bias.copy_(
183+
torch.from_numpy(np.float32(rng.integers(low=1, high=4, size=(1))))
184+
)
185+
186+
def forward(self, x):
187+
x = self.conv2d(x)
188+
return x
189+
190+
@register_test
191+
class simple_conv2d_2x2_1x1x14x14_st2(torch.nn.Module):
192+
inputs = {
193+
TosaProfile.BI: (
194+
torch.from_numpy(
195+
np.float32(rng.integers(low=10, high=20, size=(1, 1, 14, 14)))
196+
),
197+
),
198+
TosaProfile.MI: (
199+
torch.from_numpy(
200+
np.float32(rng.integers(low=10, high=20, size=(1, 1, 14, 14)))
201+
),
202+
),
203+
}
204+
205+
def __init__(self):
206+
super().__init__()
207+
self.conv2d = torch.nn.Conv2d(
208+
in_channels=1, out_channels=1, kernel_size=2, stride=2
209+
)
210+
with torch.no_grad():
211+
self.conv2d.weight.copy_(
212+
torch.from_numpy(
213+
np.float32(rng.integers(low=1, high=4, size=(1, 1, 2, 2)))
214+
)
215+
)
216+
self.conv2d.bias.copy_(
217+
torch.from_numpy(np.float32(rng.integers(low=1, high=4, size=(1))))
218+
)
219+
220+
def forward(self, x):
221+
x = self.conv2d(x)
222+
return x
223+
224+
@register_test
225+
class simple_conv2d_5x5_3x2x128x128_st1(torch.nn.Module):
226+
inputs = {
227+
TosaProfile.BI: (
228+
torch.from_numpy(
229+
np.float32(rng.integers(low=10, high=20, size=(3, 2, 128, 128)))
230+
),
231+
),
232+
TosaProfile.MI: (
233+
torch.from_numpy(
234+
np.float32(rng.integers(low=10, high=20, size=(3, 2, 128, 128)))
235+
),
236+
),
237+
}
238+
239+
def __init__(self):
240+
super().__init__()
241+
self.conv2d = torch.nn.Conv2d(
242+
in_channels=2, out_channels=3, kernel_size=5, stride=1
243+
)
244+
with torch.no_grad():
245+
self.conv2d.weight.copy_(
246+
torch.from_numpy(
247+
np.float32(rng.integers(low=1, high=10, size=(1, 1, 5, 5)))
248+
)
249+
)
250+
self.conv2d.bias.copy_(torch.ones(3, dtype=torch.float))
132251

133252
def forward(self, x):
134253
x = self.conv2d(x)
@@ -137,8 +256,16 @@ def forward(self, x):
137256
@register_test
138257
class block_two_conv2d(torch.nn.Module):
139258
inputs = {
140-
TosaProfile.BI: (torch.ones(1, 3, 256, 256),),
141-
TosaProfile.MI: (torch.ones(1, 3, 256, 256),),
259+
TosaProfile.BI: (
260+
torch.from_numpy(
261+
np.float32(rng.integers(low=10, high=20, size=(1, 3, 256, 256)))
262+
),
263+
),
264+
TosaProfile.MI: (
265+
torch.from_numpy(
266+
np.float32(rng.integers(low=10, high=20, size=(1, 3, 256, 256)))
267+
),
268+
),
142269
}
143270

144271
def __init__(self):
@@ -149,6 +276,11 @@ def __init__(self):
149276
self.conv2d_2 = torch.nn.Conv2d(
150277
in_channels=10, out_channels=15, kernel_size=5, stride=1
151278
)
279+
with torch.no_grad():
280+
self.conv2d.weight.copy_(torch.ones(10, 3, 5, 5, dtype=torch.float))
281+
self.conv2d.bias.copy_(torch.ones(10))
282+
self.conv2d_2.weight.copy_(torch.ones(15, 10, 5, 5, dtype=torch.float))
283+
self.conv2d_2.bias.copy_(torch.ones(15))
152284

153285
def forward(self, x):
154286
x = self.conv2d(x)

0 commit comments

Comments
 (0)