Skip to content

Commit 1fa27d0

Browse files
committed
parallel roi_align on CPU
1 parent a4f5330 commit 1fa27d0

2 files changed

Lines changed: 87 additions & 79 deletions

File tree

setup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ def get_extensions():
156156
print(f" TORCHVISION_USE_FFMPEG: {use_ffmpeg}")
157157
use_video_codec = os.getenv("TORCHVISION_USE_VIDEO_CODEC", "1") == "1"
158158
print(f" TORCHVISION_USE_VIDEO_CODEC: {use_video_codec}")
159+
use_openmp = os.getenv("USE_OPENMP", "0") == "1"
160+
print(f" USE_OPEN: {use_openmp}")
159161

160162
nvcc_flags = os.getenv("NVCC_FLAGS", "")
161163
print(f" NVCC_FLAGS: {nvcc_flags}")
@@ -223,6 +225,9 @@ def get_extensions():
223225
define_macros += [("USE_PYTHON", None)]
224226
extra_compile_args["cxx"].append("/MP")
225227

228+
if use_openmp:
229+
extra_compile_args["cxx"].append("-fopenmp")
230+
226231
if debug_mode:
227232
print("Compiling in debug mode")
228233
extra_compile_args["cxx"].append("-g")

torchvision/csrc/ops/cpu/roi_align_kernel.cpp

Lines changed: 82 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <ATen/ATen.h>
2+
#include <ATen/Parallel.h>
23
#include <torch/library.h>
34

45
#include "./roi_align_common.h"
@@ -25,86 +26,88 @@ void roi_align_forward_kernel_impl(
2526
// (n, c, ph, pw) is an element in the pooled output
2627
// can be parallelized using omp
2728
// #pragma omp parallel for num_threads(32)
28-
for (int n = 0; n < n_rois; n++) {
29-
int index_n = n * channels * pooled_width * pooled_height;
30-
31-
const T* offset_rois = rois + n * 5;
32-
int roi_batch_ind = offset_rois[0];
33-
34-
// Do not using rounding; this implementation detail is critical
35-
T offset = aligned ? (T)0.5 : (T)0.0;
36-
T roi_start_w = offset_rois[1] * spatial_scale - offset;
37-
T roi_start_h = offset_rois[2] * spatial_scale - offset;
38-
T roi_end_w = offset_rois[3] * spatial_scale - offset;
39-
T roi_end_h = offset_rois[4] * spatial_scale - offset;
40-
41-
T roi_width = roi_end_w - roi_start_w;
42-
T roi_height = roi_end_h - roi_start_h;
43-
if (!aligned) {
44-
// Force malformed ROIs to be 1x1
45-
roi_width = std::max(roi_width, (T)1.);
46-
roi_height = std::max(roi_height, (T)1.);
47-
}
48-
49-
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
50-
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
51-
52-
// We use roi_bin_grid to sample the grid and mimic integral
53-
int roi_bin_grid_h = (sampling_ratio > 0)
54-
? sampling_ratio
55-
: ceil(roi_height / pooled_height); // e.g., = 2
56-
int roi_bin_grid_w =
57-
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
58-
59-
// We do average (integral) pooling inside a bin
60-
// When the grid is empty, output zeros.
61-
const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
62-
63-
// we want to precalculate indices and weights shared by all chanels,
64-
// this is the key point of optimization
65-
std::vector<detail::PreCalc<T>> pre_calc(
66-
roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
67-
detail::pre_calc_for_bilinear_interpolate(
68-
height,
69-
width,
70-
pooled_height,
71-
pooled_width,
72-
roi_start_h,
73-
roi_start_w,
74-
bin_size_h,
75-
bin_size_w,
76-
roi_bin_grid_h,
77-
roi_bin_grid_w,
78-
pre_calc);
79-
80-
for (int c = 0; c < channels; c++) {
81-
int index_n_c = index_n + c * pooled_width * pooled_height;
82-
const T* offset_input =
83-
input + (roi_batch_ind * channels + c) * height * width;
84-
int pre_calc_index = 0;
85-
86-
for (int ph = 0; ph < pooled_height; ph++) {
87-
for (int pw = 0; pw < pooled_width; pw++) {
88-
int index = index_n_c + ph * pooled_width + pw;
89-
90-
T output_val = 0.;
91-
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
92-
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
93-
detail::PreCalc<T> pc = pre_calc[pre_calc_index];
94-
output_val += pc.w1 * offset_input[pc.pos1] +
95-
pc.w2 * offset_input[pc.pos2] +
96-
pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4];
97-
98-
pre_calc_index += 1;
29+
at::parallel_for(0, n_rois, 1, [&](int begin, int end) {
30+
for (int n = begin; n < end; n++) {
31+
int index_n = n * channels * pooled_width * pooled_height;
32+
33+
const T* offset_rois = rois + n * 5;
34+
int roi_batch_ind = offset_rois[0];
35+
36+
// Do not using rounding; this implementation detail is critical
37+
T offset = aligned ? (T)0.5 : (T)0.0;
38+
T roi_start_w = offset_rois[1] * spatial_scale - offset;
39+
T roi_start_h = offset_rois[2] * spatial_scale - offset;
40+
T roi_end_w = offset_rois[3] * spatial_scale - offset;
41+
T roi_end_h = offset_rois[4] * spatial_scale - offset;
42+
43+
T roi_width = roi_end_w - roi_start_w;
44+
T roi_height = roi_end_h - roi_start_h;
45+
if (!aligned) {
46+
// Force malformed ROIs to be 1x1
47+
roi_width = std::max(roi_width, (T)1.);
48+
roi_height = std::max(roi_height, (T)1.);
49+
}
50+
51+
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
52+
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
53+
54+
// We use roi_bin_grid to sample the grid and mimic integral
55+
int roi_bin_grid_h = (sampling_ratio > 0)
56+
? sampling_ratio
57+
: ceil(roi_height / pooled_height); // e.g., = 2
58+
int roi_bin_grid_w =
59+
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
60+
61+
// We do average (integral) pooling inside a bin
62+
// When the grid is empty, output zeros.
63+
const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
64+
65+
// we want to precalculate indices and weights shared by all chanels,
66+
// this is the key point of optimization
67+
std::vector<detail::PreCalc<T>> pre_calc(
68+
roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
69+
detail::pre_calc_for_bilinear_interpolate(
70+
height,
71+
width,
72+
pooled_height,
73+
pooled_width,
74+
roi_start_h,
75+
roi_start_w,
76+
bin_size_h,
77+
bin_size_w,
78+
roi_bin_grid_h,
79+
roi_bin_grid_w,
80+
pre_calc);
81+
82+
for (int c = 0; c < channels; c++) {
83+
int index_n_c = index_n + c * pooled_width * pooled_height;
84+
const T* offset_input =
85+
input + (roi_batch_ind * channels + c) * height * width;
86+
int pre_calc_index = 0;
87+
88+
for (int ph = 0; ph < pooled_height; ph++) {
89+
for (int pw = 0; pw < pooled_width; pw++) {
90+
int index = index_n_c + ph * pooled_width + pw;
91+
92+
T output_val = 0.;
93+
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
94+
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
95+
detail::PreCalc<T> pc = pre_calc[pre_calc_index];
96+
output_val += pc.w1 * offset_input[pc.pos1] +
97+
pc.w2 * offset_input[pc.pos2] +
98+
pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4];
99+
100+
pre_calc_index += 1;
101+
}
99102
}
100-
}
101-
output_val /= count; // Average pooling
102-
103-
output[index] = output_val;
104-
} // for pw
105-
} // for ph
106-
} // for c
107-
} // for n
103+
output_val /= count; // Average pooling
104+
105+
output[index] = output_val;
106+
} // for pw
107+
} // for ph
108+
} // for c
109+
} // for n
110+
});
108111
}
109112

110113
template <typename T>

0 commit comments

Comments
 (0)