@@ -9,6 +9,92 @@ namespace ops {
99
1010namespace {
1111
12+ template <typename T>
13+ inline void roi_align_single_framework_forward (
14+ const T* input,
15+ const int count,
16+ int channels,
17+ int height,
18+ int width,
19+ int pooled_height,
20+ int pooled_width,
21+ int roi_bin_grid_h,
22+ int roi_bin_grid_w,
23+ const std::vector<detail::PreCalc<T>>& pre_calc,
24+ T* output) {
25+ for (int c = 0 ; c < channels; c++) {
26+ const T* offset_input = input + c * height * width;
27+ int pre_calc_index = 0 ;
28+
29+ for (int ph = 0 ; ph < pooled_height; ph++) {
30+ for (int pw = 0 ; pw < pooled_width; pw++) {
31+ int index = c * pooled_height * pooled_width + ph * pooled_width + pw;
32+
33+ T output_val = 0 .;
34+ for (int iy = 0 ; iy < roi_bin_grid_h; iy++) {
35+ for (int ix = 0 ; ix < roi_bin_grid_w; ix++) {
36+ detail::PreCalc<T> pc = pre_calc[pre_calc_index];
37+ output_val += pc.w1 * offset_input[pc.pos1 ] +
38+ pc.w2 * offset_input[pc.pos2 ] +
39+ pc.w3 * offset_input[pc.pos3 ] + pc.w4 * offset_input[pc.pos4 ];
40+
41+ pre_calc_index += 1 ;
42+ }
43+ }
44+ output_val /= count; // Average pooling
45+
46+ output[index] = output_val;
47+ } // for pw
48+ } // for ph
49+ } // for c
50+ }
51+
52+ template <typename T>
53+ inline void roi_align_single_framework_channels_last_forward (
54+ const T* input,
55+ const int count,
56+ int channels,
57+ int height,
58+ int width,
59+ int pooled_height,
60+ int pooled_width,
61+ int roi_bin_grid_h,
62+ int roi_bin_grid_w,
63+ const std::vector<detail::PreCalc<T>>& pre_calc,
64+ T* output) {
65+ // for 'normal' size of channels, should be L1 fit;
66+ // otherwise consider blocking on channels.
67+ int pre_calc_index = 0 ;
68+ for (int ph = 0 ; ph < pooled_height; ph++) {
69+ for (int pw = 0 ; pw < pooled_width; pw++) {
70+ T* out = output + (ph * pooled_width + pw) * channels;
71+
72+ // pass I: do accumulation
73+ for (int iy = 0 ; iy < roi_bin_grid_h; iy++) {
74+ for (int ix = 0 ; ix < roi_bin_grid_w; ix++) {
75+ detail::PreCalc<T> pc = pre_calc[pre_calc_index];
76+ const T* in1 = input + pc.pos1 * channels;
77+ const T* in2 = input + pc.pos2 * channels;
78+ const T* in3 = input + pc.pos3 * channels;
79+ const T* in4 = input + pc.pos4 * channels;
80+
81+ #pragma omp simd
82+ for (int c = 0 ; c < channels; c++) {
83+ out[c] += pc.w1 * in1[c] + pc.w2 * in2[c] + pc.w3 * in3[c] + pc.w4 * in4[c];
84+ }
85+ pre_calc_index += 1 ;
86+ }
87+ }
88+
89+ // pass II: do average
90+ #pragma omp simd
91+ for (int c = 0 ; c < channels; c++) {
92+ out[c] /= count;
93+ }
94+ } // for pw
95+ } // for ph
96+ }
97+
1298template <typename T>
1399void roi_align_forward_kernel_impl (
14100 int n_rois,
@@ -22,13 +108,12 @@ void roi_align_forward_kernel_impl(
22108 int sampling_ratio,
23109 bool aligned,
24110 const T* rois,
25- T* output) {
111+ T* output,
112+ bool is_channels_last) {
26113 // (n, c, ph, pw) is an element in the pooled output
27114 // can be parallelized using omp
28115 at::parallel_for (0 , n_rois, 1 , [&](int begin, int end) {
29116 for (int n = begin; n < end; n++) {
30- int index_n = n * channels * pooled_width * pooled_height;
31-
32117 const T* offset_rois = rois + n * 5 ;
33118 int roi_batch_ind = offset_rois[0 ];
34119
@@ -78,33 +163,33 @@ void roi_align_forward_kernel_impl(
78163 roi_bin_grid_w,
79164 pre_calc);
80165
81- for ( int c = 0 ; c < channels; c++ ) {
82- int index_n_c = index_n + c * pooled_width * pooled_height;
83- const T* offset_input =
84- input + (roi_batch_ind * channels + c) * height * width;
85- int pre_calc_index = 0 ;
86-
87- for ( int ph = 0 ; ph < pooled_height; ph++) {
88- for ( int pw = 0 ; pw < pooled_width; pw++) {
89- int index = index_n_c + ph * pooled_width + pw;
90-
91- T output_val = 0 .;
92- for ( int iy = 0 ; iy < roi_bin_grid_h; iy++) {
93- for ( int ix = 0 ; ix < roi_bin_grid_w; ix++) {
94- detail::PreCalc<T> pc = pre_calc[pre_calc_index];
95- output_val += pc. w1 * offset_input[pc. pos1 ] +
96- pc. w2 * offset_input[pc. pos2 ] +
97- pc. w3 * offset_input[pc. pos3 ] + pc. w4 * offset_input[pc. pos4 ];
98-
99- pre_calc_index += 1 ;
100- }
101- }
102- output_val /= count; // Average pooling
103-
104- output[index] = output_val;
105- } // for pw
106- } // for ph
107- } // for c
166+ if (is_channels_last ) {
167+ roi_align_single_framework_channels_last_forward (
168+ input + roi_batch_ind * height * width * channels,
169+ count,
170+ channels,
171+ height,
172+ width,
173+ pooled_height,
174+ pooled_width,
175+ roi_bin_grid_h,
176+ roi_bin_grid_w,
177+ pre_calc,
178+ output + n * pooled_width * pooled_height * channels);
179+ } else {
180+ roi_align_single_framework_forward (
181+ input + roi_batch_ind * channels * height * width,
182+ count,
183+ channels,
184+ height,
185+ width,
186+ pooled_height,
187+ pooled_width,
188+ roi_bin_grid_h,
189+ roi_bin_grid_w,
190+ pre_calc,
191+ output + n * channels * pooled_width * pooled_height);
192+ }
108193 } // for n
109194 });
110195}
@@ -303,13 +388,15 @@ at::Tensor roi_align_forward_kernel(
303388 auto height = input.size (2 );
304389 auto width = input.size (3 );
305390
306- at::Tensor output = at::zeros (
307- {num_rois, channels, pooled_height, pooled_width}, input.options ());
391+ auto memory_format = input.suggest_memory_format ();
392+ bool is_channels_last = memory_format == at::MemoryFormat::ChannelsLast;
393+ at::Tensor output = at::empty ({0 }, input.options ());
394+ output.resize_ ({num_rois, channels, pooled_height, pooled_width}, memory_format).zero_ ();
308395
309396 if (output.numel () == 0 )
310397 return output;
311398
312- auto input_ = input.contiguous (), rois_ = rois.contiguous ();
399+ auto input_ = input.contiguous (memory_format ), rois_ = rois.contiguous ();
313400 AT_DISPATCH_FLOATING_TYPES_AND_HALF (
314401 input.scalar_type (), " roi_align_forward_kernel" , [&] {
315402 roi_align_forward_kernel_impl<scalar_t >(
@@ -324,7 +411,8 @@ at::Tensor roi_align_forward_kernel(
324411 sampling_ratio,
325412 aligned,
326413 rois_.data_ptr <scalar_t >(),
327- output.data_ptr <scalar_t >());
414+ output.data_ptr <scalar_t >(),
415+ is_channels_last);
328416 });
329417 return output;
330418}
0 commit comments