@@ -476,23 +476,39 @@ def preprocess( # noqa: C901
476
476
elif exir_ops .edge .aten .convolution .default == node .target :
477
477
input , weight , bias , stride , pad , dilation , _ , _ , group = inputs
478
478
479
+ # Currently only int8 is supported in quantized types.
480
+ actual_out_type = ts .DType .INT8 if is_quant_node else outp .dtype
481
+
479
482
## Transpose input tensor to NHWC_Order for TOSA
480
483
NHWC_Order = [0 , 2 , 3 , 1 ]
481
484
input_transposed = transpose_helper (
482
- tosa_fb , input , NHWC_Order , outp . dtype
485
+ tosa_fb , input , NHWC_Order , actual_out_type
483
486
)
484
487
485
- ## CONV2DOp
488
+ # Get the attributes of convolution.
486
489
attr = ts .TosaSerializerAttribute ()
487
- # PAD
488
490
pad_attr = [val for val in pad .special for _ in (0 , 1 )]
489
- # Stride
490
491
stride_attr = stride .special
491
- # Dilation
492
492
dilation_attr = dilation .special
493
493
attr .ConvAttribute (pad_attr , stride_attr , dilation_attr , 0 , 0 )
494
494
495
+ # Non-bias case.
496
+ if len (node .all_input_nodes ) == 2 :
497
+ # Create a zero bias tensor if not presented
498
+ out_channels = weight .shape [0 ]
499
+ bias_name = "bias" + node .name .split ("default" , 1 )[1 ]
500
+ bias = tosa_fb .addConst (
501
+ [out_channels ],
502
+ ts .DType .INT32 if is_quant_node else outp .dtype ,
503
+ [0 ] * out_channels ,
504
+ name = bias_name ,
505
+ )
506
+
495
507
if group .number > 1 :
508
+ assert (
509
+ is_quant_node is False
510
+ ), "quantized depthwise convolution is not supported yet in BI mode"
511
+
496
512
# Transpose weight to [KH, KW, C, M]
497
513
weight_HWCM_Order = [2 , 3 , 0 , 1 ]
498
514
weight_transposed = transpose_helper (
@@ -523,14 +539,17 @@ def preprocess( # noqa: C901
523
539
# Transpose weight to [OC, H, W, IC]
524
540
weight_CHWC_Order = [0 , 2 , 3 , 1 ]
525
541
weight_transposed = transpose_helper (
526
- tosa_fb , weight , weight_CHWC_Order , outp . dtype
542
+ tosa_fb , weight , weight_CHWC_Order , actual_out_type
527
543
)
528
544
529
545
## TOSA output shape is [NHWO]
530
546
NHWO_Order = [0 , 2 , 3 , 1 ]
531
547
out_shape_TOSA_CONV2D = [outp .shape [i ] for i in NHWO_Order ]
548
+
549
+ # The output type is int32 when input type is int8.
532
550
conv2d_res = tosa_fb .addIntermediate (
533
- out_shape_TOSA_CONV2D , outp .dtype
551
+ out_shape_TOSA_CONV2D ,
552
+ ts .DType .INT32 if is_quant_node else outp .dtype ,
534
553
)
535
554
tosa_fb .addOperator (
536
555
TosaOp .Op ().CONV2D ,
@@ -547,6 +566,24 @@ def preprocess( # noqa: C901
547
566
NOHW_Order = [0 , 3 , 1 , 2 ]
548
567
attr_output_transpose = ts .TosaSerializerAttribute ()
549
568
attr_output_transpose .TransposeAttribute (NOHW_Order )
569
+
570
+ # For quantized convolution, rescale the output value back to the same
571
+ # integer value domain of the next op. Otherwise return float32 output.
572
+ if is_quant_node :
573
+ # Get scale_factor from input, weight, and output.
574
+ _ , input_scale , _ , _ , _ , _ = getNodeArgs (node .args [0 ])
575
+ _ , weight_scale , _ , _ , _ , _ = getNodeArgs (node .args [1 ])
576
+ _ , output_scale , _ , _ , _ , _ = getNodeArgs (list (node .users )[0 ])
577
+
578
+ conv2d_res = tosa_quant_utils .buildRescaleOpConvOutput (
579
+ tosa_fb ,
580
+ conv2d_res ,
581
+ actual_out_type ,
582
+ input_scale ,
583
+ weight_scale ,
584
+ output_scale ,
585
+ )
586
+
550
587
tosa_fb .addOperator (
551
588
TosaOp .Op ().TRANSPOSE ,
552
589
[conv2d_res .name ],
@@ -802,7 +839,7 @@ def preprocess( # noqa: C901
802
839
p_data = edge_program .state_dict [parameter_name ]
803
840
804
841
assert isinstance (p_data , torch .Tensor ), "Expect Attr to be tensor"
805
- weight_values = p_data .detach ().numpy ()
842
+ parameter_values = p_data .detach ().numpy ()
806
843
807
844
# Check if they're for quantized nodes
808
845
consumer_node = list (node .users )[0 ]
@@ -811,14 +848,14 @@ def preprocess( # noqa: C901
811
848
consumer_node
812
849
)
813
850
814
- weight_values_quantized = (
815
- (weight_values / weight_node_scale .number )
851
+ parameter_values_quantized = (
852
+ (parameter_values / weight_node_scale .number )
816
853
+ weight_node_zp .number
817
854
).astype (np .int8 )
818
855
tosa_fb .addConst (
819
856
inputs [0 ].shape ,
820
857
ts .DType .INT8 ,
821
- weight_values_quantized ,
858
+ parameter_values_quantized ,
822
859
name = out ,
823
860
)
824
861
elif (
@@ -837,30 +874,55 @@ def preprocess( # noqa: C901
837
874
weight_node
838
875
)
839
876
840
- weight_values_quantized = (
841
- weight_values / (input_node_scale * weight_node_scale )
877
+ parameter_values_quantized = (
878
+ parameter_values / (input_node_scale * weight_node_scale )
879
+ ).astype (np .int32 )
880
+
881
+ tosa_fb .addConst (
882
+ inputs [0 ].shape ,
883
+ ts .DType .INT32 ,
884
+ parameter_values_quantized ,
885
+ name = out ,
886
+ )
887
+ elif (
888
+ consumer_node .target == exir_ops .edge .aten .convolution .default
889
+ and list (consumer_node .users )[0 ].target == tosa_quant_utils .q_op
890
+ ):
891
+ (
892
+ input_node ,
893
+ weight_node ,
894
+ bias_node ,
895
+ ) = consumer_node .all_input_nodes
896
+
897
+ input_node_scale , _ = getQuantNodeArgs (input_node )
898
+ weight_node_scale , _ = getQuantNodeArgs (weight_node )
899
+
900
+ bias_scales = input_node_scale * weight_node_scale
901
+ parameter_values_quantized = (
902
+ parameter_values / bias_scales
842
903
).astype (np .int32 )
843
904
844
905
tosa_fb .addConst (
845
906
inputs [0 ].shape ,
846
907
ts .DType .INT32 ,
847
- weight_values_quantized ,
908
+ parameter_values_quantized ,
848
909
name = out ,
849
910
)
850
911
else :
851
912
tosa_fb .addConst (
852
- inputs [0 ].shape , inputs [0 ].dtype , weight_values , name = out
913
+ inputs [0 ].shape , inputs [0 ].dtype , parameter_values , name = out
853
914
)
915
+
854
916
elif out in edge_program .graph_signature .inputs_to_buffers :
855
917
parameter_name = edge_program .graph_signature .inputs_to_buffers [
856
918
node .name
857
919
]
858
920
p_data = edge_program .state_dict [parameter_name ]
859
921
860
922
assert isinstance (p_data , torch .Tensor ), "Expect Attr to be tensor"
861
- weight_values = p_data .detach ().numpy ()
923
+ buffer_values = p_data .detach ().numpy ()
862
924
tosa_fb .addConst (
863
- inputs [0 ].shape , inputs [0 ].dtype , weight_values , name = out
925
+ inputs [0 ].shape , inputs [0 ].dtype , buffer_values , name = out
864
926
)
865
927
else :
866
928
tensor = ts .TosaSerializerTensor (
0 commit comments