@@ -38,7 +38,7 @@ inline bool IsExpand(std::vector<int64_t>& filter_dim,
3838 std::vector<int >& dilations) {
3939 bool filter_1 = true , strides_1 = true , padding_0 = true , dilation_1 = true ;
4040 for (size_t j = 0 ; j < strides.size (); ++j) {
41- filter_1 = filter_1 && (static_cast <int >(filter_dim[j]) == 1 );
41+ filter_1 = filter_1 && (static_cast <int >(filter_dim[j + 2 ]) == 1 );
4242 strides_1 = strides_1 && (strides[j] == 1 );
4343 padding_0 = padding_0 && (paddings[j] == 0 );
4444 dilation_1 = dilation_1 && (dilations[j] == 1 );
@@ -91,32 +91,28 @@ class GemmConvKernel : public framework::OpKernel<T> {
9191
9292 const int batch_size = static_cast <int >(input->dims ()[0 ]);
9393
94- // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
94+ // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
9595 std::vector<int64_t > filter_shape_vec (framework::vectorize (filter.dims ()));
96- filter_shape_vec.erase (filter_shape_vec.begin (),
97- filter_shape_vec.begin () + 2 );
98-
99- // output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
96+ // output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
10097 std::vector<int64_t > output_shape_vec (framework::vectorize (output->dims ()));
101- output_shape_vec.erase (output_shape_vec.begin (),
102- output_shape_vec.begin () + 2 );
10398
10499 // use col_shape in the im2col calculation
105100 // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
106101 // o_h, o_w}
107- std::vector<int64_t > col_shape_vec;
108- col_shape_vec.push_back (input->dims ()[1 ] / groups);
109- col_shape_vec.insert (col_shape_vec.end (), filter_shape_vec.begin (),
110- filter_shape_vec.end ());
111- col_shape_vec.insert (col_shape_vec.end (), output_shape_vec.begin (),
112- output_shape_vec.end ());
102+ size_t data_dim = filter_shape_vec.size () - 2 ;
103+ std::vector<int64_t > col_shape_vec (1 + 2 * data_dim);
104+ col_shape_vec[0 ] = input->dims ()[1 ] / groups;
105+ for (size_t j = 0 ; j < data_dim; ++j) {
106+ col_shape_vec[j + 1 ] = filter_shape_vec[j + 2 ];
107+ col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2 ];
108+ }
113109 framework::DDim col_shape (framework::make_ddim (col_shape_vec));
114110
115111 // use col_matrix_shape in the gemm calculation
116112 // size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d *
117113 // o_h * o_w)
118114 framework::DDim col_matrix_shape =
119- framework::flatten_to_2d (col_shape, filter_shape_vec. size () + 1 );
115+ framework::flatten_to_2d (col_shape, data_dim + 1 );
120116
121117 bool is_expand = IsExpand (filter_shape_vec, strides, paddings, dilations);
122118 Tensor col;
@@ -159,13 +155,13 @@ class GemmConvKernel : public framework::OpKernel<T> {
159155 col.ShareDataWith (in_slice);
160156 col_matrix.ShareDataWith (col);
161157 col_matrix.Resize (col_matrix_shape);
162- } else if (filter_shape_vec. size () == 2 ) {
158+ } else if (data_dim == 2U ) {
163159 // im2col
164160 im2col (context.device_context (), in_slice, dilations, strides,
165161 std::vector<int >{paddings[0 ], paddings[1 ], paddings[0 ],
166162 paddings[1 ]},
167163 &col);
168- } else if (filter_shape_vec. size () == 3 ) {
164+ } else if (data_dim == 3U ) {
169165 // vol2col
170166 vol2col (context.device_context (), in_slice, dilations, strides,
171167 paddings, &col);
@@ -206,34 +202,30 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
206202
207203 const int batch_size = static_cast <int >(input->dims ()[0 ]);
208204
209- // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
205+ // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
210206 std::vector<int64_t > filter_shape_vec (framework::vectorize (filter.dims ()));
211- filter_shape_vec.erase (filter_shape_vec.begin (),
212- filter_shape_vec.begin () + 2 );
213-
214- // output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
207+ // output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
215208 std::vector<int64_t > output_shape_vec (
216209 framework::vectorize (output_grad->dims ()));
217- output_shape_vec.erase (output_shape_vec.begin (),
218- output_shape_vec.begin () + 2 );
219210
220211 // use col_shape in the im2col calculation
221212 // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
222213 // o_h, o_w}
223- std::vector<int64_t > col_shape_vec;
224- col_shape_vec.push_back (input->dims ()[1 ] / groups);
225- col_shape_vec.insert (col_shape_vec.end (), filter_shape_vec.begin (),
226- filter_shape_vec.end ());
227- col_shape_vec.insert (col_shape_vec.end (), output_shape_vec.begin (),
228- output_shape_vec.end ());
214+ size_t data_dim = filter_shape_vec.size () - 2 ;
215+ std::vector<int64_t > col_shape_vec (1 + 2 * data_dim);
216+ col_shape_vec[0 ] = input->dims ()[1 ] / groups;
217+ for (size_t j = 0 ; j < data_dim; ++j) {
218+ col_shape_vec[j + 1 ] = filter_shape_vec[j + 2 ];
219+ col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2 ];
220+ }
229221 framework::DDim col_shape (framework::make_ddim (col_shape_vec));
230222
231223 // use col_matrix_shape in the gemm calculation
232224 // size: (i_c/g * k_h * k_w, o_h * o_w)
233225 // or
234226 // (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w)
235227 framework::DDim col_matrix_shape =
236- framework::flatten_to_2d (col_shape, filter_shape_vec. size () + 1 );
228+ framework::flatten_to_2d (col_shape, data_dim + 1 );
237229
238230 framework::DDim input_shape = framework::slice_ddim (
239231 input->dims (), 1 , static_cast <int >(input->dims ().size ()));
@@ -294,12 +286,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
294286 out_grad_slice, false , T (1.0 ), &col_matrix,
295287 T (0.0 ));
296288
297- if (is_expand && filter_shape_vec. size () == 2 ) {
289+ if (is_expand && data_dim == 2U ) {
298290 col2im (context.device_context (), col, dilations, strides,
299291 std::vector<int >{paddings[0 ], paddings[1 ], paddings[0 ],
300292 paddings[1 ]},
301293 &in_grad_slice);
302- } else if (is_expand && filter_shape_vec. size () == 3 ) {
294+ } else if (is_expand && data_dim == 3U ) {
303295 col2vol (context.device_context (), col, dilations, strides, paddings,
304296 &in_grad_slice);
305297 }
@@ -328,12 +320,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
328320 col.ShareDataWith (in_slice);
329321 col_matrix.ShareDataWith (col);
330322 col_matrix.Resize (col_matrix_shape);
331- } else if (filter_shape_vec. size () == 2 ) {
323+ } else if (data_dim == 2U ) {
332324 im2col (context.device_context (), in_slice, dilations, strides,
333325 std::vector<int >{paddings[0 ], paddings[1 ], paddings[0 ],
334326 paddings[1 ]},
335327 &col);
336- } else if (filter_shape_vec. size () == 3 ) {
328+ } else if (data_dim == 3U ) {
337329 vol2col (context.device_context (), in_slice, dilations, strides,
338330 paddings, &col);
339331 }
0 commit comments