Skip to content

Commit ba02b2f

Browse files
authored
Encapsulate and standardize roi_align (#3085)
* Renaming C++ files & methods according to recommended naming conventions and aligning them with Python's API. * Adding all internal functions in anonymous namespaces. * Renaming C++/CUDA kernel files and moving operator code from header to cpp file. * Create foreach cpp file a separate header file with "public" functions. * Removing unnecessary repeated includes.
1 parent 750bde3 commit ba02b2f

10 files changed

+173
-95
lines changed

test/tracing/frcnn/test_frcnn_tracing.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include <ATen/ATen.h>
22
#include <torch/script.h>
33
#include <torch/torch.h>
4-
#include <torchvision/ROIAlign.h>
4+
#include <torchvision/roi_align.h>
55
#include <torchvision/cpu/vision_cpu.h>
66
#include <torchvision/nms.h>
77

torchvision/csrc/cpu/ROIAlign_cpu.cpp renamed to torchvision/csrc/cpu/roi_align_kernel.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
#include <ATen/TensorUtils.h>
2-
#include "vision_cpu.h"
1+
#include "roi_align_kernel.h"
2+
3+
namespace {
34

45
// implementation taken from Caffe2
56
template <typename T>
@@ -111,7 +112,7 @@ void pre_calc_for_bilinear_interpolate(
111112
}
112113

113114
template <typename T>
114-
void ROIAlignForward(
115+
void roi_align_forward_kernel_impl(
115116
int nthreads,
116117
const T* input,
117118
const T& spatial_scale,
@@ -277,7 +278,7 @@ inline void add(T* address, const T& val) {
277278
}
278279

279280
template <typename T>
280-
void ROIAlignBackward(
281+
void roi_align_backward_kernel_impl(
281282
int nthreads,
282283
const T* grad_output,
283284
const T& spatial_scale,
@@ -382,9 +383,11 @@ void ROIAlignBackward(
382383
} // ix
383384
} // iy
384385
} // for
385-
} // ROIAlignBackward
386+
}
387+
388+
} // namespace
386389

387-
at::Tensor ROIAlign_forward_cpu(
390+
at::Tensor roi_align_forward_cpu(
388391
const at::Tensor& input,
389392
const at::Tensor& rois,
390393
double spatial_scale,
@@ -398,7 +401,7 @@ at::Tensor ROIAlign_forward_cpu(
398401

399402
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
400403

401-
at::CheckedFrom c = "ROIAlign_forward_cpu";
404+
at::CheckedFrom c = "roi_align_forward_cpu";
402405
at::checkAllSameType(c, {input_t, rois_t});
403406

404407
auto num_rois = rois.size(0);
@@ -416,8 +419,8 @@ at::Tensor ROIAlign_forward_cpu(
416419

417420
auto input_ = input.contiguous(), rois_ = rois.contiguous();
418421
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
419-
input.scalar_type(), "ROIAlign_forward", [&] {
420-
ROIAlignForward<scalar_t>(
422+
input.scalar_type(), "roi_align_forward", [&] {
423+
roi_align_forward_kernel_impl<scalar_t>(
421424
output_size,
422425
input_.data_ptr<scalar_t>(),
423426
spatial_scale,
@@ -434,7 +437,7 @@ at::Tensor ROIAlign_forward_cpu(
434437
return output;
435438
}
436439

437-
at::Tensor ROIAlign_backward_cpu(
440+
at::Tensor roi_align_backward_cpu(
438441
const at::Tensor& grad,
439442
const at::Tensor& rois,
440443
double spatial_scale,
@@ -451,7 +454,7 @@ at::Tensor ROIAlign_backward_cpu(
451454

452455
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
453456

454-
at::CheckedFrom c = "ROIAlign_backward_cpu";
457+
at::CheckedFrom c = "roi_align_backward_cpu";
455458
at::checkAllSameType(c, {grad_t, rois_t});
456459

457460
at::Tensor grad_input =
@@ -470,8 +473,8 @@ at::Tensor ROIAlign_backward_cpu(
470473

471474
auto rois_ = rois.contiguous();
472475
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
473-
grad.scalar_type(), "ROIAlign_forward", [&] {
474-
ROIAlignBackward<scalar_t>(
476+
grad.scalar_type(), "roi_align_forward", [&] {
477+
roi_align_backward_kernel_impl<scalar_t>(
475478
grad.numel(),
476479
grad.data_ptr<scalar_t>(),
477480
spatial_scale,
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include "../macros.h"
5+
6+
VISION_API at::Tensor roi_align_forward_cpu(
7+
const at::Tensor& input,
8+
const at::Tensor& rois,
9+
double spatial_scale,
10+
int64_t pooled_height,
11+
int64_t pooled_width,
12+
int64_t sampling_ratio,
13+
bool aligned);
14+
15+
VISION_API at::Tensor roi_align_backward_cpu(
16+
const at::Tensor& grad,
17+
const at::Tensor& rois,
18+
double spatial_scale,
19+
int64_t pooled_height,
20+
int64_t pooled_width,
21+
int64_t batch_size,
22+
int64_t channels,
23+
int64_t height,
24+
int64_t width,
25+
int64_t sampling_ratio,
26+
bool aligned);

torchvision/csrc/cpu/vision_cpu.h

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,6 @@
44

55
// TODO: Delete this file once all the methods are gone
66

7-
VISION_API at::Tensor ROIAlign_forward_cpu(
8-
const at::Tensor& input,
9-
const at::Tensor& rois,
10-
double spatial_scale,
11-
int64_t pooled_height,
12-
int64_t pooled_width,
13-
int64_t sampling_ratio,
14-
bool aligned);
15-
16-
VISION_API at::Tensor ROIAlign_backward_cpu(
17-
const at::Tensor& grad,
18-
const at::Tensor& rois,
19-
double spatial_scale,
20-
int64_t pooled_height,
21-
int64_t pooled_width,
22-
int64_t batch_size,
23-
int64_t channels,
24-
int64_t height,
25-
int64_t width,
26-
int64_t sampling_ratio,
27-
bool aligned);
28-
297
VISION_API std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cpu(
308
const at::Tensor& input,
319
const at::Tensor& rois,

torchvision/csrc/cuda/ROIAlign_cuda.cu renamed to torchvision/csrc/cuda/roi_align_kernel.cu

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
#include <ATen/ATen.h>
2-
#include <ATen/TensorUtils.h>
31
#include <ATen/cuda/CUDAContext.h>
42
#include <c10/cuda/CUDAGuard.h>
53
#include <THC/THCAtomics.cuh>
64

75
#include "cuda_helpers.h"
6+
#include "roi_align_kernel.h"
7+
8+
namespace {
89

910
template <typename T>
1011
__device__ T bilinear_interpolate(
@@ -61,7 +62,7 @@ __device__ T bilinear_interpolate(
6162
}
6263

6364
template <typename T>
64-
__global__ void RoIAlignForward(
65+
__global__ void roi_align_forward_kernel_impl(
6566
int nthreads,
6667
const T* input,
6768
const T spatial_scale,
@@ -197,7 +198,7 @@ __device__ void bilinear_interpolate_gradient(
197198
}
198199

199200
template <typename T>
200-
__global__ void RoIAlignBackward(
201+
__global__ void roi_align_backward_kernel_impl(
201202
int nthreads,
202203
const T* grad_output,
203204
const T spatial_scale,
@@ -308,9 +309,11 @@ __global__ void RoIAlignBackward(
308309
} // ix
309310
} // iy
310311
} // CUDA_1D_KERNEL_LOOP
311-
} // RoIAlignBackward
312+
}
313+
314+
} // namespace
312315

313-
at::Tensor ROIAlign_forward_cuda(
316+
at::Tensor roi_align_forward_cuda(
314317
const at::Tensor& input,
315318
const at::Tensor& rois,
316319
double spatial_scale,
@@ -325,7 +328,7 @@ at::Tensor ROIAlign_forward_cuda(
325328

326329
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
327330

328-
at::CheckedFrom c = "ROIAlign_forward_cuda";
331+
at::CheckedFrom c = "roi_align_forward_cuda";
329332
at::checkAllSameGPU(c, {input_t, rois_t});
330333
at::checkAllSameType(c, {input_t, rois_t});
331334

@@ -354,8 +357,8 @@ at::Tensor ROIAlign_forward_cuda(
354357

355358
auto input_ = input.contiguous(),
356359
rois_ = rois.contiguous();
357-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "ROIAlign_forward", [&] {
358-
RoIAlignForward<scalar_t><<<grid, block, 0, stream>>>(
360+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "roi_align_forward", [&] {
361+
roi_align_forward_kernel_impl<scalar_t><<<grid, block, 0, stream>>>(
359362
output_size,
360363
input_.data_ptr<scalar_t>(),
361364
spatial_scale,
@@ -373,7 +376,7 @@ at::Tensor ROIAlign_forward_cuda(
373376
return output;
374377
}
375378

376-
at::Tensor ROIAlign_backward_cuda(
379+
at::Tensor roi_align_backward_cuda(
377380
const at::Tensor& grad,
378381
const at::Tensor& rois,
379382
double spatial_scale,
@@ -390,7 +393,7 @@ at::Tensor ROIAlign_backward_cuda(
390393

391394
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
392395

393-
at::CheckedFrom c = "ROIAlign_backward_cuda";
396+
at::CheckedFrom c = "roi_align_backward_cuda";
394397
at::checkAllSameGPU(c, {grad_t, rois_t});
395398
at::checkAllSameType(c, {grad_t, rois_t});
396399

@@ -418,8 +421,8 @@ at::Tensor ROIAlign_backward_cuda(
418421
int w_stride = grad.stride(3);
419422

420423
auto rois_ = rois.contiguous();
421-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "ROIAlign_backward", [&] {
422-
RoIAlignBackward<scalar_t><<<grid, block, 0, stream>>>(
424+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "roi_align_backward", [&] {
425+
roi_align_backward_kernel_impl<scalar_t><<<grid, block, 0, stream>>>(
423426
grad.numel(),
424427
grad.data_ptr<scalar_t>(),
425428
spatial_scale,
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include "../macros.h"
5+
6+
VISION_API at::Tensor roi_align_forward_cuda(
7+
const at::Tensor& input,
8+
const at::Tensor& rois,
9+
double spatial_scale,
10+
int64_t pooled_height,
11+
int64_t pooled_width,
12+
int64_t sampling_ratio,
13+
bool aligned);
14+
15+
VISION_API at::Tensor roi_align_backward_cuda(
16+
const at::Tensor& grad,
17+
const at::Tensor& rois,
18+
double spatial_scale,
19+
int64_t pooled_height,
20+
int64_t pooled_width,
21+
int64_t batch_size,
22+
int64_t channels,
23+
int64_t height,
24+
int64_t width,
25+
int64_t sampling_ratio,
26+
bool aligned);

torchvision/csrc/cuda/vision_cuda.h

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,6 @@
44

55
// TODO: Delete this file once all the methods are gone
66

7-
VISION_API at::Tensor ROIAlign_forward_cuda(
8-
const at::Tensor& input,
9-
const at::Tensor& rois,
10-
double spatial_scale,
11-
int64_t pooled_height,
12-
int64_t pooled_width,
13-
int64_t sampling_ratio,
14-
bool aligned);
15-
16-
VISION_API at::Tensor ROIAlign_backward_cuda(
17-
const at::Tensor& grad,
18-
const at::Tensor& rois,
19-
double spatial_scale,
20-
int64_t pooled_height,
21-
int64_t pooled_width,
22-
int64_t batch_size,
23-
int64_t channels,
24-
int64_t height,
25-
int64_t width,
26-
int64_t sampling_ratio,
27-
bool aligned);
28-
297
VISION_API std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(
308
const at::Tensor& input,
319
const at::Tensor& rois,

torchvision/csrc/ROIAlign.h renamed to torchvision/csrc/roi_align.cpp

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,10 @@
1-
#pragma once
1+
#include "roi_align.h"
2+
#include <torch/extension.h>
23

3-
#include "cpu/vision_cpu.h"
4-
5-
#ifdef WITH_CUDA
6-
#include "autocast.h"
7-
#include "cuda/vision_cuda.h"
8-
#endif
9-
#ifdef WITH_HIP
10-
#include "autocast.h"
11-
#include "hip/vision_cuda.h"
4+
#if defined(WITH_CUDA) || defined(WITH_HIP)
5+
#include <ATen/autocast_mode.h>
126
#endif
137

14-
// TODO: put this stuff in torchvision namespace
15-
16-
// roi_align dispatch nexus
178
at::Tensor roi_align(
189
const at::Tensor& input, // Input feature map.
1910
const at::Tensor& rois, // List of ROIs to pool over.
@@ -39,7 +30,7 @@ at::Tensor roi_align(
3930
}
4031

4132
#if defined(WITH_CUDA) || defined(WITH_HIP)
42-
at::Tensor ROIAlign_autocast(
33+
at::Tensor roi_align_autocast(
4334
const at::Tensor& input,
4435
const at::Tensor& rois,
4536
double spatial_scale,
@@ -90,6 +81,8 @@ at::Tensor _roi_align_backward(
9081
aligned);
9182
}
9283

84+
namespace {
85+
9386
class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
9487
public:
9588
static torch::autograd::variable_list forward(
@@ -189,7 +182,9 @@ class ROIAlignBackwardFunction
189182
}
190183
};
191184

192-
at::Tensor ROIAlign_autograd(
185+
} // namespace
186+
187+
at::Tensor roi_align_autograd(
193188
const at::Tensor& input,
194189
const at::Tensor& rois,
195190
double spatial_scale,
@@ -207,7 +202,7 @@ at::Tensor ROIAlign_autograd(
207202
aligned)[0];
208203
}
209204

210-
at::Tensor ROIAlign_backward_autograd(
205+
at::Tensor roi_align_backward_autograd(
211206
const at::Tensor& grad,
212207
const at::Tensor& rois,
213208
double spatial_scale,

0 commit comments

Comments
 (0)