@@ -76,27 +76,45 @@ Tensor& add_out(
76
76
int inp2_shape[kTensorDimensionLimit ];
77
77
int out_shape[kTensorDimensionLimit ];
78
78
79
- /* input shapes and output shapes */
80
- for (auto i = 0 ; i < a_size.size (); i++) {
81
- inp1_shape[i] = a_size[i];
82
- }
83
-
84
- for (auto i = 0 ; i < b_size.size (); i++) {
85
- inp2_shape[i] = b_size[i];
86
- }
87
-
88
- for (auto i = 0 ; i < out_size.size (); i++) {
89
- out_shape[i] = out_size[i];
90
- }
91
-
92
79
/* find broadcast*/
93
80
const bool a_is_broadcasted = !out.sizes ().equals (a.sizes ());
94
81
const bool b_is_broadcasted = !out.sizes ().equals (b.sizes ());
95
82
const bool broadcast = (a_is_broadcasted || b_is_broadcasted);
96
83
97
84
int max_dim = a.dim () > b.dim () ? a.dim () : b.dim ();
85
+ max_dim = out.dim () > max_dim ? out.dim () : max_dim;
98
86
99
- if (compute_type == ScalarType::Int) {
87
+ bool optimized = 1 ;
88
+
89
+ if ((a.dim () == 0 ) || (b.dim () == 0 )) {
90
+ optimized = 0 ;
91
+ }
92
+
93
+ if ((broadcast == 1 ) && (max_dim > kTensorDimensionLimit )) {
94
+ optimized = 0 ;
95
+ }
96
+
97
+ for (int i = 0 ; i < max_dim; i++) {
98
+ out_shape[i] = 1 ;
99
+ inp1_shape[i] = 1 ;
100
+ inp2_shape[i] = 1 ;
101
+ }
102
+
103
+ int offset_out = max_dim - out.dim ();
104
+ int offset_inp1 = max_dim - a.dim ();
105
+ int offset_inp2 = max_dim - b.dim ();
106
+
107
+ for (int i = 0 ; i < out.dim (); i++) {
108
+ out_shape[i + offset_out] = out.size (i);
109
+ }
110
+ for (int i = 0 ; i < a.dim (); i++) {
111
+ inp1_shape[i + offset_inp1] = a.size (i);
112
+ }
113
+ for (int i = 0 ; i < b.dim (); i++) {
114
+ inp2_shape[i + offset_inp2] = b.size (i);
115
+ }
116
+
117
+ if ((compute_type == ScalarType::Int) && (optimized)){
100
118
const int * const inp1_data = a.const_data_ptr <int >();
101
119
const int * const inp2_data = b.const_data_ptr <int >();
102
120
int * const out_data = out.mutable_data_ptr <int >();
@@ -117,7 +135,7 @@ Tensor& add_out(
117
135
xa_nn_elm_add_32x32_32 (
118
136
out_data, inp1_data, inp2_data, alpha_val, out.numel ());
119
137
}
120
- } else if (compute_type == ScalarType::Float) {
138
+ } else if (( compute_type == ScalarType::Float) && (optimized) ) {
121
139
const float * const inp1_data = a.const_data_ptr <float >();
122
140
const float * const inp2_data = b.const_data_ptr <float >();
123
141
float * const out_data = out.mutable_data_ptr <float >();
0 commit comments