@@ -246,6 +246,10 @@ def preprocess( # noqa: C901
246
246
if path is None :
247
247
path = tempfile .mkdtemp (prefix = "arm_tosa_" )
248
248
249
+ # Verify if this is a quantized model ahead so that the tensor data type of
250
+ # tosa operations during lowering can be easier determined.
251
+ is_quantized_model = tosa_quant_utils .isQuantizedModel (edge_program .graph )
252
+
249
253
# Converted output for this subgraph, serializer needs path early as it emits
250
254
# const data directly. Path created and data written only in debug builds.
251
255
tosa_fb = ts .TosaSerializer (path )
@@ -476,10 +480,15 @@ def preprocess( # noqa: C901
476
480
elif exir_ops .edge .aten .convolution .default == node .target :
477
481
input , weight , bias , stride , pad , dilation , _ , _ , group = inputs
478
482
483
+ # Currently only int8 is supported in quantized types.
484
+ actual_out_type = (
485
+ ts .DType .INT8 if is_quantized_model else outp .dtype
486
+ )
487
+
479
488
## Transpose input tensor to NHWC_Order for TOSA
480
489
NHWC_Order = [0 , 2 , 3 , 1 ]
481
490
input_transposed = transpose_helper (
482
- tosa_fb , input , NHWC_Order , outp . dtype
491
+ tosa_fb , input , NHWC_Order , actual_out_type
483
492
)
484
493
485
494
## CONV2DOp
@@ -523,14 +532,17 @@ def preprocess( # noqa: C901
523
532
# Transpose weight to [OC, H, W, IC]
524
533
weight_CHWC_Order = [0 , 2 , 3 , 1 ]
525
534
weight_transposed = transpose_helper (
526
- tosa_fb , weight , weight_CHWC_Order , outp . dtype
535
+ tosa_fb , weight , weight_CHWC_Order , actual_out_type
527
536
)
528
537
529
538
## TOSA output shape is [NHWO]
530
539
NHWO_Order = [0 , 2 , 3 , 1 ]
531
540
out_shape_TOSA_CONV2D = [outp .shape [i ] for i in NHWO_Order ]
541
+
542
+ # The output type is int32 when input type is int8.
532
543
conv2d_res = tosa_fb .addIntermediate (
533
- out_shape_TOSA_CONV2D , outp .dtype
544
+ out_shape_TOSA_CONV2D ,
545
+ ts .DType .INT32 if is_quant_node else outp .dtype ,
534
546
)
535
547
tosa_fb .addOperator (
536
548
TosaOp .Op ().CONV2D ,
@@ -547,12 +559,45 @@ def preprocess( # noqa: C901
547
559
NOHW_Order = [0 , 3 , 1 , 2 ]
548
560
attr_output_transpose = ts .TosaSerializerAttribute ()
549
561
attr_output_transpose .TransposeAttribute (NOHW_Order )
550
- tosa_fb .addOperator (
551
- TosaOp .Op ().TRANSPOSE ,
552
- [conv2d_res .name ],
553
- [outp .name ],
554
- attr_output_transpose ,
555
- )
562
+
563
+ if len (node .all_input_nodes ) == 3 :
564
+ input_node , weight_node , bias_node = node .all_input_nodes
565
+ else :
566
+ raise AssertionError (
567
+ "non-biased conv2d is not supported for now"
568
+ )
569
+
570
+ output_node = list (node .users )[0 ]
571
+
572
+ # For quantized convolution, rescale the output value back to the same
573
+ # integer value domain of the next op. Otherwise return float32 output.
574
+ if is_quant_node :
575
+ # Get scale_factor from input, weight, and output.
576
+ _ , input_scale , _ , _ , _ , _ = getNodeArgs (input_node )
577
+ _ , weight_scale , _ , _ , _ , _ = getNodeArgs (weight_node )
578
+ _ , output_scale , _ , _ , _ , _ = getNodeArgs (output_node )
579
+ rescaled_conv2d_res = tosa_quant_utils .buildRescaleOpConvOutput (
580
+ tosa_fb ,
581
+ conv2d_res ,
582
+ actual_out_type ,
583
+ input_scale ,
584
+ weight_scale ,
585
+ output_scale ,
586
+ )
587
+ tosa_fb .addOperator (
588
+ TosaOp .Op ().TRANSPOSE ,
589
+ [rescaled_conv2d_res .name ],
590
+ [outp .name ],
591
+ attr_output_transpose ,
592
+ )
593
+ else :
594
+ tosa_fb .addOperator (
595
+ TosaOp .Op ().TRANSPOSE ,
596
+ [conv2d_res .name ],
597
+ [outp .name ],
598
+ attr_output_transpose ,
599
+ )
600
+
556
601
elif exir_ops .edge .aten .div .Tensor == node .target :
557
602
# Div is implemented as x/y = x*1/y
558
603
recip = tosa_fb .addIntermediate (inputs [1 ].shape , inputs [1 ].dtype )
@@ -802,7 +847,7 @@ def preprocess( # noqa: C901
802
847
p_data = edge_program .state_dict [parameter_name ]
803
848
804
849
assert isinstance (p_data , torch .Tensor ), "Expect Attr to be tensor"
805
- weight_values = p_data .detach ().numpy ()
850
+ ph_values = p_data .detach ().numpy ()
806
851
807
852
# Check if they're for quantized nodes
808
853
consumer_node = list (node .users )[0 ]
@@ -811,14 +856,14 @@ def preprocess( # noqa: C901
811
856
consumer_node
812
857
)
813
858
814
- weight_values_quantized = (
815
- (weight_values / weight_node_scale .number )
859
+ ph_values_quantized = (
860
+ (ph_values / weight_node_scale .number )
816
861
+ weight_node_zp .number
817
862
).astype (np .int8 )
818
863
tosa_fb .addConst (
819
864
inputs [0 ].shape ,
820
865
ts .DType .INT8 ,
821
- weight_values_quantized ,
866
+ ph_values_quantized ,
822
867
name = out ,
823
868
)
824
869
elif (
@@ -837,30 +882,53 @@ def preprocess( # noqa: C901
837
882
weight_node
838
883
)
839
884
840
- weight_values_quantized = (
841
- weight_values / (input_node_scale * weight_node_scale )
885
+ ph_values_quantized = (
886
+ ph_values / (input_node_scale * weight_node_scale )
842
887
).astype (np .int32 )
843
888
844
889
tosa_fb .addConst (
845
890
inputs [0 ].shape ,
846
891
ts .DType .INT32 ,
847
- weight_values_quantized ,
892
+ ph_values_quantized ,
893
+ name = out ,
894
+ )
895
+ elif (
896
+ consumer_node .target == exir_ops .edge .aten .convolution .default
897
+ and list (consumer_node .users )[0 ].target == tosa_quant_utils .q_op
898
+ ):
899
+ (
900
+ input_node ,
901
+ weight_node ,
902
+ bias_node ,
903
+ ) = consumer_node .all_input_nodes
904
+
905
+ input_node_scale , _ = getQuantNodeArgs (input_node )
906
+ weight_node_scale , _ = getQuantNodeArgs (weight_node )
907
+
908
+ bias_scales = input_node_scale * weight_node_scale
909
+ ph_values_quantized = (ph_values / bias_scales ).astype (np .int32 )
910
+
911
+ tosa_fb .addConst (
912
+ inputs [0 ].shape ,
913
+ ts .DType .INT32 ,
914
+ ph_values_quantized ,
848
915
name = out ,
849
916
)
850
917
else :
851
918
tosa_fb .addConst (
852
- inputs [0 ].shape , inputs [0 ].dtype , weight_values , name = out
919
+ inputs [0 ].shape , inputs [0 ].dtype , ph_values , name = out
853
920
)
921
+
854
922
elif out in edge_program .graph_signature .inputs_to_buffers :
855
923
parameter_name = edge_program .graph_signature .inputs_to_buffers [
856
924
node .name
857
925
]
858
926
p_data = edge_program .state_dict [parameter_name ]
859
927
860
928
assert isinstance (p_data , torch .Tensor ), "Expect Attr to be tensor"
861
- weight_values = p_data .detach ().numpy ()
929
+ ph_values = p_data .detach ().numpy ()
862
930
tosa_fb .addConst (
863
- inputs [0 ].shape , inputs [0 ].dtype , weight_values , name = out
931
+ inputs [0 ].shape , inputs [0 ].dtype , ph_values , name = out
864
932
)
865
933
else :
866
934
tensor = ts .TosaSerializerTensor (
0 commit comments