diff --git a/torchvision/csrc/ROIPool.h b/torchvision/csrc/ROIPool.h index 38748c7f57b..00aa1b63d3c 100644 --- a/torchvision/csrc/ROIPool.h +++ b/torchvision/csrc/ROIPool.h @@ -2,60 +2,60 @@ #include "cpu/vision_cpu.h" -#ifdef WITH_CUDA -#include "cuda/vision_cuda.h" -#endif -#ifdef WITH_HIP -#include "hip/vision_cuda.h" +#if defined(WITH_CUDA) || defined(WITH_HIP) +#include "autocast.h" #endif -std::tuple ROIPool_forward( +// TODO: put this stuff in torchvision namespace + +std::tuple roi_pool( const at::Tensor& input, const at::Tensor& rois, const double spatial_scale, const int64_t pooled_height, const int64_t pooled_width) { - if (input.is_cuda()) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::roi_pool", "") + .typed(); + return op.call(input, rois, spatial_scale, pooled_height, pooled_width); +} + #if defined(WITH_CUDA) || defined(WITH_HIP) - return ROIPool_forward_cuda( - input, rois, spatial_scale, pooled_height, pooled_width); -#else - TORCH_CHECK(false, "Not compiled with GPU support"); -#endif - } - return ROIPool_forward_cpu( - input, rois, spatial_scale, pooled_height, pooled_width); +std::tuple ROIPool_autocast( + const at::Tensor& input, + const at::Tensor& rois, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + auto result = roi_pool( + at::autocast::cached_cast(at::kFloat, input), + at::autocast::cached_cast(at::kFloat, rois), + spatial_scale, + pooled_height, + pooled_width); + + return std::make_tuple( + std::get<0>(result).to(input.scalar_type()), + std::get<1>(result).to(input.scalar_type())); } +#endif -at::Tensor ROIPool_backward( +at::Tensor _roi_pool_backward( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& argmax, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width) { - if (grad.is_cuda()) { -#if defined(WITH_CUDA) || defined(WITH_HIP) - return ROIPool_backward_cuda( - grad, - rois, - argmax, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width); -#else - TORCH_CHECK(false, "Not compiled with GPU support"); -#endif - } - return ROIPool_backward_cpu( + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_roi_pool_backward", "") + .typed(); + return op.call( grad, rois, argmax, @@ -81,8 +81,9 @@ class ROIPoolFunction : public torch::autograd::Function { ctx->saved_data["pooled_height"] = pooled_height; ctx->saved_data["pooled_width"] = pooled_width; ctx->saved_data["input_shape"] = input.sizes(); - auto result = ROIPool_forward( - input, rois, spatial_scale, pooled_height, pooled_width); + at::AutoNonVariableTypeMode g; + auto result = + roi_pool(input, rois, spatial_scale, pooled_height, pooled_width); auto output = std::get<0>(result); auto argmax = std::get<1>(result); ctx->save_for_backward({rois, argmax}); @@ -98,7 +99,7 @@ class ROIPoolFunction : public torch::autograd::Function { auto rois = saved[0]; auto argmax = saved[1]; auto input_shape = ctx->saved_data["input_shape"].toIntList(); - auto grad_in = ROIPool_backward( + auto grad_in = _roi_pool_backward( grad_output[0], rois, argmax, @@ -117,7 +118,46 @@ class ROIPoolFunction : public torch::autograd::Function { } }; -std::tuple roi_pool( +// TODO: There should be an easier way to do this +class ROIPoolBackwardFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + torch::autograd::Variable grad, + torch::autograd::Variable rois, + torch::autograd::Variable argmax, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width) { + at::AutoNonVariableTypeMode g; + auto grad_in = _roi_pool_backward( + grad, + rois, + argmax, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); + + return {grad_in}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + TORCH_CHECK(0, "double backwards on roi_pool not supported"); + } +}; + +std::tuple ROIPool_autograd( const at::Tensor& input, const at::Tensor& rois, const double spatial_scale, @@ -125,5 +165,30 @@ std::tuple roi_pool( const int64_t pooled_width) { auto result = ROIPoolFunction::apply( input, rois, spatial_scale, pooled_height, pooled_width); - return std::tuple(result[0], result[1]); + + return std::make_tuple(result[0], result[1]); } + +at::Tensor ROIPool_backward_autograd( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width) { + return ROIPoolBackwardFunction::apply( + grad, + rois, + argmax, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width)[0]; +} \ No newline at end of file diff --git a/torchvision/csrc/cpu/ROIPool_cpu.cpp b/torchvision/csrc/cpu/ROIPool_cpu.cpp index b13f1de6646..af83257a3cc 100644 --- a/torchvision/csrc/cpu/ROIPool_cpu.cpp +++ b/torchvision/csrc/cpu/ROIPool_cpu.cpp @@ -123,9 +123,9 @@ void RoIPoolBackward( std::tuple ROIPool_forward_cpu( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width) { TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); @@ -172,13 +172,13 @@ at::Tensor ROIPool_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& argmax, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width) { // Check if input tensors are CPU tensors TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index 69b1bbf555d..cae700f8521 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -5,21 +5,21 @@ VISION_API std::tuple ROIPool_forward_cpu( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width); + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width); VISION_API at::Tensor ROIPool_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& argmax, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width); + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width); VISION_API at::Tensor ROIAlign_forward_cpu( const at::Tensor& input, diff --git a/torchvision/csrc/cuda/ROIPool_cuda.cu b/torchvision/csrc/cuda/ROIPool_cuda.cu index a35dabbeb39..d62fcabed21 100644 --- a/torchvision/csrc/cuda/ROIPool_cuda.cu +++ b/torchvision/csrc/cuda/ROIPool_cuda.cu @@ -118,9 +118,9 @@ __global__ void RoIPoolBackward( std::tuple ROIPool_forward_cuda( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width) { TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); TORCH_CHECK( @@ -182,13 +182,13 @@ at::Tensor ROIPool_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& argmax, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width) { // Check if input tensors are CUDA tensors TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor"); TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index 2481cfc63c2..af82e40d2a4 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -27,21 +27,21 @@ VISION_API at::Tensor ROIAlign_backward_cuda( VISION_API std::tuple ROIPool_forward_cuda( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width); + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width); VISION_API at::Tensor ROIPool_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& argmax, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width); + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width); VISION_API std::tuple PSROIPool_forward_cuda( const at::Tensor& input, diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index f56a671d6e5..eea1cf2ec9c 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -50,7 +50,10 @@ TORCH_LIBRARY(torchvision, m) { "roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor"); m.def( "_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor"); - m.def("roi_pool", &roi_pool); + m.def( + "roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"); + m.def( + "_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"); m.def("_new_empty_tensor_op", &new_empty_tensor); m.def("ps_roi_align", &ps_roi_align); m.def("ps_roi_pool", &ps_roi_pool); @@ -64,6 +67,8 @@ TORCH_LIBRARY(torchvision, m) { TORCH_LIBRARY_IMPL(torchvision, CPU, m) { m.impl("roi_align", ROIAlign_forward_cpu); m.impl("_roi_align_backward", ROIAlign_backward_cpu); + m.impl("roi_pool", ROIPool_forward_cpu); + m.impl("_roi_pool_backward", ROIPool_backward_cpu); m.impl("deform_conv2d", DeformConv2d_forward_cpu); m.impl("_deform_conv2d_backward", DeformConv2d_backward_cpu); m.impl("nms", nms_cpu); @@ -74,6 +79,8 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) { TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { m.impl("roi_align", ROIAlign_forward_cuda); m.impl("_roi_align_backward", ROIAlign_backward_cuda); + m.impl("roi_pool", ROIPool_forward_cuda); + m.impl("_roi_pool_backward", ROIPool_backward_cuda); m.impl("deform_conv2d", DeformConv2d_forward_cuda); m.impl("_deform_conv2d_backward", DeformConv2d_backward_cuda); m.impl("nms", nms_cuda); @@ -84,6 +91,7 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { #if defined(WITH_CUDA) || defined(WITH_HIP) TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { m.impl("roi_align", ROIAlign_autocast); + m.impl("roi_pool", ROIPool_autocast); m.impl("deform_conv2d", DeformConv2d_autocast); m.impl("nms", nms_autocast); } @@ -92,6 +100,8 @@ TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { m.impl("roi_align", ROIAlign_autograd); m.impl("_roi_align_backward", ROIAlign_backward_autograd); + m.impl("roi_pool", ROIPool_autograd); + m.impl("_roi_pool_backward", ROIPool_backward_autograd); m.impl("deform_conv2d", DeformConv2d_autograd); m.impl("_deform_conv2d_backward", DeformConv2d_backward_autograd); }