diff --git a/torchvision/csrc/PSROIAlign.h b/torchvision/csrc/PSROIAlign.h index 92f4390a0f9..1e5dd17aabc 100644 --- a/torchvision/csrc/PSROIAlign.h +++ b/torchvision/csrc/PSROIAlign.h @@ -223,4 +223,4 @@ at::Tensor PSROIAlign_backward_autograd( channels, height, width)[0]; -} \ No newline at end of file +} diff --git a/torchvision/csrc/PSROIPool.h b/torchvision/csrc/PSROIPool.h index 2adb3aa4196..c3ced9e7842 100644 --- a/torchvision/csrc/PSROIPool.h +++ b/torchvision/csrc/PSROIPool.h @@ -3,62 +3,68 @@ #include "cpu/vision_cpu.h" #ifdef WITH_CUDA +#include "autocast.h" #include "cuda/vision_cuda.h" #endif #ifdef WITH_HIP +#include "autocast.h" #include "hip/vision_cuda.h" #endif -std::tuple PSROIPool_forward( +// TODO: put this stuff in torchvision namespace + +std::tuple ps_roi_pool( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width) { - if (input.is_cuda()) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::ps_roi_pool", "") + .typed(); + return op.call(input, rois, spatial_scale, pooled_height, pooled_width); +} + #if defined(WITH_CUDA) || defined(WITH_HIP) - return PSROIPool_forward_cuda( - input, rois, spatial_scale, pooled_height, pooled_width); -#else - TORCH_CHECK(false, "Not compiled with GPU support"); -#endif - } - return PSROIPool_forward_cpu( - input, rois, spatial_scale, pooled_height, pooled_width); +std::tuple PSROIPool_autocast( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + auto result = ps_roi_pool( + at::autocast::cached_cast(at::kFloat, input), + at::autocast::cached_cast(at::kFloat, rois), + spatial_scale, + pooled_height, + pooled_width); + + return std::make_tuple( + std::get<0>(result).to(input.scalar_type()), + std::get<1>(result).to(input.scalar_type())); } +#endif -at::Tensor PSROIPool_backward( +at::Tensor _ps_roi_pool_backward( const at::Tensor& grad, const at::Tensor& rois, - const at::Tensor& mapping_channel, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width) { - if (grad.is_cuda()) { -#if defined(WITH_CUDA) || defined(WITH_HIP) - return PSROIPool_backward_cuda( - grad, - rois, - mapping_channel, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width); -#else - TORCH_CHECK(false, "Not compiled with GPU support"); -#endif - } - return PSROIPool_backward_cpu( + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_ps_roi_pool_backward", "") + .typed(); + return op.call( grad, rois, - mapping_channel, + channel_mapping, spatial_scale, pooled_height, pooled_width, @@ -72,33 +78,36 @@ class PSROIPoolFunction : public torch::autograd::Function { public: static torch::autograd::variable_list forward( torch::autograd::AutogradContext* ctx, - torch::autograd::Variable input, - torch::autograd::Variable rois, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width) { + const torch::autograd::Variable& input, + const torch::autograd::Variable& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { ctx->saved_data["spatial_scale"] = spatial_scale; ctx->saved_data["pooled_height"] = pooled_height; ctx->saved_data["pooled_width"] = pooled_width; ctx->saved_data["input_shape"] = input.sizes(); - auto result = PSROIPool_forward( - input, rois, spatial_scale, pooled_height, pooled_width); + at::AutoNonVariableTypeMode g; + auto result = + ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width); + auto output = std::get<0>(result); auto channel_mapping = std::get<1>(result); ctx->save_for_backward({rois, channel_mapping}); ctx->mark_non_differentiable({channel_mapping}); + return {output, channel_mapping}; } static torch::autograd::variable_list backward( torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_output) { + const torch::autograd::variable_list& grad_output) { // Use data saved in forward auto saved = ctx->get_saved_variables(); auto rois = saved[0]; auto channel_mapping = saved[1]; auto input_shape = ctx->saved_data["input_shape"].toIntList(); - auto grad_in = PSROIPool_backward( + auto grad_in = _ps_roi_pool_backward( grad_output[0], rois, channel_mapping, @@ -109,6 +118,7 @@ class PSROIPoolFunction : public torch::autograd::Function { input_shape[1], input_shape[2], input_shape[3]); + return {grad_in, torch::autograd::Variable(), torch::autograd::Variable(), @@ -117,13 +127,77 @@ class PSROIPoolFunction : public torch::autograd::Function { } }; -std::tuple ps_roi_pool( +// TODO: There should be an easier way to do this +class PSROIPoolBackwardFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& grad, + const torch::autograd::Variable& rois, + const torch::autograd::Variable& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + at::AutoNonVariableTypeMode g; + auto grad_in = _ps_roi_pool_backward( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); + + return {grad_in}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + TORCH_CHECK(0, "double backwards on ps_roi_pool not supported"); + } +}; + +std::tuple PSROIPool_autograd( const at::Tensor& input, const at::Tensor& rois, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { auto result = PSROIPoolFunction::apply( input, rois, spatial_scale, pooled_height, pooled_width); - return std::tuple(result[0], result[1]); + + return std::make_tuple(result[0], result[1]); +} + +at::Tensor PSROIPool_backward_autograd( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + return PSROIPoolBackwardFunction::apply( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width)[0]; } diff --git a/torchvision/csrc/cpu/PSROIPool_cpu.cpp b/torchvision/csrc/cpu/PSROIPool_cpu.cpp index 357e08cd3fa..c6e0a64cac3 100644 --- a/torchvision/csrc/cpu/PSROIPool_cpu.cpp +++ b/torchvision/csrc/cpu/PSROIPool_cpu.cpp @@ -12,14 +12,14 @@ template void PSROIPoolForward( const T* input, const T spatial_scale, - const int channels, - const int height, - const int width, - const int pooled_height, - const int pooled_width, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, const T* rois, - const int channels_out, - const int num_rois, + int channels_out, + int num_rois, T* output, int* channel_mapping) { for (int n = 0; n < num_rois; ++n) { @@ -82,14 +82,14 @@ template void PSROIPoolBackward( const T* grad_output, const int* channel_mapping, - const int num_rois, + int num_rois, const T spatial_scale, - const int channels, - const int height, - const int width, - const int pooled_height, - const int pooled_width, - const int channels_out, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int channels_out, T* grad_input, const T* rois) { for (int n = 0; n < num_rois; ++n) { @@ -146,9 +146,9 @@ void PSROIPoolBackward( std::tuple PSROIPool_forward_cpu( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { // Check if input tensors are CPU tensors TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); @@ -204,13 +204,13 @@ at::Tensor PSROIPool_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { // Check if input tensors are CPU tensors TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index 0d09bf72715..14def9d324f 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -2,17 +2,73 @@ #include #include "../macros.h" -VISION_API std::tuple ROIPool_forward_cpu( +VISION_API at::Tensor DeformConv2d_forward_cpu( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t deformable_groups); + +VISION_API std::tuple +DeformConv2d_backward_cpu( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t deformable_groups); + +VISION_API at::Tensor nms_cpu( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold); + +VISION_API std::tuple PSROIAlign_forward_cpu( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio); + +VISION_API at::Tensor PSROIAlign_backward_cpu( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); + +VISION_API std::tuple PSROIPool_forward_cpu( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width); -VISION_API at::Tensor ROIPool_backward_cpu( +VISION_API at::Tensor PSROIPool_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, - const at::Tensor& argmax, + const at::Tensor& channel_mapping, double spatial_scale, int64_t pooled_height, int64_t pooled_width, @@ -43,77 +99,21 @@ VISION_API at::Tensor ROIAlign_backward_cpu( int64_t sampling_ratio, bool aligned); -VISION_API std::tuple PSROIPool_forward_cpu( - const at::Tensor& input, - const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width); - -VISION_API at::Tensor PSROIPool_backward_cpu( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& mapping_channel, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width); - -VISION_API std::tuple PSROIAlign_forward_cpu( +VISION_API std::tuple ROIPool_forward_cpu( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio); + int64_t pooled_width); -VISION_API at::Tensor PSROIAlign_backward_cpu( +VISION_API at::Tensor ROIPool_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, - const at::Tensor& channel_mapping, + const at::Tensor& argmax, double spatial_scale, int64_t pooled_height, int64_t pooled_width, - int64_t sampling_ratio, int64_t batch_size, int64_t channels, int64_t height, int64_t width); - -VISION_API at::Tensor nms_cpu( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold); - -VISION_API at::Tensor DeformConv2d_forward_cpu( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t deformable_groups); - -VISION_API std::tuple -DeformConv2d_backward_cpu( - const at::Tensor& grad_out, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t deformable_groups); diff --git a/torchvision/csrc/cuda/PSROIPool_cuda.cu b/torchvision/csrc/cuda/PSROIPool_cuda.cu index d880020d8bb..ab6a50b009c 100644 --- a/torchvision/csrc/cuda/PSROIPool_cuda.cu +++ b/torchvision/csrc/cuda/PSROIPool_cuda.cu @@ -8,16 +8,16 @@ template __global__ void PSROIPoolForward( - const int nthreads, + int nthreads, const T* input, const T spatial_scale, - const int channels, - const int height, - const int width, - const int pooled_height, - const int pooled_width, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, const T* rois, - const int channels_out, + int channels_out, T* output, int* channel_mapping) { CUDA_1D_KERNEL_LOOP(index, nthreads) { @@ -74,17 +74,17 @@ __global__ void PSROIPoolForward( template __global__ void PSROIPoolBackward( - const int nthreads, + int nthreads, const T* grad_output, const int* channel_mapping, - const int num_rois, + int num_rois, const T spatial_scale, - const int channels, - const int height, - const int width, - const int pooled_height, - const int pooled_width, - const int channels_out, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int channels_out, T* grad_input, const T* rois) { CUDA_1D_KERNEL_LOOP(index, nthreads) { @@ -135,9 +135,9 @@ __global__ void PSROIPoolBackward( std::tuple PSROIPool_forward_cuda( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { // Check if input tensors are CUDA tensors TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); @@ -206,13 +206,13 @@ at::Tensor PSROIPool_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width) { + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { // Check if input tensors are CUDA tensors TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor"); TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index bc1a26731e8..731d119cf75 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -2,118 +2,118 @@ #include #include "../macros.h" -VISION_API at::Tensor ROIAlign_forward_cuda( +VISION_API at::Tensor DeformConv2d_forward_cuda( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t deformable_groups); + +VISION_API std::tuple +DeformConv2d_backward_cuda( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t deformable_groups); + +VISION_API at::Tensor nms_cuda( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold); + +VISION_API std::tuple PSROIAlign_forward_cuda( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width, - int64_t sampling_ratio, - bool aligned); + int64_t sampling_ratio); -VISION_API at::Tensor ROIAlign_backward_cuda( +VISION_API at::Tensor PSROIAlign_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, + const at::Tensor& channel_mapping, double spatial_scale, int64_t pooled_height, int64_t pooled_width, + int64_t sampling_ratio, int64_t batch_size, int64_t channels, int64_t height, - int64_t width, - int64_t sampling_ratio, - bool aligned); - -VISION_API std::tuple ROIPool_forward_cuda( - const at::Tensor& input, - const at::Tensor& rois, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width); - -VISION_API at::Tensor ROIPool_backward_cuda( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t batch_size, - const int64_t channels, - const int64_t height, - const int64_t width); + int64_t width); VISION_API std::tuple PSROIPool_forward_cuda( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width); + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width); VISION_API at::Tensor PSROIPool_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, - const at::Tensor& mapping_channel, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width); + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); -VISION_API std::tuple PSROIAlign_forward_cuda( +VISION_API at::Tensor ROIAlign_forward_cuda( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width, - int64_t sampling_ratio); + int64_t sampling_ratio, + bool aligned); -VISION_API at::Tensor PSROIAlign_backward_cuda( +VISION_API at::Tensor ROIAlign_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, - const at::Tensor& channel_mapping, double spatial_scale, int64_t pooled_height, int64_t pooled_width, - int64_t sampling_ratio, int64_t batch_size, int64_t channels, int64_t height, - int64_t width); - -VISION_API at::Tensor nms_cuda( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold); + int64_t width, + int64_t sampling_ratio, + bool aligned); -VISION_API at::Tensor DeformConv2d_forward_cuda( +VISION_API std::tuple ROIPool_forward_cuda( const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t deformable_groups); + const at::Tensor& rois, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width); -VISION_API std::tuple -DeformConv2d_backward_cuda( - const at::Tensor& grad_out, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t deformable_groups); +VISION_API at::Tensor ROIPool_backward_cuda( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width); diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index c764d63cf8c..bd9a770473d 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -45,73 +45,83 @@ int64_t cuda_version() noexcept { } // namespace vision TORCH_LIBRARY(torchvision, m) { + m.def( + "deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> Tensor"); + m.def( + "_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> (Tensor, Tensor, Tensor, Tensor)"); m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"); m.def( - "roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor"); + "ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)"); m.def( - "_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor"); + "_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor"); m.def( - "roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"); + "ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"); m.def( - "_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"); - m.def("_new_empty_tensor_op", &new_empty_tensor); + "_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"); m.def( - "ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)"); + "roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor"); m.def( - "_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor"); - m.def("ps_roi_pool", &ps_roi_pool); + "_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor"); m.def( - "deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> Tensor"); + "roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"); m.def( - "_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> (Tensor, Tensor, Tensor, Tensor)"); + "_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"); m.def("_cuda_version", &vision::cuda_version); + m.def("_new_empty_tensor_op", &new_empty_tensor); } TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl("roi_align", ROIAlign_forward_cpu); - m.impl("_roi_align_backward", ROIAlign_backward_cpu); - m.impl("roi_pool", ROIPool_forward_cpu); - m.impl("_roi_pool_backward", ROIPool_backward_cpu); m.impl("deform_conv2d", DeformConv2d_forward_cpu); m.impl("_deform_conv2d_backward", DeformConv2d_backward_cpu); m.impl("nms", nms_cpu); m.impl("ps_roi_align", PSROIAlign_forward_cpu); m.impl("_ps_roi_align_backward", PSROIAlign_backward_cpu); + m.impl("ps_roi_pool", PSROIPool_forward_cpu); + m.impl("_ps_roi_pool_backward", PSROIPool_backward_cpu); + m.impl("roi_align", ROIAlign_forward_cpu); + m.impl("_roi_align_backward", ROIAlign_backward_cpu); + m.impl("roi_pool", ROIPool_forward_cpu); + m.impl("_roi_pool_backward", ROIPool_backward_cpu); } // TODO: Place this in a hypothetical separate torchvision_cuda library #if defined(WITH_CUDA) || defined(WITH_HIP) TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { - m.impl("roi_align", ROIAlign_forward_cuda); - m.impl("_roi_align_backward", ROIAlign_backward_cuda); - m.impl("roi_pool", ROIPool_forward_cuda); - m.impl("_roi_pool_backward", ROIPool_backward_cuda); m.impl("deform_conv2d", DeformConv2d_forward_cuda); m.impl("_deform_conv2d_backward", DeformConv2d_backward_cuda); m.impl("nms", nms_cuda); m.impl("ps_roi_align", PSROIAlign_forward_cuda); m.impl("_ps_roi_align_backward", PSROIAlign_backward_cuda); + m.impl("ps_roi_pool", PSROIPool_forward_cuda); + m.impl("_ps_roi_pool_backward", PSROIPool_backward_cuda); + m.impl("roi_align", ROIAlign_forward_cuda); + m.impl("_roi_align_backward", ROIAlign_backward_cuda); + m.impl("roi_pool", ROIPool_forward_cuda); + m.impl("_roi_pool_backward", ROIPool_backward_cuda); } #endif // Autocast only needs to wrap forward pass ops. #if defined(WITH_CUDA) || defined(WITH_HIP) TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { - m.impl("roi_align", ROIAlign_autocast); - m.impl("roi_pool", ROIPool_autocast); m.impl("deform_conv2d", DeformConv2d_autocast); m.impl("nms", nms_autocast); m.impl("ps_roi_align", PSROIAlign_autocast); + m.impl("ps_roi_pool", PSROIPool_autocast); + m.impl("roi_align", ROIAlign_autocast); + m.impl("roi_pool", ROIPool_autocast); } #endif TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { - m.impl("roi_align", ROIAlign_autograd); - m.impl("_roi_align_backward", ROIAlign_backward_autograd); - m.impl("roi_pool", ROIPool_autograd); - m.impl("_roi_pool_backward", ROIPool_backward_autograd); m.impl("deform_conv2d", DeformConv2d_autograd); m.impl("_deform_conv2d_backward", DeformConv2d_backward_autograd); m.impl("ps_roi_align", PSROIAlign_autograd); m.impl("_ps_roi_align_backward", PSROIAlign_backward_autograd); + m.impl("ps_roi_pool", PSROIPool_autograd); + m.impl("_ps_roi_pool_backward", PSROIPool_backward_autograd); + m.impl("roi_align", ROIAlign_autograd); + m.impl("_roi_align_backward", ROIAlign_backward_autograd); + m.impl("roi_pool", ROIPool_autograd); + m.impl("_roi_pool_backward", ROIPool_backward_autograd); }