diff --git a/torchvision/csrc/PSROIPool.h b/torchvision/csrc/PSROIPool.h index 2adb3aa4196..f075b957263 100644 --- a/torchvision/csrc/PSROIPool.h +++ b/torchvision/csrc/PSROIPool.h @@ -2,63 +2,64 @@ #include "cpu/vision_cpu.h" -#ifdef WITH_CUDA -#include "cuda/vision_cuda.h" -#endif -#ifdef WITH_HIP -#include "hip/vision_cuda.h" +#if defined(WITH_CUDA) || defined(WITH_HIP) +#include "autocast.h" #endif -std::tuple PSROIPool_forward( +// TODO: put this stuff in torchvision namespace + +std::tuple ps_roi_pool( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width) { - if (input.is_cuda()) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::ps_roi_pool", "") + .typed(); + return op.call(input, rois, spatial_scale, pooled_height, pooled_width); +} + #if defined(WITH_CUDA) || defined(WITH_HIP) - return PSROIPool_forward_cuda( - input, rois, spatial_scale, pooled_height, pooled_width); -#else - TORCH_CHECK(false, "Not compiled with GPU support"); -#endif - } - return PSROIPool_forward_cpu( - input, rois, spatial_scale, pooled_height, pooled_width); +std::tuple PSROIPool_autocast( + const at::Tensor& input, + const at::Tensor& rois, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + auto result = ps_roi_pool( + at::autocast::cached_cast(at::kFloat, input), + at::autocast::cached_cast(at::kFloat, rois), + spatial_scale, + pooled_height, + pooled_width); + + return std::make_tuple( + std::get<0>(result).to(input.scalar_type()), + std::get<1>(result).to(input.scalar_type())); } +#endif -at::Tensor PSROIPool_backward( +at::Tensor _ps_roi_pool_backward( const at::Tensor& grad, const at::Tensor& rois, - const at::Tensor& mapping_channel, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width) { - if (grad.is_cuda()) { -#if defined(WITH_CUDA) || defined(WITH_HIP) - return PSROIPool_backward_cuda( - grad, - rois, - mapping_channel, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width); -#else - TORCH_CHECK(false, "Not compiled with GPU support"); -#endif - } - return PSROIPool_backward_cpu( + const at::Tensor& channel_mapping, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_ps_roi_pool_backward", "") + .typed(); + return op.call( grad, rois, - mapping_channel, + channel_mapping, spatial_scale, pooled_height, pooled_width, @@ -81,8 +82,9 @@ class PSROIPoolFunction : public torch::autograd::Function { ctx->saved_data["pooled_height"] = pooled_height; ctx->saved_data["pooled_width"] = pooled_width; ctx->saved_data["input_shape"] = input.sizes(); - auto result = PSROIPool_forward( - input, rois, spatial_scale, pooled_height, pooled_width); + at::AutoNonVariableTypeMode g; + auto result = + ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width); auto output = std::get<0>(result); auto channel_mapping = std::get<1>(result); ctx->save_for_backward({rois, channel_mapping}); @@ -98,7 +100,7 @@ class PSROIPoolFunction : public torch::autograd::Function { auto rois = saved[0]; auto channel_mapping = saved[1]; auto input_shape = ctx->saved_data["input_shape"].toIntList(); - auto grad_in = PSROIPool_backward( + auto grad_in = _ps_roi_pool_backward( grad_output[0], rois, channel_mapping, @@ -117,7 +119,46 @@ class PSROIPoolFunction : public torch::autograd::Function { } }; -std::tuple ps_roi_pool( +// TODO: There should be an easier way to do this +class PSROIPoolBackwardFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + torch::autograd::Variable grad, + torch::autograd::Variable rois, + torch::autograd::Variable channel_mapping, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width) { + at::AutoNonVariableTypeMode g; + auto grad_in = _ps_roi_pool_backward( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); + + return {grad_in}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + TORCH_CHECK(0, "double backwards on ps_roi_pool not supported"); + } +}; + +std::tuple PSROIPool_autograd( const at::Tensor& input, const at::Tensor& rois, const double spatial_scale, @@ -125,5 +166,30 @@ std::tuple ps_roi_pool( const int64_t pooled_width) { auto result = PSROIPoolFunction::apply( input, rois, spatial_scale, pooled_height, pooled_width); - return std::tuple(result[0], result[1]); + + return std::make_tuple(result[0], result[1]); } + +at::Tensor PSROIPool_backward_autograd( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width) { + return PSROIPoolBackwardFunction::apply( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width)[0]; +} \ No newline at end of file diff --git a/torchvision/csrc/cpu/PSROIPool_cpu.cpp b/torchvision/csrc/cpu/PSROIPool_cpu.cpp index 357e08cd3fa..eefd1cf1aba 100644 --- a/torchvision/csrc/cpu/PSROIPool_cpu.cpp +++ b/torchvision/csrc/cpu/PSROIPool_cpu.cpp @@ -146,9 +146,9 @@ void PSROIPoolBackward( std::tuple PSROIPool_forward_cpu( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width) { // Check if input tensors are CPU tensors TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); @@ -204,13 +204,13 @@ at::Tensor PSROIPool_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width) { // Check if input tensors are CPU tensors TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index 69b1bbf555d..926e76c4583 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -46,21 +46,21 @@ VISION_API at::Tensor ROIAlign_backward_cpu( VISION_API std::tuple PSROIPool_forward_cpu( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width); + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width); VISION_API at::Tensor PSROIPool_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& mapping_channel, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width); + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width); VISION_API std::tuple PSROIAlign_forward_cpu( const at::Tensor& input, diff --git a/torchvision/csrc/cuda/PSROIPool_cuda.cu b/torchvision/csrc/cuda/PSROIPool_cuda.cu index d880020d8bb..9ab46e93f06 100644 --- a/torchvision/csrc/cuda/PSROIPool_cuda.cu +++ b/torchvision/csrc/cuda/PSROIPool_cuda.cu @@ -135,9 +135,9 @@ __global__ void PSROIPoolBackward( std::tuple PSROIPool_forward_cuda( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width) { // Check if input tensors are CUDA tensors TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); @@ -206,13 +206,13 @@ at::Tensor PSROIPool_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width) { + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width) { // Check if input tensors are CUDA tensors TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor"); TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index 2481cfc63c2..1af9db0a5f7 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -46,21 +46,21 @@ VISION_API at::Tensor ROIPool_backward_cuda( VISION_API std::tuple PSROIPool_forward_cuda( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width); + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width); VISION_API at::Tensor PSROIPool_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& mapping_channel, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width); + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width); VISION_API std::tuple PSROIAlign_forward_cuda( const at::Tensor& input, diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index f56a671d6e5..b4cd3b610fd 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -53,7 +53,10 @@ TORCH_LIBRARY(torchvision, m) { m.def("roi_pool", &roi_pool); m.def("_new_empty_tensor_op", &new_empty_tensor); m.def("ps_roi_align", &ps_roi_align); - m.def("ps_roi_pool", &ps_roi_pool); + m.def( + "ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"); + m.def( + "_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor mapping_channel, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"); m.def( "deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> Tensor"); m.def( @@ -66,6 +69,8 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) { m.impl("_roi_align_backward", ROIAlign_backward_cpu); m.impl("deform_conv2d", DeformConv2d_forward_cpu); m.impl("_deform_conv2d_backward", DeformConv2d_backward_cpu); + m.impl("ps_roi_pool", PSROIPool_forward_cpu); + m.impl("_ps_roi_pool_backward", PSROIPool_backward_cpu); m.impl("nms", nms_cpu); } @@ -76,6 +81,8 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { m.impl("_roi_align_backward", ROIAlign_backward_cuda); m.impl("deform_conv2d", DeformConv2d_forward_cuda); m.impl("_deform_conv2d_backward", DeformConv2d_backward_cuda); + m.impl("ps_roi_pool", PSROIPool_forward_cuda); + m.impl("_ps_roi_pool_backward", PSROIPool_backward_cuda); m.impl("nms", nms_cuda); } #endif @@ -85,6 +92,7 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { m.impl("roi_align", ROIAlign_autocast); m.impl("deform_conv2d", DeformConv2d_autocast); + m.impl("ps_roi_pool", PSROIPool_autocast); m.impl("nms", nms_autocast); } #endif @@ -94,4 +102,6 @@ TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { m.impl("_roi_align_backward", ROIAlign_backward_autograd); m.impl("deform_conv2d", DeformConv2d_autograd); m.impl("_deform_conv2d_backward", DeformConv2d_backward_autograd); + m.impl("ps_roi_pool", PSROIPool_autograd); + m.impl("_ps_roi_pool_backward", PSROIPool_backward_autograd); }