diff --git a/test/tracing/frcnn/test_frcnn_tracing.cpp b/test/tracing/frcnn/test_frcnn_tracing.cpp index a23b95cf88f..95b3a1b5726 100644 --- a/test/tracing/frcnn/test_frcnn_tracing.cpp +++ b/test/tracing/frcnn/test_frcnn_tracing.cpp @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include #include diff --git a/torchvision/csrc/cpu/ROIAlign_cpu.cpp b/torchvision/csrc/cpu/roi_align_kernel.cpp similarity index 96% rename from torchvision/csrc/cpu/ROIAlign_cpu.cpp rename to torchvision/csrc/cpu/roi_align_kernel.cpp index 10ebd8158cc..01d2bca25a3 100644 --- a/torchvision/csrc/cpu/ROIAlign_cpu.cpp +++ b/torchvision/csrc/cpu/roi_align_kernel.cpp @@ -1,5 +1,6 @@ -#include -#include "vision_cpu.h" +#include "roi_align_kernel.h" + +namespace { // implementation taken from Caffe2 template @@ -111,7 +112,7 @@ void pre_calc_for_bilinear_interpolate( } template -void ROIAlignForward( +void roi_align_forward_kernel_impl( int nthreads, const T* input, const T& spatial_scale, @@ -277,7 +278,7 @@ inline void add(T* address, const T& val) { } template -void ROIAlignBackward( +void roi_align_backward_kernel_impl( int nthreads, const T* grad_output, const T& spatial_scale, @@ -382,9 +383,11 @@ void ROIAlignBackward( } // ix } // iy } // for -} // ROIAlignBackward +} + +} // namespace -at::Tensor ROIAlign_forward_cpu( +at::Tensor roi_align_forward_cpu( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -398,7 +401,7 @@ at::Tensor ROIAlign_forward_cpu( at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "ROIAlign_forward_cpu"; + at::CheckedFrom c = "roi_align_forward_cpu"; at::checkAllSameType(c, {input_t, rois_t}); auto num_rois = rois.size(0); @@ -416,8 +419,8 @@ at::Tensor ROIAlign_forward_cpu( auto input_ = input.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "ROIAlign_forward", [&] { - ROIAlignForward( + input.scalar_type(), "roi_align_forward", [&] { + roi_align_forward_kernel_impl( output_size, input_.data_ptr(), spatial_scale, @@ -434,7 +437,7 @@ at::Tensor ROIAlign_forward_cpu( return output; } -at::Tensor ROIAlign_backward_cpu( +at::Tensor roi_align_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, double spatial_scale, @@ -451,7 +454,7 @@ at::Tensor ROIAlign_backward_cpu( at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "ROIAlign_backward_cpu"; + at::CheckedFrom c = "roi_align_backward_cpu"; at::checkAllSameType(c, {grad_t, rois_t}); at::Tensor grad_input = @@ -470,8 +473,8 @@ at::Tensor ROIAlign_backward_cpu( auto rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "ROIAlign_forward", [&] { - ROIAlignBackward( + grad.scalar_type(), "roi_align_forward", [&] { + roi_align_backward_kernel_impl( grad.numel(), grad.data_ptr(), spatial_scale, diff --git a/torchvision/csrc/cpu/roi_align_kernel.h b/torchvision/csrc/cpu/roi_align_kernel.h new file mode 100644 index 00000000000..79fd46bd44e --- /dev/null +++ b/torchvision/csrc/cpu/roi_align_kernel.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include "../macros.h" + +VISION_API at::Tensor roi_align_forward_cpu( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned); + +VISION_API at::Tensor roi_align_backward_cpu( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned); diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index baf64f89689..a2647c57aa5 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -4,28 +4,6 @@ // TODO: Delete this file once all the methods are gone -VISION_API at::Tensor ROIAlign_forward_cpu( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - bool aligned); - -VISION_API at::Tensor ROIAlign_backward_cpu( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width, - int64_t sampling_ratio, - bool aligned); - VISION_API std::tuple ROIPool_forward_cpu( const at::Tensor& input, const at::Tensor& rois, diff --git a/torchvision/csrc/cuda/ROIAlign_cuda.cu b/torchvision/csrc/cuda/roi_align_kernel.cu similarity index 94% rename from torchvision/csrc/cuda/ROIAlign_cuda.cu rename to torchvision/csrc/cuda/roi_align_kernel.cu index b773121d2b9..7f763170a9e 100644 --- a/torchvision/csrc/cuda/ROIAlign_cuda.cu +++ b/torchvision/csrc/cuda/roi_align_kernel.cu @@ -1,10 +1,11 @@ -#include -#include #include #include #include #include "cuda_helpers.h" +#include "roi_align_kernel.h" + +namespace { template __device__ T bilinear_interpolate( @@ -61,7 +62,7 @@ __device__ T bilinear_interpolate( } template -__global__ void RoIAlignForward( +__global__ void roi_align_forward_kernel_impl( int nthreads, const T* input, const T spatial_scale, @@ -197,7 +198,7 @@ __device__ void bilinear_interpolate_gradient( } template -__global__ void RoIAlignBackward( +__global__ void roi_align_backward_kernel_impl( int nthreads, const T* grad_output, const T spatial_scale, @@ -308,9 +309,11 @@ __global__ void RoIAlignBackward( } // ix } // iy } // CUDA_1D_KERNEL_LOOP -} // RoIAlignBackward +} + +} // namespace -at::Tensor ROIAlign_forward_cuda( +at::Tensor roi_align_forward_cuda( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -325,7 +328,7 @@ at::Tensor ROIAlign_forward_cuda( at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "ROIAlign_forward_cuda"; + at::CheckedFrom c = "roi_align_forward_cuda"; at::checkAllSameGPU(c, {input_t, rois_t}); at::checkAllSameType(c, {input_t, rois_t}); @@ -354,8 +357,8 @@ at::Tensor ROIAlign_forward_cuda( auto input_ = input.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "ROIAlign_forward", [&] { - RoIAlignForward<<>>( + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "roi_align_forward", [&] { + roi_align_forward_kernel_impl<<>>( output_size, input_.data_ptr(), spatial_scale, @@ -373,7 +376,7 @@ at::Tensor ROIAlign_forward_cuda( return output; } -at::Tensor ROIAlign_backward_cuda( +at::Tensor roi_align_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, double spatial_scale, @@ -390,7 +393,7 @@ at::Tensor ROIAlign_backward_cuda( at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "ROIAlign_backward_cuda"; + at::CheckedFrom c = "roi_align_backward_cuda"; at::checkAllSameGPU(c, {grad_t, rois_t}); at::checkAllSameType(c, {grad_t, rois_t}); @@ -418,8 +421,8 @@ at::Tensor ROIAlign_backward_cuda( int w_stride = grad.stride(3); auto rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "ROIAlign_backward", [&] { - RoIAlignBackward<<>>( + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "roi_align_backward", [&] { + roi_align_backward_kernel_impl<<>>( grad.numel(), grad.data_ptr(), spatial_scale, diff --git a/torchvision/csrc/cuda/roi_align_kernel.h b/torchvision/csrc/cuda/roi_align_kernel.h new file mode 100644 index 00000000000..46054f04f38 --- /dev/null +++ b/torchvision/csrc/cuda/roi_align_kernel.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include "../macros.h" + +VISION_API at::Tensor roi_align_forward_cuda( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned); + +VISION_API at::Tensor roi_align_backward_cuda( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned); diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index 8d411b9c67e..1ec187c3348 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -4,28 +4,6 @@ // TODO: Delete this file once all the methods are gone -VISION_API at::Tensor ROIAlign_forward_cuda( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - bool aligned); - -VISION_API at::Tensor ROIAlign_backward_cuda( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width, - int64_t sampling_ratio, - bool aligned); - VISION_API std::tuple ROIPool_forward_cuda( const at::Tensor& input, const at::Tensor& rois, diff --git a/torchvision/csrc/ROIAlign.h b/torchvision/csrc/roi_align.cpp similarity index 94% rename from torchvision/csrc/ROIAlign.h rename to torchvision/csrc/roi_align.cpp index 708981f061e..30eda8612d2 100644 --- a/torchvision/csrc/ROIAlign.h +++ b/torchvision/csrc/roi_align.cpp @@ -1,19 +1,10 @@ -#pragma once +#include "roi_align.h" +#include -#include "cpu/vision_cpu.h" - -#ifdef WITH_CUDA -#include "autocast.h" -#include "cuda/vision_cuda.h" -#endif -#ifdef WITH_HIP -#include "autocast.h" -#include "hip/vision_cuda.h" +#if defined(WITH_CUDA) || defined(WITH_HIP) +#include #endif -// TODO: put this stuff in torchvision namespace - -// roi_align dispatch nexus at::Tensor roi_align( const at::Tensor& input, // Input feature map. const at::Tensor& rois, // List of ROIs to pool over. @@ -39,7 +30,7 @@ at::Tensor roi_align( } #if defined(WITH_CUDA) || defined(WITH_HIP) -at::Tensor ROIAlign_autocast( +at::Tensor roi_align_autocast( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -90,6 +81,8 @@ at::Tensor _roi_align_backward( aligned); } +namespace { + class ROIAlignFunction : public torch::autograd::Function { public: static torch::autograd::variable_list forward( @@ -189,7 +182,9 @@ class ROIAlignBackwardFunction } }; -at::Tensor ROIAlign_autograd( +} // namespace + +at::Tensor roi_align_autograd( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -207,7 +202,7 @@ at::Tensor ROIAlign_autograd( aligned)[0]; } -at::Tensor ROIAlign_backward_autograd( +at::Tensor roi_align_backward_autograd( const at::Tensor& grad, const at::Tensor& rois, double spatial_scale, diff --git a/torchvision/csrc/roi_align.h b/torchvision/csrc/roi_align.h new file mode 100644 index 00000000000..d9bae4ba2a1 --- /dev/null +++ b/torchvision/csrc/roi_align.h @@ -0,0 +1,69 @@ +#pragma once + +#include "cpu/roi_align_kernel.h" + +#ifdef WITH_CUDA +#include "cuda/roi_align_kernel.h" +#endif +#ifdef WITH_HIP +#include "hip/roi_align_kernel.h" +#endif + +// C++ Forward +at::Tensor roi_align( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned); + +// Autocast Forward +#if defined(WITH_CUDA) || defined(WITH_HIP) +at::Tensor roi_align_autocast( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned); +#endif + +// C++ Backward +at::Tensor _roi_align_backward( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned); + +// Autograd Forward and Backward +at::Tensor roi_align_autograd( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned); + +at::Tensor roi_align_backward_autograd( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned); diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index 6f540c6832e..c41663f0736 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -8,13 +8,13 @@ #include #endif -#include "ROIAlign.h" #include "ROIPool.h" #include "deform_conv2d.h" #include "empty_tensor_op.h" #include "nms.h" #include "ps_roi_align.h" #include "ps_roi_pool.h" +#include "roi_align.h" // If we are in a Windows environment, we need to define // initialization functions for the _custom_ops extension @@ -69,8 +69,8 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) { m.impl("_ps_roi_align_backward", ps_roi_align_backward_cpu); m.impl("ps_roi_pool", ps_roi_pool_forward_cpu); m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_cpu); - m.impl("roi_align", ROIAlign_forward_cpu); - m.impl("_roi_align_backward", ROIAlign_backward_cpu); + m.impl("roi_align", roi_align_forward_cpu); + m.impl("_roi_align_backward", roi_align_backward_cpu); m.impl("roi_pool", ROIPool_forward_cpu); m.impl("_roi_pool_backward", ROIPool_backward_cpu); } @@ -85,8 +85,8 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { m.impl("_ps_roi_align_backward", ps_roi_align_backward_cuda); m.impl("ps_roi_pool", ps_roi_pool_forward_cuda); m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_cuda); - m.impl("roi_align", ROIAlign_forward_cuda); - m.impl("_roi_align_backward", ROIAlign_backward_cuda); + m.impl("roi_align", roi_align_forward_cuda); + m.impl("_roi_align_backward", roi_align_backward_cuda); m.impl("roi_pool", ROIPool_forward_cuda); m.impl("_roi_pool_backward", ROIPool_backward_cuda); } @@ -99,7 +99,7 @@ TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { m.impl("nms", nms_autocast); m.impl("ps_roi_align", ps_roi_align_autocast); m.impl("ps_roi_pool", ps_roi_pool_autocast); - m.impl("roi_align", ROIAlign_autocast); + m.impl("roi_align", roi_align_autocast); m.impl("roi_pool", ROIPool_autocast); } #endif @@ -111,8 +111,8 @@ TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { m.impl("_ps_roi_align_backward", ps_roi_align_backward_autograd); m.impl("ps_roi_pool", ps_roi_pool_autograd); m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_autograd); - m.impl("roi_align", ROIAlign_autograd); - m.impl("_roi_align_backward", ROIAlign_backward_autograd); + m.impl("roi_align", roi_align_autograd); + m.impl("_roi_align_backward", roi_align_backward_autograd); m.impl("roi_pool", ROIPool_autograd); m.impl("_roi_pool_backward", ROIPool_backward_autograd); }