diff --git a/torchvision/csrc/DeformConv.h b/torchvision/csrc/DeformConv.h index 3d5636eb2c6..2712cdc9492 100644 --- a/torchvision/csrc/DeformConv.h +++ b/torchvision/csrc/DeformConv.h @@ -1,89 +1,105 @@ #pragma once -#include "cpu/vision_cpu.h" - -#ifdef WITH_CUDA -#include "cuda/vision_cuda.h" -#endif -#ifdef WITH_HIP -#include "hip/vision_cuda.h" +#if defined(WITH_CUDA) || defined(WITH_HIP) +#include "autocast.h" #endif -at::Tensor DeformConv2d_forward( +// TODO: put this stuff in torchvision namespace + +at::Tensor deform_conv2d( const at::Tensor& input, const at::Tensor& weight, const at::Tensor& offset, const at::Tensor& bias, - const std::pair& stride, - const std::pair& padding, - const std::pair& dilation, - const int groups, - const int offset_groups) { - if (input.is_cuda()) { -#if defined(WITH_CUDA) || defined(WITH_HIP) - return DeformConv2d_forward_cuda( - input.contiguous(), - weight.contiguous(), - offset.contiguous(), - bias.contiguous(), - stride, - padding, - dilation, - groups, - offset_groups); -#else - TORCH_CHECK(false, "Not compiled with GPU support"); -#endif - } - return DeformConv2d_forward_cpu( - input.contiguous(), - weight.contiguous(), - offset.contiguous(), - bias.contiguous(), - stride, - padding, - dilation, + const int64_t stride_h, + const int64_t stride_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t dilation_h, + const int64_t dilation_w, + const int64_t groups, + const int64_t offset_groups) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::deform_conv2d", "") + .typed(); + return op.call( + input, + weight, + offset, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, groups, offset_groups); } -std::tuple DeformConv2d_backward( - const at::Tensor& grad, +#if defined(WITH_CUDA) || defined(WITH_HIP) +at::Tensor DeformConv2d_autocast( const at::Tensor& input, const at::Tensor& weight, const at::Tensor& offset, const at::Tensor& bias, - const std::pair& stride, - const std::pair& padding, - const std::pair& dilation, - const int groups, - const int offset_groups) { - if (grad.is_cuda()) { -#if defined(WITH_CUDA) || defined(WITH_HIP) - return DeformConv2d_backward_cuda( - grad.contiguous(), - input.contiguous(), - weight.contiguous(), - offset.contiguous(), - bias.contiguous(), - stride, - padding, - dilation, - groups, - offset_groups); -#else - TORCH_CHECK(false, "Not compiled with GPU support"); + const int64_t stride_h, + const int64_t stride_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t dilation_h, + const int64_t dilation_w, + const int64_t groups, + const int64_t offset_groups) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + return deform_conv2d( + at::autocast::cached_cast(at::kFloat, input), + at::autocast::cached_cast(at::kFloat, weight), + at::autocast::cached_cast(at::kFloat, offset), + at::autocast::cached_cast(at::kFloat, bias), + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups) + .to(input.scalar_type()); +} #endif - } - return DeformConv2d_backward_cpu( - grad.contiguous(), - input.contiguous(), - weight.contiguous(), - offset.contiguous(), - bias.contiguous(), - stride, - padding, - dilation, + +std::tuple +_deform_conv2d_backward( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& bias, + const int64_t stride_h, + const int64_t stride_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t dilation_h, + const int64_t dilation_w, + const int64_t groups, + const int64_t offset_groups) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "") + .typed(); + return op.call( + grad, + input, + weight, + offset, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, groups, offset_groups); } @@ -105,14 +121,18 @@ class DeformConv2dFunction int64_t dilation_w, int64_t groups, int64_t offset_groups) { - auto output = DeformConv2d_forward( + at::AutoNonVariableTypeMode g; // TODO_vv: check if necessary + auto output = deform_conv2d( input, weight, offset, bias, - {stride_h, stride_w}, - {pad_h, pad_w}, - {dilation_h, dilation_w}, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, groups, offset_groups); @@ -149,15 +169,18 @@ class DeformConv2dFunction auto groups = ctx->saved_data["groups"].toInt(); auto offset_groups = ctx->saved_data["offset_groups"].toInt(); - auto grads = DeformConv2d_backward( + auto grads = _deform_conv2d_backward( grad_output[0], input, weight, offset, bias, - {stride_h, stride_w}, - {pad_h, pad_w}, - {dilation_h, dilation_w}, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, groups, offset_groups); auto grad_input = std::get<0>(grads); @@ -182,20 +205,75 @@ class DeformConv2dFunction } }; -at::Tensor deform_conv2d( +// TODO: There should be an easier way to do this +class DeformConv2dBackwardFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + torch::autograd::Variable grad, + torch::autograd::Variable input, + torch::autograd::Variable weight, + torch::autograd::Variable offset, + torch::autograd::Variable bias, + const int64_t stride_h, + const int64_t stride_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t dilation_h, + const int64_t dilation_w, + const int64_t groups, + const int64_t offset_groups) { + at::AutoNonVariableTypeMode g; + auto result = _deform_conv2d_backward( + grad, + input, + weight, + offset, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups); + + auto grad_input = std::get<0>(result); + auto grad_weight = std::get<1>(result); + auto grad_offset = std::get<2>(result); + auto grad_bias = std::get<3>(result); + + return { + grad_input, + grad_weight, + grad_offset, + grad_bias, + }; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + TORCH_CHECK(0, "double backwards on deform_conv2d not supported"); + } +}; + +at::Tensor DeformConv2d_autograd( const at::Tensor& input, const at::Tensor& weight, const at::Tensor& offset, const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t offset_groups) { - auto result = DeformConv2dFunction::apply( + const int64_t stride_h, + const int64_t stride_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t dilation_h, + const int64_t dilation_w, + const int64_t groups, + const int64_t offset_groups) { + return DeformConv2dFunction::apply( input, weight, offset, @@ -207,6 +285,37 @@ at::Tensor deform_conv2d( dilation_h, dilation_w, groups, - offset_groups); - return result[0]; + offset_groups)[0]; } + +std::tuple +DeformConv2d_backward_autograd( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& bias, + const int64_t stride_h, + const int64_t stride_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t dilation_h, + const int64_t dilation_w, + const int64_t groups, + const int64_t offset_groups) { + auto result = DeformConv2dBackwardFunction::apply( + grad, + input, + weight, + offset, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups); + return std::make_tuple(result[0], result[1], result[2], result[3]); +} \ No newline at end of file diff --git a/torchvision/csrc/cpu/DeformConv_cpu.cpp b/torchvision/csrc/cpu/DeformConv_cpu.cpp index c1580f7228a..f25ca2b8d55 100644 --- a/torchvision/csrc/cpu/DeformConv_cpu.cpp +++ b/torchvision/csrc/cpu/DeformConv_cpu.cpp @@ -232,22 +232,23 @@ at::Tensor DeformConv2d_forward_cpu( const at::Tensor& input_param, const at::Tensor& weight_param, const at::Tensor& offset_param, - const at::Tensor& bias, - std::pair stride, - std::pair pad, - std::pair dilation, - int n_weight_grps, - int n_offset_grps) { - at::Tensor input = input_param; - at::Tensor offset = offset_param; - at::Tensor weight = weight_param; + const at::Tensor& bias_param, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dil_h, + int64_t dil_w, + int64_t n_weight_grps, + int64_t n_offset_grps) { + at::Tensor input = input_param.contiguous(); + at::Tensor offset = offset_param.contiguous(); + at::Tensor weight = weight_param.contiguous(); + at::Tensor bias = bias_param.contiguous(); TORCH_CHECK(input.ndimension() == 4); TORCH_CHECK(offset.ndimension() == 4); TORCH_CHECK(weight.ndimension() == 4); - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(offset.is_contiguous()); - TORCH_CHECK(weight.is_contiguous()); TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); int batch_sz = input.size(0); @@ -263,15 +264,6 @@ at::Tensor DeformConv2d_forward_cpu( int weight_h = weight.size(2); int weight_w = weight.size(3); - int stride_h = stride.first; - int stride_w = stride.second; - - int pad_h = pad.first; - int pad_w = pad.second; - - int dil_h = dilation.first; - int dil_w = dilation.second; - int ker_h = dil_h * (weight_h - 1) + 1; int ker_w = dil_w * (weight_w - 1) + 1; int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; @@ -683,9 +675,12 @@ static std::tuple deform_conv2d_backward_input_cpu( at::Tensor weight, at::Tensor offset, at::Tensor grad_out, - std::pair stride, - std::pair pad, - std::pair dilation, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int dil_h, + int dil_w, int n_weight_grps, int n_offset_grps, int n_parallel_imgs) { @@ -700,15 +695,6 @@ static std::tuple deform_conv2d_backward_input_cpu( int weight_h = weight.size(2); int weight_w = weight.size(3); - int stride_h = stride.first; - int stride_w = stride.second; - - int pad_h = pad.first; - int pad_w = pad.second; - - int dil_h = dilation.first; - int dil_w = dilation.second; - long out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) / stride_h + 1; long out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) / stride_w + 1; @@ -813,9 +799,12 @@ static at::Tensor deform_conv2d_backward_parameters_cpu( const at::Tensor& weight, at::Tensor offset, const at::Tensor& grad_out, - std::pair stride, - std::pair pad, - std::pair dilation, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int dil_h, + int dil_w, int n_weight_grps, int n_offset_grps, int n_parallel_imgs) { @@ -830,15 +819,6 @@ static at::Tensor deform_conv2d_backward_parameters_cpu( int weight_h = weight.size(2); int weight_w = weight.size(3); - int stride_h = stride.first; - int stride_w = stride.second; - - int pad_h = pad.first; - int pad_w = pad.second; - - int dil_h = dilation.first; - int dil_w = dilation.second; - long out_h = grad_out.size(2); long out_w = grad_out.size(3); @@ -917,16 +897,25 @@ static at::Tensor deform_conv2d_backward_parameters_cpu( std::tuple DeformConv2d_backward_cpu( - const at::Tensor& grad_out, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& bias, - std::pair stride, - std::pair pad, - std::pair dilation, - int n_weight_grps, - int n_offset_grps) { + const at::Tensor& grad_out_param, + const at::Tensor& input_param, + const at::Tensor& weight_param, + const at::Tensor& offset_param, + const at::Tensor& bias_param, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dil_h, + int64_t dil_w, + int64_t n_weight_grps, + int64_t n_offset_grps) { + at::Tensor grad_out = grad_out_param.contiguous(); + at::Tensor input = input_param.contiguous(); + at::Tensor weight = weight_param.contiguous(); + at::Tensor offset = offset_param.contiguous(); + at::Tensor bias = bias_param.contiguous(); + const int batch_sz = input.size(0); const int n_parallel_imgs = get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); @@ -936,9 +925,12 @@ DeformConv2d_backward_cpu( weight, offset, grad_out, - stride, - pad, - dilation, + stride_h, + stride_w, + pad_h, + pad_w, + dil_h, + dil_w, n_weight_grps, n_offset_grps, n_parallel_imgs); @@ -951,9 +943,12 @@ DeformConv2d_backward_cpu( weight, offset, grad_out, - stride, - pad, - dilation, + stride_h, + stride_w, + pad_h, + pad_w, + dil_h, + dil_w, n_weight_grps, n_offset_grps, n_parallel_imgs); diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index c2a2c36ce99..69b1bbf555d 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -93,11 +93,14 @@ VISION_API at::Tensor DeformConv2d_forward_cpu( const at::Tensor& weight, const at::Tensor& offset, const at::Tensor& bias, - std::pair stride, - std::pair pad, - std::pair dilation, - int groups, - int deformable_groups); + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t deformable_groups); VISION_API std::tuple DeformConv2d_backward_cpu( @@ -106,8 +109,11 @@ DeformConv2d_backward_cpu( const at::Tensor& weight, const at::Tensor& offset, const at::Tensor& bias, - std::pair stride, - std::pair pad, - std::pair dilation, - int groups, - int deformable_groups); + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t deformable_groups); diff --git a/torchvision/csrc/cuda/DeformConv_cuda.cu b/torchvision/csrc/cuda/DeformConv_cuda.cu index 89516ae8454..7f18414e77e 100644 --- a/torchvision/csrc/cuda/DeformConv_cuda.cu +++ b/torchvision/csrc/cuda/DeformConv_cuda.cu @@ -248,22 +248,23 @@ at::Tensor DeformConv2d_forward_cuda( const at::Tensor& input_param, const at::Tensor& weight_param, const at::Tensor& offset_param, - const at::Tensor& bias, - std::pair stride, - std::pair pad, - std::pair dilation, - int n_weight_grps, - int n_offset_grps) { - at::Tensor input = input_param; - at::Tensor weight = weight_param; - at::Tensor offset = offset_param; + const at::Tensor& bias_param, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dil_h, + int64_t dil_w, + int64_t n_weight_grps, + int64_t n_offset_grps) { + at::Tensor input = input_param.contiguous(); + at::Tensor offset = offset_param.contiguous(); + at::Tensor weight = weight_param.contiguous(); + at::Tensor bias = bias_param.contiguous(); TORCH_CHECK(input.ndimension() == 4); TORCH_CHECK(offset.ndimension() == 4); TORCH_CHECK(weight.ndimension() == 4); - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(offset.is_contiguous()); - TORCH_CHECK(weight.is_contiguous()); TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); at::DeviceGuard guard(input.device()); @@ -280,15 +281,6 @@ at::Tensor DeformConv2d_forward_cuda( int weight_h = weight.size(2); int weight_w = weight.size(3); - int stride_h = stride.first; - int stride_w = stride.second; - - int pad_h = pad.first; - int pad_w = pad.second; - - int dil_h = dilation.first; - int dil_w = dilation.second; - int ker_h = dil_h * (weight_h - 1) + 1; int ker_w = dil_w * (weight_w - 1) + 1; int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; @@ -711,9 +703,12 @@ static std::tuple deform_conv_backward_input_cuda( at::Tensor weight, at::Tensor offset, at::Tensor grad_out, - std::pair stride, - std::pair pad, - std::pair dilation, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int dil_h, + int dil_w, int n_weight_grps, int n_offset_grps, int n_parallel_imgs) { @@ -730,15 +725,6 @@ static std::tuple deform_conv_backward_input_cuda( int weight_h = weight.size(2); int weight_w = weight.size(3); - int stride_h = stride.first; - int stride_w = stride.second; - - int pad_h = pad.first; - int pad_w = pad.second; - - int dil_h = dilation.first; - int dil_w = dilation.second; - long out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) / stride_w + 1; long out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) / stride_h + 1; @@ -841,9 +827,12 @@ static at::Tensor deform_conv_backward_parameters_cuda( const at::Tensor& weight, at::Tensor offset, const at::Tensor& grad_out, - std::pair stride, - std::pair pad, - std::pair dilation, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int dil_h, + int dil_w, int n_weight_grps, int n_offset_grps, int n_parallel_imgs) { @@ -860,15 +849,6 @@ static at::Tensor deform_conv_backward_parameters_cuda( int weight_h = weight.size(2); int weight_w = weight.size(3); - int stride_h = stride.first; - int stride_w = stride.second; - - int pad_h = pad.first; - int pad_w = pad.second; - - int dil_h = dilation.first; - int dil_w = dilation.second; - long out_h = grad_out.size(2); long out_w = grad_out.size(3); @@ -946,16 +926,25 @@ static at::Tensor deform_conv_backward_parameters_cuda( std::tuple DeformConv2d_backward_cuda( - const at::Tensor& grad_out, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& bias, - std::pair stride, - std::pair pad, - std::pair dilation, - int n_weight_grps, - int n_offset_grps) { + const at::Tensor& grad_out_param, + const at::Tensor& input_param, + const at::Tensor& weight_param, + const at::Tensor& offset_param, + const at::Tensor& bias_param, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dil_h, + int64_t dil_w, + int64_t n_weight_grps, + int64_t n_offset_grps) { + at::Tensor grad_out = grad_out_param.contiguous(); + at::Tensor input = input_param.contiguous(); + at::Tensor weight = weight_param.contiguous(); + at::Tensor offset = offset_param.contiguous(); + at::Tensor bias = bias_param.contiguous(); + const int batch_sz = input.size(0); const int n_parallel_imgs = get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); @@ -965,9 +954,12 @@ DeformConv2d_backward_cuda( weight, offset, grad_out, - stride, - pad, - dilation, + stride_h, + stride_w, + pad_h, + pad_w, + dil_h, + dil_w, n_weight_grps, n_offset_grps, n_parallel_imgs); @@ -980,9 +972,12 @@ DeformConv2d_backward_cuda( weight, offset, grad_out, - stride, - pad, - dilation, + stride_h, + stride_w, + pad_h, + pad_w, + dil_h, + dil_w, n_weight_grps, n_offset_grps, n_parallel_imgs); diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index 2c13d0aeed3..2481cfc63c2 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -93,11 +93,14 @@ VISION_API at::Tensor DeformConv2d_forward_cuda( const at::Tensor& weight, const at::Tensor& offset, const at::Tensor& bias, - std::pair stride, - std::pair pad, - std::pair dilation, - int groups, - int deformable_groups); + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t deformable_groups); VISION_API std::tuple DeformConv2d_backward_cuda( @@ -106,8 +109,11 @@ DeformConv2d_backward_cuda( const at::Tensor& weight, const at::Tensor& offset, const at::Tensor& bias, - std::pair stride, - std::pair pad, - std::pair dilation, - int groups, - int deformable_groups); + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t deformable_groups); diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index 75e65d67661..f56a671d6e5 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -54,13 +54,18 @@ TORCH_LIBRARY(torchvision, m) { m.def("_new_empty_tensor_op", &new_empty_tensor); m.def("ps_roi_align", &ps_roi_align); m.def("ps_roi_pool", &ps_roi_pool); - m.def("deform_conv2d", &deform_conv2d); + m.def( + "deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> Tensor"); + m.def( + "_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> (Tensor, Tensor, Tensor, Tensor)"); m.def("_cuda_version", &vision::cuda_version); } TORCH_LIBRARY_IMPL(torchvision, CPU, m) { m.impl("roi_align", ROIAlign_forward_cpu); m.impl("_roi_align_backward", ROIAlign_backward_cpu); + m.impl("deform_conv2d", DeformConv2d_forward_cpu); + m.impl("_deform_conv2d_backward", DeformConv2d_backward_cpu); m.impl("nms", nms_cpu); } @@ -69,6 +74,8 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) { TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { m.impl("roi_align", ROIAlign_forward_cuda); m.impl("_roi_align_backward", ROIAlign_backward_cuda); + m.impl("deform_conv2d", DeformConv2d_forward_cuda); + m.impl("_deform_conv2d_backward", DeformConv2d_backward_cuda); m.impl("nms", nms_cuda); } #endif @@ -77,6 +84,7 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { #if defined(WITH_CUDA) || defined(WITH_HIP) TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { m.impl("roi_align", ROIAlign_autocast); + m.impl("deform_conv2d", DeformConv2d_autocast); m.impl("nms", nms_autocast); } #endif @@ -84,4 +92,6 @@ TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { m.impl("roi_align", ROIAlign_autograd); m.impl("_roi_align_backward", ROIAlign_backward_autograd); + m.impl("deform_conv2d", DeformConv2d_autograd); + m.impl("_deform_conv2d_backward", DeformConv2d_backward_autograd); }