Skip to content

Commit e50cd53

Browse files
committed
add channels_last support for roi_align on CPU
1 parent 2862d6b commit e50cd53

2 files changed

Lines changed: 135 additions & 34 deletions

File tree

setup.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,10 @@ def get_extensions():
158158
print(f" TORCHVISION_USE_VIDEO_CODEC: {use_video_codec}")
159159
use_openmp = os.getenv("USE_OPENMP", "0") == "1"
160160
print(f" USE_OPEN: {use_openmp}")
161+
use_avx2 = os.getenv('USE_AVX2', '0') == '1'
162+
print(f" USE_AVX2: {use_avx2}")
163+
use_avx512 = os.getenv('USE_AVX512', '0') == '1'
164+
print(f" USE_AVX512: {use_avx512}")
161165

162166
nvcc_flags = os.getenv("NVCC_FLAGS", "")
163167
print(f" NVCC_FLAGS: {nvcc_flags}")
@@ -228,6 +232,15 @@ def get_extensions():
228232
if use_openmp:
229233
extra_compile_args["cxx"].append("-fopenmp")
230234

235+
if use_avx2:
236+
extra_compile_args["cxx"].append("-O3")
237+
extra_compile_args["cxx"].append("-mavx2")
238+
239+
if use_avx512:
240+
extra_compile_args["cxx"].append("-O3")
241+
extra_compile_args["cxx"].append("-march=skylake-avx512")
242+
extra_compile_args["cxx"].append("-mavx512f")
243+
231244
if debug_mode:
232245
print("Compiling in debug mode")
233246
extra_compile_args["cxx"].append("-g")

torchvision/csrc/ops/cpu/roi_align_kernel.cpp

Lines changed: 122 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,92 @@ namespace ops {
99

1010
namespace {
1111

12+
template <typename T>
13+
inline void roi_align_single_framework_forward(
14+
const T* input,
15+
const int count,
16+
int channels,
17+
int height,
18+
int width,
19+
int pooled_height,
20+
int pooled_width,
21+
int roi_bin_grid_h,
22+
int roi_bin_grid_w,
23+
const std::vector<detail::PreCalc<T>>& pre_calc,
24+
T* output) {
25+
for (int c = 0; c < channels; c++) {
26+
const T* offset_input = input + c * height * width;
27+
int pre_calc_index = 0;
28+
29+
for (int ph = 0; ph < pooled_height; ph++) {
30+
for (int pw = 0; pw < pooled_width; pw++) {
31+
int index = c * pooled_height * pooled_width + ph * pooled_width + pw;
32+
33+
T output_val = 0.;
34+
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
35+
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
36+
detail::PreCalc<T> pc = pre_calc[pre_calc_index];
37+
output_val += pc.w1 * offset_input[pc.pos1] +
38+
pc.w2 * offset_input[pc.pos2] +
39+
pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4];
40+
41+
pre_calc_index += 1;
42+
}
43+
}
44+
output_val /= count; // Average pooling
45+
46+
output[index] = output_val;
47+
} // for pw
48+
} // for ph
49+
} // for c
50+
}
51+
52+
template <typename T>
53+
inline void roi_align_single_framework_channels_last_forward(
54+
const T* input,
55+
const int count,
56+
int channels,
57+
int height,
58+
int width,
59+
int pooled_height,
60+
int pooled_width,
61+
int roi_bin_grid_h,
62+
int roi_bin_grid_w,
63+
const std::vector<detail::PreCalc<T>>& pre_calc,
64+
T* output) {
65+
// for 'normal' size of channels, should be L1 fit;
66+
// otherwise consider blocking on channels.
67+
int pre_calc_index = 0;
68+
for (int ph = 0; ph < pooled_height; ph++) {
69+
for (int pw = 0; pw < pooled_width; pw++) {
70+
T* out = output + (ph * pooled_width + pw) * channels;
71+
72+
// pass I: do accumulation
73+
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
74+
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
75+
detail::PreCalc<T> pc = pre_calc[pre_calc_index];
76+
const T* in1 = input + pc.pos1 * channels;
77+
const T* in2 = input + pc.pos2 * channels;
78+
const T* in3 = input + pc.pos3 * channels;
79+
const T* in4 = input + pc.pos4 * channels;
80+
81+
#pragma omp simd
82+
for (int c = 0; c < channels; c++) {
83+
out[c] += pc.w1 * in1[c] + pc.w2 * in2[c] + pc.w3 * in3[c] + pc.w4 * in4[c];
84+
}
85+
pre_calc_index += 1;
86+
}
87+
}
88+
89+
// pass II: do average
90+
#pragma omp simd
91+
for (int c = 0; c < channels; c++) {
92+
out[c] /= count;
93+
}
94+
} // for pw
95+
} // for ph
96+
}
97+
1298
template <typename T>
1399
void roi_align_forward_kernel_impl(
14100
int n_rois,
@@ -22,13 +108,12 @@ void roi_align_forward_kernel_impl(
22108
int sampling_ratio,
23109
bool aligned,
24110
const T* rois,
25-
T* output) {
111+
T* output,
112+
bool is_channels_last) {
26113
// (n, c, ph, pw) is an element in the pooled output
27114
// can be parallelized using omp
28115
at::parallel_for(0, n_rois, 1, [&](int begin, int end) {
29116
for (int n = begin; n < end; n++) {
30-
int index_n = n * channels * pooled_width * pooled_height;
31-
32117
const T* offset_rois = rois + n * 5;
33118
int roi_batch_ind = offset_rois[0];
34119

@@ -78,33 +163,33 @@ void roi_align_forward_kernel_impl(
78163
roi_bin_grid_w,
79164
pre_calc);
80165

81-
for (int c = 0; c < channels; c++) {
82-
int index_n_c = index_n + c * pooled_width * pooled_height;
83-
const T* offset_input =
84-
input + (roi_batch_ind * channels + c) * height * width;
85-
int pre_calc_index = 0;
86-
87-
for (int ph = 0; ph < pooled_height; ph++) {
88-
for (int pw = 0; pw < pooled_width; pw++) {
89-
int index = index_n_c + ph * pooled_width + pw;
90-
91-
T output_val = 0.;
92-
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
93-
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
94-
detail::PreCalc<T> pc = pre_calc[pre_calc_index];
95-
output_val += pc.w1 * offset_input[pc.pos1] +
96-
pc.w2 * offset_input[pc.pos2] +
97-
pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4];
98-
99-
pre_calc_index += 1;
100-
}
101-
}
102-
output_val /= count; // Average pooling
103-
104-
output[index] = output_val;
105-
} // for pw
106-
} // for ph
107-
} // for c
166+
if (is_channels_last) {
167+
roi_align_single_framework_channels_last_forward(
168+
input + roi_batch_ind * height * width * channels,
169+
count,
170+
channels,
171+
height,
172+
width,
173+
pooled_height,
174+
pooled_width,
175+
roi_bin_grid_h,
176+
roi_bin_grid_w,
177+
pre_calc,
178+
output + n * pooled_width * pooled_height * channels);
179+
} else {
180+
roi_align_single_framework_forward(
181+
input + roi_batch_ind * channels * height * width,
182+
count,
183+
channels,
184+
height,
185+
width,
186+
pooled_height,
187+
pooled_width,
188+
roi_bin_grid_h,
189+
roi_bin_grid_w,
190+
pre_calc,
191+
output + n * channels * pooled_width * pooled_height);
192+
}
108193
} // for n
109194
});
110195
}
@@ -303,13 +388,15 @@ at::Tensor roi_align_forward_kernel(
303388
auto height = input.size(2);
304389
auto width = input.size(3);
305390

306-
at::Tensor output = at::zeros(
307-
{num_rois, channels, pooled_height, pooled_width}, input.options());
391+
auto memory_format = input.suggest_memory_format();
392+
bool is_channels_last = memory_format == at::MemoryFormat::ChannelsLast;
393+
at::Tensor output = at::empty({0}, input.options());
394+
output.resize_({num_rois, channels, pooled_height, pooled_width}, memory_format).zero_();
308395

309396
if (output.numel() == 0)
310397
return output;
311398

312-
auto input_ = input.contiguous(), rois_ = rois.contiguous();
399+
auto input_ = input.contiguous(memory_format), rois_ = rois.contiguous();
313400
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
314401
input.scalar_type(), "roi_align_forward_kernel", [&] {
315402
roi_align_forward_kernel_impl<scalar_t>(
@@ -324,7 +411,8 @@ at::Tensor roi_align_forward_kernel(
324411
sampling_ratio,
325412
aligned,
326413
rois_.data_ptr<scalar_t>(),
327-
output.data_ptr<scalar_t>());
414+
output.data_ptr<scalar_t>(),
415+
is_channels_last);
328416
});
329417
return output;
330418
}

0 commit comments

Comments
 (0)