Skip to content

Commit 116f737

Browse files
committed
Add initial lowering of aten.convolution to tosa.conv2d support
1 parent 94119f6 commit 116f737

File tree

5 files changed

+291
-48
lines changed

5 files changed

+291
-48
lines changed

backends/arm/arm_backend.py

Lines changed: 79 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -476,23 +476,39 @@ def preprocess( # noqa: C901
476476
elif exir_ops.edge.aten.convolution.default == node.target:
477477
input, weight, bias, stride, pad, dilation, _, _, group = inputs
478478

479+
# Currently only int8 is supported in quantized types.
480+
actual_out_type = ts.DType.INT8 if is_quant_node else outp.dtype
481+
479482
## Transpose input tensor to NHWC_Order for TOSA
480483
NHWC_Order = [0, 2, 3, 1]
481484
input_transposed = transpose_helper(
482-
tosa_fb, input, NHWC_Order, outp.dtype
485+
tosa_fb, input, NHWC_Order, actual_out_type
483486
)
484487

485-
## CONV2DOp
488+
# Get the attributes of convolution.
486489
attr = ts.TosaSerializerAttribute()
487-
# PAD
488490
pad_attr = [val for val in pad.special for _ in (0, 1)]
489-
# Stride
490491
stride_attr = stride.special
491-
# Dilation
492492
dilation_attr = dilation.special
493493
attr.ConvAttribute(pad_attr, stride_attr, dilation_attr, 0, 0)
494494

495+
# Non-bias case.
496+
if len(node.all_input_nodes) == 2:
497+
# Create a zero bias tensor if not presented
498+
out_channels = weight.shape[0]
499+
bias_name = "bias" + node.name.split("default", 1)[1]
500+
bias = tosa_fb.addConst(
501+
[out_channels],
502+
ts.DType.INT32 if is_quant_node else outp.dtype,
503+
[0] * out_channels,
504+
name=bias_name,
505+
)
506+
495507
if group.number > 1:
508+
assert (
509+
is_quant_node is False
510+
), "quantized depthwise convolution is not supported yet in BI mode"
511+
496512
# Transpose weight to [KH, KW, C, M]
497513
weight_HWCM_Order = [2, 3, 0, 1]
498514
weight_transposed = transpose_helper(
@@ -523,14 +539,17 @@ def preprocess( # noqa: C901
523539
# Transpose weight to [OC, H, W, IC]
524540
weight_CHWC_Order = [0, 2, 3, 1]
525541
weight_transposed = transpose_helper(
526-
tosa_fb, weight, weight_CHWC_Order, outp.dtype
542+
tosa_fb, weight, weight_CHWC_Order, actual_out_type
527543
)
528544

529545
## TOSA output shape is [NHWO]
530546
NHWO_Order = [0, 2, 3, 1]
531547
out_shape_TOSA_CONV2D = [outp.shape[i] for i in NHWO_Order]
548+
549+
# The output type is int32 when input type is int8.
532550
conv2d_res = tosa_fb.addIntermediate(
533-
out_shape_TOSA_CONV2D, outp.dtype
551+
out_shape_TOSA_CONV2D,
552+
ts.DType.INT32 if is_quant_node else outp.dtype,
534553
)
535554
tosa_fb.addOperator(
536555
TosaOp.Op().CONV2D,
@@ -547,6 +566,24 @@ def preprocess( # noqa: C901
547566
NOHW_Order = [0, 3, 1, 2]
548567
attr_output_transpose = ts.TosaSerializerAttribute()
549568
attr_output_transpose.TransposeAttribute(NOHW_Order)
569+
570+
# For quantized convolution, rescale the output value back to the same
571+
# integer value domain of the next op. Otherwise return float32 output.
572+
if is_quant_node:
573+
# Get scale_factor from input, weight, and output.
574+
_, input_scale, _, _, _, _ = getNodeArgs(node.args[0])
575+
_, weight_scale, _, _, _, _ = getNodeArgs(node.args[1])
576+
_, output_scale, _, _, _, _ = getNodeArgs(list(node.users)[0])
577+
578+
conv2d_res = tosa_quant_utils.buildRescaleOpConvOutput(
579+
tosa_fb,
580+
conv2d_res,
581+
actual_out_type,
582+
input_scale,
583+
weight_scale,
584+
output_scale,
585+
)
586+
550587
tosa_fb.addOperator(
551588
TosaOp.Op().TRANSPOSE,
552589
[conv2d_res.name],
@@ -802,7 +839,7 @@ def preprocess( # noqa: C901
802839
p_data = edge_program.state_dict[parameter_name]
803840

804841
assert isinstance(p_data, torch.Tensor), "Expect Attr to be tensor"
805-
weight_values = p_data.detach().numpy()
842+
parameter_values = p_data.detach().numpy()
806843

807844
# Check if they're for quantized nodes
808845
consumer_node = list(node.users)[0]
@@ -811,14 +848,14 @@ def preprocess( # noqa: C901
811848
consumer_node
812849
)
813850

814-
weight_values_quantized = (
815-
(weight_values / weight_node_scale.number)
851+
parameter_values_quantized = (
852+
(parameter_values / weight_node_scale.number)
816853
+ weight_node_zp.number
817854
).astype(np.int8)
818855
tosa_fb.addConst(
819856
inputs[0].shape,
820857
ts.DType.INT8,
821-
weight_values_quantized,
858+
parameter_values_quantized,
822859
name=out,
823860
)
824861
elif (
@@ -837,30 +874,55 @@ def preprocess( # noqa: C901
837874
weight_node
838875
)
839876

840-
weight_values_quantized = (
841-
weight_values / (input_node_scale * weight_node_scale)
877+
parameter_values_quantized = (
878+
parameter_values / (input_node_scale * weight_node_scale)
879+
).astype(np.int32)
880+
881+
tosa_fb.addConst(
882+
inputs[0].shape,
883+
ts.DType.INT32,
884+
parameter_values_quantized,
885+
name=out,
886+
)
887+
elif (
888+
consumer_node.target == exir_ops.edge.aten.convolution.default
889+
and list(consumer_node.users)[0].target == tosa_quant_utils.q_op
890+
):
891+
(
892+
input_node,
893+
weight_node,
894+
bias_node,
895+
) = consumer_node.all_input_nodes
896+
897+
input_node_scale, _ = getQuantNodeArgs(input_node)
898+
weight_node_scale, _ = getQuantNodeArgs(weight_node)
899+
900+
bias_scales = input_node_scale * weight_node_scale
901+
parameter_values_quantized = (
902+
parameter_values / bias_scales
842903
).astype(np.int32)
843904

844905
tosa_fb.addConst(
845906
inputs[0].shape,
846907
ts.DType.INT32,
847-
weight_values_quantized,
908+
parameter_values_quantized,
848909
name=out,
849910
)
850911
else:
851912
tosa_fb.addConst(
852-
inputs[0].shape, inputs[0].dtype, weight_values, name=out
913+
inputs[0].shape, inputs[0].dtype, parameter_values, name=out
853914
)
915+
854916
elif out in edge_program.graph_signature.inputs_to_buffers:
855917
parameter_name = edge_program.graph_signature.inputs_to_buffers[
856918
node.name
857919
]
858920
p_data = edge_program.state_dict[parameter_name]
859921

860922
assert isinstance(p_data, torch.Tensor), "Expect Attr to be tensor"
861-
weight_values = p_data.detach().numpy()
923+
buffer_values = p_data.detach().numpy()
862924
tosa_fb.addConst(
863-
inputs[0].shape, inputs[0].dtype, weight_values, name=out
925+
inputs[0].shape, inputs[0].dtype, buffer_values, name=out
864926
)
865927
else:
866928
tensor = ts.TosaSerializerTensor(

0 commit comments

Comments
 (0)