diff --git a/torchvision/csrc/autocast.h b/torchvision/csrc/autocast.h index 1f954464b72..584ef13f389 100644 --- a/torchvision/csrc/autocast.h +++ b/torchvision/csrc/autocast.h @@ -1,5 +1,7 @@ #pragma once +// TODO: Delete this file once none of the methods use it + #if defined(WITH_CUDA) || defined(WITH_HIP) #include #endif diff --git a/torchvision/csrc/cpu/DeformConv_cpu.cpp b/torchvision/csrc/cpu/deform_conv2d_kernel.cpp similarity index 84% rename from torchvision/csrc/cpu/DeformConv_cpu.cpp rename to torchvision/csrc/cpu/deform_conv2d_kernel.cpp index 0212be55aa4..f593e880b3b 100644 --- a/torchvision/csrc/cpu/DeformConv_cpu.cpp +++ b/torchvision/csrc/cpu/deform_conv2d_kernel.cpp @@ -66,18 +66,14 @@ // modified from // https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp -#include -#include -#include +#include "deform_conv2d_kernel.h" -#include -#include -#include +namespace { const int kMaxParallelImgs = 32; template -static scalar_t bilinear_interpolate( +scalar_t bilinear_interpolate( const scalar_t* in, int height, int width, @@ -116,7 +112,7 @@ static scalar_t bilinear_interpolate( } template -static void deformable_im2col_kernel( +void deformable_im2col_kernel( int n, const scalar_t* input, const scalar_t* offset, @@ -129,8 +125,8 @@ static void deformable_im2col_kernel( int pad_w, int stride_h, int stride_w, - int dil_h, - int dil_w, + int dilation_h, + int dilation_w, int batch_sz, int n_in_channels, int n_offset_grps, @@ -180,8 +176,10 @@ static void deformable_im2col_kernel( offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; const scalar_t offset_w = offset_ptr [(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; - const scalar_t y = (out_y * stride_h - pad_h) + i * dil_h + offset_h; - const scalar_t x = (out_x * stride_w - pad_w) + j * dil_w + offset_w; + const scalar_t y = + (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + const scalar_t x = + (out_x * stride_w - pad_w) + j * dilation_w + offset_w; *columns_ptr = mask_value * bilinear_interpolate(input_ptr, height, width, y, x); columns_ptr += batch_sz * out_h * out_w; @@ -190,7 +188,7 @@ static void deformable_im2col_kernel( } } -static void deformable_im2col( +void deformable_im2col( const at::Tensor& input, const at::Tensor& data_offset, const at::Tensor& data_mask, @@ -203,8 +201,8 @@ static void deformable_im2col( int pad_w, int stride_h, int stride_w, - int dil_h, - int dil_w, + int dilation_h, + int dilation_w, int out_h, int out_w, int parallel_imgs, @@ -228,8 +226,8 @@ static void deformable_im2col( pad_w, stride_h, stride_w, - dil_h, - dil_w, + dilation_h, + dilation_w, parallel_imgs, n_in_channels, deformable_group, @@ -240,7 +238,7 @@ static void deformable_im2col( })); } -static int get_greatest_divisor_below_bound(int n, int bound) { +int get_greatest_divisor_below_bound(int n, int bound) { for (int k = bound; k > 1; --k) { if (n % k == 0) { return k; @@ -249,216 +247,8 @@ static int get_greatest_divisor_below_bound(int n, int bound) { return 1; } -at::Tensor DeformConv2d_forward_cpu( - const at::Tensor& input_param, - const at::Tensor& weight_param, - const at::Tensor& offset_param, - const at::Tensor& mask_param, - const at::Tensor& bias_param, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dil_h, - int64_t dil_w, - int64_t n_weight_grps, - int64_t n_offset_grps, - bool use_mask) { - at::Tensor input = input_param.contiguous(); - at::Tensor offset = offset_param.contiguous(); - at::Tensor weight = weight_param.contiguous(); - at::Tensor mask = mask_param.contiguous(); - at::Tensor bias = bias_param.contiguous(); - - TORCH_CHECK(input.ndimension() == 4); - TORCH_CHECK(offset.ndimension() == 4); - TORCH_CHECK(!use_mask || mask.ndimension() == 4); - TORCH_CHECK(weight.ndimension() == 4); - TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); - - int batch_sz = input.size(0); - int n_in_channels = input.size(1); - int in_h = input.size(2); - int in_w = input.size(3); - - int n_parallel_imgs = - get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); - - // Unpack shapes and args - int out_channels = weight.size(0); - int weight_h = weight.size(2); - int weight_w = weight.size(3); - - int ker_h = dil_h * (weight_h - 1) + 1; - int ker_w = dil_w * (weight_w - 1) + 1; - int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; - int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; - - TORCH_CHECK( - weight_h > 0 && weight_w > 0, - "weight_h: ", - weight_h, - " weight_w: ", - weight_w); - TORCH_CHECK( - stride_h > 0 && stride_w > 0, - "stride_h: ", - stride_h, - " stride_w: ", - stride_w); - TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w); - TORCH_CHECK(dil_h > 0 && dil_w > 0, "dil_h: ", dil_h, " dil_w: ", dil_w); - - TORCH_CHECK(weight.size(1) * n_weight_grps == input.size(1)); - TORCH_CHECK(weight.size(0) % n_weight_grps == 0); - TORCH_CHECK( - (offset.size(1) == n_offset_grps * 2 * weight_h * weight_w), - "offset.shape[1] is not valid: got: ", - offset.size(1), - " expected: ", - n_offset_grps * 2 * weight_h * weight_w); - TORCH_CHECK( - (!use_mask || mask.size(1) == n_offset_grps * weight_h * weight_w), - "mask.shape[1] is not valid: got: ", - mask.size(1), - " expected: ", - n_offset_grps * weight_h * weight_w); - TORCH_CHECK(input.size(1) % n_offset_grps == 0); - - TORCH_CHECK( - (offset.size(0) == input.size(0)), "invalid batch size of offset"); - TORCH_CHECK( - (offset.size(2) == out_h && offset.size(3) == out_w), - "offset output dims: (", - offset.size(2), - ", ", - offset.size(3), - ") - ", - "computed output dims: (", - out_h, - ", ", - out_w, - ")"); - TORCH_CHECK((mask.size(0) == input.size(0)), "invalid batch size of mask"); - TORCH_CHECK( - (!use_mask || (mask.size(2) == out_h && mask.size(3) == out_w)), - "offset output dims: (", - mask.size(2), - ", ", - mask.size(3), - ") - ", - "computed output dims: (", - out_h, - ", ", - out_w, - ")"); - TORCH_CHECK( - out_h > 0 && out_w > 0, - "Calculated output size too small - out_h: ", - out_h, - " out_w: ", - out_w); - - auto out = at::zeros({batch_sz, out_channels, out_h, out_w}, input.options()); - if (batch_sz == 0) { - return out; - } - - // Separate batches into blocks - out = out.view({batch_sz / n_parallel_imgs, - n_parallel_imgs, - out_channels, - out_h, - out_w}); - input = input.view( - {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); - - offset = offset.view({batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * 2 * weight_h * weight_w, - out_h, - out_w}); - - if (use_mask) { - mask = mask.view({batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * weight_h * weight_w, - out_h, - out_w}); - } - - at::Tensor out_buf = at::zeros( - {batch_sz / n_parallel_imgs, - out_channels, - n_parallel_imgs * out_h, - out_w}, - out.options()); - - // Separate channels into convolution groups - out_buf = out_buf.view({out_buf.size(0), - n_weight_grps, - out_buf.size(1) / n_weight_grps, - out_buf.size(2), - out_buf.size(3)}); - weight = weight.view({n_weight_grps, - weight.size(0) / n_weight_grps, - weight.size(1), - weight.size(2), - weight.size(3)}); - - // Sample points and perform convolution - auto columns = at::zeros( - {n_in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w}, - input.options()); - for (int b = 0; b < batch_sz / n_parallel_imgs; b++) { - deformable_im2col( - input[b], - offset[b], - mask[b], - n_in_channels, - in_h, - in_w, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dil_h, - dil_w, - out_h, - out_w, - n_parallel_imgs, - n_offset_grps, - use_mask, - columns); - - columns = columns.view( - {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); - for (int g = 0; g < n_weight_grps; g++) { - out_buf[b][g] = out_buf[b][g] - .flatten(1) - .addmm_(weight[g].flatten(1), columns[g]) - .view_as(out_buf[b][g]); - } - columns = - columns.view({columns.size(0) * columns.size(1), columns.size(2)}); - } - - out_buf = out_buf.view({batch_sz / n_parallel_imgs, - out_channels, - n_parallel_imgs, - out_h, - out_w}); - out_buf.transpose_(1, 2); - out.copy_(out_buf); - out = out.view({batch_sz, out_channels, out_h, out_w}); - - return out + bias.view({1, out_channels, 1, 1}); -} - template -static void deformable_col2im_kernel( +void deformable_col2im_kernel( int n, const scalar_t* col, const scalar_t* offset, @@ -533,7 +323,7 @@ static void deformable_col2im_kernel( } } -static void compute_grad_input( +void compute_grad_input( const at::Tensor& columns, const at::Tensor& offset, const at::Tensor& mask, @@ -587,7 +377,7 @@ static void compute_grad_input( } template -static scalar_t get_coordinate_weight( +scalar_t get_coordinate_weight( const scalar_t* im_data, int height, int width, @@ -620,7 +410,7 @@ static scalar_t get_coordinate_weight( } template -static void deformable_col2im_coord_kernel( +void deformable_col2im_coord_kernel( int n, const scalar_t* col, const scalar_t* im, @@ -732,7 +522,7 @@ static void deformable_col2im_coord_kernel( } } -static void compute_grad_offset_and_mask( +void compute_grad_offset_and_mask( const at::Tensor& columns, const at::Tensor& input, const at::Tensor& offset, @@ -790,8 +580,7 @@ static void compute_grad_offset_and_mask( })); } -static std::tuple -deform_conv2d_backward_input_cpu( +std::tuple backward_gradient_inputs( at::Tensor input, at::Tensor weight, at::Tensor offset, @@ -801,8 +590,8 @@ deform_conv2d_backward_input_cpu( int stride_w, int pad_h, int pad_w, - int dil_h, - int dil_w, + int dilation_h, + int dilation_w, int n_weight_grps, int n_offset_grps, int n_parallel_imgs, @@ -818,8 +607,10 @@ deform_conv2d_backward_input_cpu( int weight_h = weight.size(2); int weight_w = weight.size(3); - long out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) / stride_h + 1; - long out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) / stride_w + 1; + long out_h = + (in_h + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + long out_w = + (in_w + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; auto grad_input = at::zeros_like(input); auto grad_offset = at::zeros_like(offset); @@ -903,8 +694,8 @@ deform_conv2d_backward_input_cpu( pad_w, stride_h, stride_w, - dil_h, - dil_w, + dilation_h, + dilation_w, n_parallel_imgs, n_offset_grps, use_mask, @@ -924,8 +715,8 @@ deform_conv2d_backward_input_cpu( pad_w, stride_h, stride_w, - dil_h, - dil_w, + dilation_h, + dilation_w, n_parallel_imgs, n_offset_grps, use_mask, @@ -944,7 +735,7 @@ deform_conv2d_backward_input_cpu( return std::make_tuple(grad_input, grad_offset, grad_mask); } -static at::Tensor deform_conv2d_backward_parameters_cpu( +at::Tensor backward_gradient_parameters( at::Tensor input, const at::Tensor& weight, at::Tensor offset, @@ -954,8 +745,8 @@ static at::Tensor deform_conv2d_backward_parameters_cpu( int stride_w, int pad_h, int pad_w, - int dil_h, - int dil_w, + int dilation_h, + int dilation_w, int n_weight_grps, int n_offset_grps, int n_parallel_imgs, @@ -1032,8 +823,8 @@ static at::Tensor deform_conv2d_backward_parameters_cpu( pad_w, stride_h, stride_w, - dil_h, - dil_w, + dilation_h, + dilation_w, out_h, out_w, n_parallel_imgs, @@ -1058,46 +849,263 @@ static at::Tensor deform_conv2d_backward_parameters_cpu( return grad_weight; } +} // namespace + +at::Tensor deform_conv2d_forward_cpu( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + bool use_mask) { + at::Tensor input_c = input.contiguous(); + at::Tensor offset_c = offset.contiguous(); + at::Tensor weight_c = weight.contiguous(); + at::Tensor mask_c = mask.contiguous(); + at::Tensor bias_c = bias.contiguous(); + + TORCH_CHECK(input_c.ndimension() == 4); + TORCH_CHECK(offset_c.ndimension() == 4); + TORCH_CHECK(!use_mask || mask_c.ndimension() == 4); + TORCH_CHECK(weight_c.ndimension() == 4); + TORCH_CHECK(input_c.device().is_cpu(), "input must be a CPU tensor"); + + int batch_sz = input_c.size(0); + int n_in_channels = input_c.size(1); + int in_h = input_c.size(2); + int in_w = input_c.size(3); + + int n_parallel_imgs = + get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); + + // Unpack shapes and args + int out_channels = weight_c.size(0); + int weight_h = weight_c.size(2); + int weight_w = weight_c.size(3); + + int ker_h = dilation_h * (weight_h - 1) + 1; + int ker_w = dilation_w * (weight_w - 1) + 1; + int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; + int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; + + TORCH_CHECK( + weight_h > 0 && weight_w > 0, + "weight_h: ", + weight_h, + " weight_w: ", + weight_w); + TORCH_CHECK( + stride_h > 0 && stride_w > 0, + "stride_h: ", + stride_h, + " stride_w: ", + stride_w); + TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w); + TORCH_CHECK( + dilation_h > 0 && dilation_w > 0, + "dilation_h: ", + dilation_h, + " dilation_w: ", + dilation_w); + + TORCH_CHECK(weight_c.size(1) * n_weight_grps == input_c.size(1)); + TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0); + TORCH_CHECK( + (offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w), + "offset.shape[1] is not valid: got: ", + offset_c.size(1), + " expected: ", + n_offset_grps * 2 * weight_h * weight_w); + TORCH_CHECK( + (!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w), + "mask.shape[1] is not valid: got: ", + mask_c.size(1), + " expected: ", + n_offset_grps * weight_h * weight_w); + TORCH_CHECK(input_c.size(1) % n_offset_grps == 0); + + TORCH_CHECK( + (offset_c.size(0) == input_c.size(0)), "invalid batch size of offset"); + TORCH_CHECK( + (offset_c.size(2) == out_h && offset_c.size(3) == out_w), + "offset output dims: (", + offset_c.size(2), + ", ", + offset_c.size(3), + ") - ", + "computed output dims: (", + out_h, + ", ", + out_w, + ")"); + TORCH_CHECK( + (mask_c.size(0) == input_c.size(0)), "invalid batch size of mask"); + TORCH_CHECK( + (!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w)), + "offset output dims: (", + mask_c.size(2), + ", ", + mask_c.size(3), + ") - ", + "computed output dims: (", + out_h, + ", ", + out_w, + ")"); + TORCH_CHECK( + out_h > 0 && out_w > 0, + "Calculated output size too small - out_h: ", + out_h, + " out_w: ", + out_w); + + auto out = + at::zeros({batch_sz, out_channels, out_h, out_w}, input_c.options()); + if (batch_sz == 0) { + return out; + } + + // Separate batches into blocks + out = out.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + out_channels, + out_h, + out_w}); + input_c = input_c.view( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + + offset_c = offset_c.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + if (use_mask) { + mask_c = mask_c.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + + at::Tensor out_buf = at::zeros( + {batch_sz / n_parallel_imgs, + out_channels, + n_parallel_imgs * out_h, + out_w}, + out.options()); + + // Separate channels into convolution groups + out_buf = out_buf.view({out_buf.size(0), + n_weight_grps, + out_buf.size(1) / n_weight_grps, + out_buf.size(2), + out_buf.size(3)}); + weight_c = weight_c.view({n_weight_grps, + weight_c.size(0) / n_weight_grps, + weight_c.size(1), + weight_c.size(2), + weight_c.size(3)}); + + // Sample points and perform convolution + auto columns = at::zeros( + {n_in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w}, + input_c.options()); + for (int b = 0; b < batch_sz / n_parallel_imgs; b++) { + deformable_im2col( + input_c[b], + offset_c[b], + mask_c[b], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + out_h, + out_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + columns); + + columns = columns.view( + {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + for (int g = 0; g < n_weight_grps; g++) { + out_buf[b][g] = out_buf[b][g] + .flatten(1) + .addmm_(weight_c[g].flatten(1), columns[g]) + .view_as(out_buf[b][g]); + } + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + out_buf = out_buf.view({batch_sz / n_parallel_imgs, + out_channels, + n_parallel_imgs, + out_h, + out_w}); + out_buf.transpose_(1, 2); + out.copy_(out_buf); + out = out.view({batch_sz, out_channels, out_h, out_w}); + + return out + bias_c.view({1, out_channels, 1, 1}); +} + std::tuple -DeformConv2d_backward_cpu( - const at::Tensor& grad_out_param, - const at::Tensor& input_param, - const at::Tensor& weight_param, - const at::Tensor& offset_param, - const at::Tensor& mask_param, - const at::Tensor& bias_param, +deform_conv2d_backward_cpu( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, int64_t stride_h, int64_t stride_w, int64_t pad_h, int64_t pad_w, - int64_t dil_h, - int64_t dil_w, + int64_t dilation_h, + int64_t dilation_w, int64_t n_weight_grps, int64_t n_offset_grps, bool use_mask) { - at::Tensor grad_out = grad_out_param.contiguous(); - at::Tensor input = input_param.contiguous(); - at::Tensor weight = weight_param.contiguous(); - at::Tensor offset = offset_param.contiguous(); - at::Tensor mask = mask_param.contiguous(); - at::Tensor bias = bias_param.contiguous(); - - const int batch_sz = input.size(0); + at::Tensor grad_out_c = grad_out.contiguous(); + at::Tensor input_c = input.contiguous(); + at::Tensor weight_c = weight.contiguous(); + at::Tensor offset_c = offset.contiguous(); + at::Tensor mask_c = mask.contiguous(); + at::Tensor bias_c = bias.contiguous(); + + const int batch_sz = input_c.size(0); const int n_parallel_imgs = get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); - auto grad_input_and_offset_and_mask = deform_conv2d_backward_input_cpu( - input, - weight, - offset, - mask, - grad_out, + auto grad_input_and_offset_and_mask = backward_gradient_inputs( + input_c, + weight_c, + offset_c, + mask_c, + grad_out_c, stride_h, stride_w, pad_h, pad_w, - dil_h, - dil_w, + dilation_h, + dilation_w, n_weight_grps, n_offset_grps, n_parallel_imgs, @@ -1107,24 +1115,24 @@ DeformConv2d_backward_cpu( auto grad_offset = std::get<1>(grad_input_and_offset_and_mask); auto grad_mask = std::get<2>(grad_input_and_offset_and_mask); - auto grad_weight = deform_conv2d_backward_parameters_cpu( - input, - weight, - offset, - mask, - grad_out, + auto grad_weight = backward_gradient_parameters( + input_c, + weight_c, + offset_c, + mask_c, + grad_out_c, stride_h, stride_w, pad_h, pad_w, - dil_h, - dil_w, + dilation_h, + dilation_w, n_weight_grps, n_offset_grps, n_parallel_imgs, use_mask); - auto grad_bias = at::ones_like(bias) * grad_out.sum({0, 2, 3}); + auto grad_bias = at::ones_like(bias_c) * grad_out_c.sum({0, 2, 3}); return std::make_tuple( grad_input, grad_weight, grad_offset, grad_mask, grad_bias); diff --git a/torchvision/csrc/cpu/deform_conv2d_kernel.h b/torchvision/csrc/cpu/deform_conv2d_kernel.h new file mode 100644 index 00000000000..2eb5ab37c6e --- /dev/null +++ b/torchvision/csrc/cpu/deform_conv2d_kernel.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include "../macros.h" + +VISION_API at::Tensor deform_conv2d_forward_cpu( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + bool use_mask); + +VISION_API std:: + tuple + deform_conv2d_backward_cpu( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + bool use_mask); diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index d5bfcc0de24..6f85d9c0256 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -2,40 +2,7 @@ #include #include "../macros.h" -VISION_API at::Tensor DeformConv2d_forward_cpu( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t deformable_groups, - bool use_mask); - -VISION_API std:: - tuple - DeformConv2d_backward_cpu( - const at::Tensor& grad_out, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t deformable_groups, - bool use_mask); +// TODO: Delete this file once all the methods are gone VISION_API at::Tensor nms_cpu( const at::Tensor& dets, diff --git a/torchvision/csrc/cuda/DeformConv_cuda.cu b/torchvision/csrc/cuda/deform_conv2d_kernel.cu similarity index 85% rename from torchvision/csrc/cuda/DeformConv_cuda.cu rename to torchvision/csrc/cuda/deform_conv2d_kernel.cu index 507532e7184..6edaa9c73af 100644 --- a/torchvision/csrc/cuda/DeformConv_cuda.cu +++ b/torchvision/csrc/cuda/deform_conv2d_kernel.cu @@ -67,16 +67,14 @@ // https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp #include -#include #include #include #include #include "cuda_helpers.h" +#include "deform_conv2d_kernel.h" -#include -#include -#include +namespace { const int kMaxParallelImgs = 32; @@ -136,7 +134,7 @@ __device__ scalar_t bilinear_interpolate( } template -__global__ void deformable_im2col_gpu_kernel( +__global__ void deformable_im2col_kernel( int n, const scalar_t* input_ptr, const scalar_t* offset_ptr, @@ -149,8 +147,8 @@ __global__ void deformable_im2col_gpu_kernel( int pad_w, int stride_h, int stride_w, - int dil_h, - int dil_w, + int dilation_h, + int dilation_w, int batch_sz, int n_in_channels, int n_offset_grps, @@ -198,8 +196,10 @@ __global__ void deformable_im2col_gpu_kernel( offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; const scalar_t offset_w = offset_ptr [(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; - const scalar_t y = (out_y * stride_h - pad_h) + i * dil_h + offset_h; - const scalar_t x = (out_x * stride_w - pad_w) + j * dil_w + offset_w; + const scalar_t y = + (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + const scalar_t x = + (out_x * stride_w - pad_w) + j * dilation_w + offset_w; *columns_ptr = mask_value * bilinear_interpolate(input_ptr, height, width, y, x); columns_ptr += batch_sz * out_h * out_w; @@ -208,7 +208,7 @@ __global__ void deformable_im2col_gpu_kernel( } } -static void deformable_im2col( +void deformable_im2col( const at::Tensor& input, const at::Tensor& data_offset, const at::Tensor& data_mask, @@ -221,8 +221,8 @@ static void deformable_im2col( int pad_w, int stride_h, int stride_w, - int dil_h, - int dil_w, + int dilation_h, + int dilation_w, int out_h, int out_w, int parallel_imgs, @@ -236,7 +236,7 @@ static void deformable_im2col( AT_DISPATCH_FLOATING_TYPES_AND_HALF( input.scalar_type(), "deformable_im2col_gpu", ([&] { - deformable_im2col_gpu_kernel<<< + deformable_im2col_kernel<<< blocks, threads>>>( num_kernels, @@ -251,8 +251,8 @@ static void deformable_im2col( pad_w, stride_h, stride_w, - dil_h, - dil_w, + dilation_h, + dilation_w, parallel_imgs, n_in_channels, deformable_group, @@ -268,7 +268,7 @@ static void deformable_im2col( } } -static int get_greatest_divisor_below_bound(int n, int bound) { +int get_greatest_divisor_below_bound(int n, int bound) { for (int k = bound; k > 1; --k) { if (n % k == 0) { return k; @@ -277,217 +277,8 @@ static int get_greatest_divisor_below_bound(int n, int bound) { return 1; } -at::Tensor DeformConv2d_forward_cuda( - const at::Tensor& input_param, - const at::Tensor& weight_param, - const at::Tensor& offset_param, - const at::Tensor& mask_param, - const at::Tensor& bias_param, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dil_h, - int64_t dil_w, - int64_t n_weight_grps, - int64_t n_offset_grps, - bool use_mask) { - at::Tensor input = input_param.contiguous(); - at::Tensor offset = offset_param.contiguous(); - at::Tensor weight = weight_param.contiguous(); - at::Tensor mask = mask_param.contiguous(); - at::Tensor bias = bias_param.contiguous(); - - TORCH_CHECK(input.ndimension() == 4); - TORCH_CHECK(offset.ndimension() == 4); - TORCH_CHECK(!use_mask || mask.ndimension() == 4); - TORCH_CHECK(weight.ndimension() == 4); - TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); - - at::DeviceGuard guard(input.device()); - - int batch_sz = input.size(0); - int in_channels = input.size(1); - int in_h = input.size(2); - int in_w = input.size(3); - - int n_parallel_imgs = - get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); - - int out_channels = weight.size(0); - int weight_h = weight.size(2); - int weight_w = weight.size(3); - - int ker_h = dil_h * (weight_h - 1) + 1; - int ker_w = dil_w * (weight_w - 1) + 1; - int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; - int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; - - TORCH_CHECK( - weight_h > 0 && weight_w > 0, - "weight_h: ", - weight_h, - " weight_w: ", - weight_w); - TORCH_CHECK( - stride_h > 0 && stride_w > 0, - "stride_h: ", - stride_h, - " stride_w: ", - stride_w); - TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w); - TORCH_CHECK(dil_h > 0 && dil_w > 0, "dil_h: ", dil_h, " dil_w: ", dil_w); - - TORCH_CHECK(weight.size(1) * n_weight_grps == input.size(1)); - TORCH_CHECK(weight.size(0) % n_weight_grps == 0); - TORCH_CHECK( - (offset.size(1) == n_offset_grps * 2 * weight_h * weight_w), - "offset.shape[1] is not valid: got: ", - offset.size(1), - " expected: ", - n_offset_grps * 2 * weight_h * weight_w); - TORCH_CHECK( - (!use_mask || mask.size(1) == n_offset_grps * weight_h * weight_w), - "mask.shape[1] is not valid: got: ", - mask.size(1), - " expected: ", - n_offset_grps * weight_h * weight_w); - TORCH_CHECK(input.size(1) % n_offset_grps == 0); - - TORCH_CHECK( - (offset.size(0) == input.size(0)), "invalid batch size of offset"); - TORCH_CHECK( - (offset.size(2) == out_h && offset.size(3) == out_w), - "offset output dims: (", - offset.size(2), - ", ", - offset.size(3), - ") - ", - "computed output dims: (", - out_h, - ", ", - out_w, - ")"); - TORCH_CHECK((mask.size(0) == input.size(0)), "invalid batch size of mask"); - TORCH_CHECK( - (!use_mask || (mask.size(2) == out_h && mask.size(3) == out_w)), - "mask output dims: (", - mask.size(2), - ", ", - mask.size(3), - ") - ", - "computed output dims: (", - out_h, - ", ", - out_w, - ")"); - TORCH_CHECK( - out_h > 0 && out_w > 0, - "Calculated output size too small - out_h: ", - out_h, - " out_w: ", - out_w); - - auto out = at::zeros({batch_sz, out_channels, out_h, out_w}, input.options()); - if (batch_sz == 0) { - return out; - } - - // Separate batches into blocks - out = out.view({batch_sz / n_parallel_imgs, - n_parallel_imgs, - out_channels, - out_h, - out_w}); - input = input.view( - {batch_sz / n_parallel_imgs, n_parallel_imgs, in_channels, in_h, in_w}); - - offset = offset.view({batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * 2 * weight_h * weight_w, - out_h, - out_w}); - - if (use_mask) { - mask = mask.view({batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * weight_h * weight_w, - out_h, - out_w}); - } - - at::Tensor out_buf = at::zeros( - {batch_sz / n_parallel_imgs, - out_channels, - n_parallel_imgs * out_h, - out_w}, - out.options()); - - // Separate channels into convolution groups - out_buf = out_buf.view({out_buf.size(0), - n_weight_grps, - out_buf.size(1) / n_weight_grps, - out_buf.size(2), - out_buf.size(3)}); - weight = weight.view({n_weight_grps, - weight.size(0) / n_weight_grps, - weight.size(1), - weight.size(2), - weight.size(3)}); - - // Sample points and perform convolution - auto columns = at::zeros( - {in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w}, - input.options()); - for (int b = 0; b < batch_sz / n_parallel_imgs; b++) { - deformable_im2col( - input[b], - offset[b], - mask[b], - in_channels, - in_h, - in_w, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dil_h, - dil_w, - out_h, - out_w, - n_parallel_imgs, - n_offset_grps, - use_mask, - columns); - - columns = columns.view( - {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); - for (int g = 0; g < n_weight_grps; g++) { - out_buf[b][g] = out_buf[b][g] - .flatten(1) - .addmm_(weight[g].flatten(1), columns[g]) - .view_as(out_buf[b][g]); - } - columns = - columns.view({columns.size(0) * columns.size(1), columns.size(2)}); - } - - out_buf = out_buf.view({batch_sz / n_parallel_imgs, - out_channels, - n_parallel_imgs, - out_h, - out_w}); - out_buf.transpose_(1, 2); - out.copy_(out_buf); - out = out.view({batch_sz, out_channels, out_h, out_w}); - - return out + bias.view({1, out_channels, 1, 1}); -} - template -__global__ void deformable_col2im_gpu_kernel( +__global__ void deformable_col2im_kernel( int n, const scalar_t* col, const scalar_t* offset_ptr, @@ -560,7 +351,7 @@ __global__ void deformable_col2im_gpu_kernel( } } -static void compute_grad_input( +void compute_grad_input( const at::Tensor& columns, const at::Tensor& offset, const at::Tensor& mask, @@ -591,7 +382,7 @@ static void compute_grad_input( AT_DISPATCH_FLOATING_TYPES_AND_HALF( columns.scalar_type(), "deformable_col2im_gpu", ([&] { - deformable_col2im_gpu_kernel<<< + deformable_col2im_kernel<<< blocks, threads>>>( num_kernels, @@ -657,7 +448,7 @@ __device__ scalar_t get_coordinate_weight( } template -__global__ void deformable_col2im_coord_gpu_kernel( +__global__ void deformable_col2im_coord_kernel( int n, const scalar_t* col_ptr, const scalar_t* im_ptr, @@ -766,7 +557,7 @@ __global__ void deformable_col2im_coord_gpu_kernel( } } -static void compute_grad_offset_and_mask( +void compute_grad_offset_and_mask( const at::Tensor& columns, const at::Tensor& input, const at::Tensor& offset, @@ -799,7 +590,7 @@ static void compute_grad_offset_and_mask( AT_DISPATCH_FLOATING_TYPES_AND_HALF( columns.scalar_type(), "deformable_col2im_coord_gpu", ([&] { - deformable_col2im_coord_gpu_kernel<<< + deformable_col2im_coord_kernel<<< blocks, threads>>>( num_kernels, @@ -835,7 +626,7 @@ static void compute_grad_offset_and_mask( } } -static std::tuple deform_conv2d_backward_input_cuda( +std::tuple backward_gradient_inputs( at::Tensor input, at::Tensor weight, at::Tensor offset, @@ -845,8 +636,8 @@ static std::tuple deform_conv2d_backward_inp int stride_w, int pad_h, int pad_w, - int dil_h, - int dil_w, + int dilation_h, + int dilation_w, int n_weight_grps, int n_offset_grps, int n_parallel_imgs, @@ -864,8 +655,10 @@ static std::tuple deform_conv2d_backward_inp int weight_h = weight.size(2); int weight_w = weight.size(3); - long out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) / stride_w + 1; - long out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) / stride_h + 1; + long out_w = + (in_w + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + long out_h = + (in_h + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; auto grad_input = at::zeros_like(input); auto grad_offset = at::zeros_like(offset); @@ -948,8 +741,8 @@ static std::tuple deform_conv2d_backward_inp pad_w, stride_h, stride_w, - dil_h, - dil_w, + dilation_h, + dilation_w, n_parallel_imgs, n_offset_grps, use_mask, @@ -969,8 +762,8 @@ static std::tuple deform_conv2d_backward_inp pad_w, stride_h, stride_w, - dil_h, - dil_w, + dilation_h, + dilation_w, n_parallel_imgs, n_offset_grps, use_mask, @@ -989,7 +782,7 @@ static std::tuple deform_conv2d_backward_inp return std::make_tuple(grad_input, grad_offset, grad_mask); } -static at::Tensor deform_conv2d_backward_parameters_cuda( +at::Tensor backward_gradient_parameters( at::Tensor input, const at::Tensor& weight, at::Tensor offset, @@ -999,8 +792,8 @@ static at::Tensor deform_conv2d_backward_parameters_cuda( int stride_w, int pad_h, int pad_w, - int dil_h, - int dil_w, + int dilation_h, + int dilation_w, int n_weight_grps, int n_offset_grps, int n_parallel_imgs, @@ -1079,8 +872,8 @@ static at::Tensor deform_conv2d_backward_parameters_cuda( pad_w, stride_h, stride_w, - dil_h, - dil_w, + dilation_h, + dilation_w, out_h, out_w, n_parallel_imgs, @@ -1105,46 +898,264 @@ static at::Tensor deform_conv2d_backward_parameters_cuda( return grad_weight; } +} // namespace + +at::Tensor deform_conv2d_forward_cuda( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + bool use_mask) { + at::Tensor input_c = input.contiguous(); + at::Tensor offset_c = offset.contiguous(); + at::Tensor weight_c = weight.contiguous(); + at::Tensor mask_c = mask.contiguous(); + at::Tensor bias_c = bias.contiguous(); + + TORCH_CHECK(input_c.ndimension() == 4); + TORCH_CHECK(offset_c.ndimension() == 4); + TORCH_CHECK(!use_mask || mask_c.ndimension() == 4); + TORCH_CHECK(weight_c.ndimension() == 4); + TORCH_CHECK(input_c.is_cuda(), "input must be a CUDA tensor"); + + at::DeviceGuard guard(input_c.device()); + + int batch_sz = input_c.size(0); + int in_channels = input_c.size(1); + int in_h = input_c.size(2); + int in_w = input_c.size(3); + + int n_parallel_imgs = + get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); + + int out_channels = weight_c.size(0); + int weight_h = weight_c.size(2); + int weight_w = weight_c.size(3); + + int ker_h = dilation_h * (weight_h - 1) + 1; + int ker_w = dilation_w * (weight_w - 1) + 1; + int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; + int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; + + TORCH_CHECK( + weight_h > 0 && weight_w > 0, + "weight_h: ", + weight_h, + " weight_w: ", + weight_w); + TORCH_CHECK( + stride_h > 0 && stride_w > 0, + "stride_h: ", + stride_h, + " stride_w: ", + stride_w); + TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w); + TORCH_CHECK( + dilation_h > 0 && dilation_w > 0, + "dilation_h: ", + dilation_h, + " dilation_w: ", + dilation_w); + + TORCH_CHECK(weight_c.size(1) * n_weight_grps == input_c.size(1)); + TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0); + TORCH_CHECK( + (offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w), + "offset.shape[1] is not valid: got: ", + offset_c.size(1), + " expected: ", + n_offset_grps * 2 * weight_h * weight_w); + TORCH_CHECK( + (!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w), + "mask.shape[1] is not valid: got: ", + mask_c.size(1), + " expected: ", + n_offset_grps * weight_h * weight_w); + TORCH_CHECK(input_c.size(1) % n_offset_grps == 0); + + TORCH_CHECK( + (offset_c.size(0) == input_c.size(0)), "invalid batch size of offset"); + TORCH_CHECK( + (offset_c.size(2) == out_h && offset_c.size(3) == out_w), + "offset output dims: (", + offset_c.size(2), + ", ", + offset_c.size(3), + ") - ", + "computed output dims: (", + out_h, + ", ", + out_w, + ")"); + TORCH_CHECK( + (mask_c.size(0) == input_c.size(0)), "invalid batch size of mask"); + TORCH_CHECK( + (!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w)), + "mask output dims: (", + mask_c.size(2), + ", ", + mask_c.size(3), + ") - ", + "computed output dims: (", + out_h, + ", ", + out_w, + ")"); + TORCH_CHECK( + out_h > 0 && out_w > 0, + "Calculated output size too small - out_h: ", + out_h, + " out_w: ", + out_w); + + auto out = + at::zeros({batch_sz, out_channels, out_h, out_w}, input_c.options()); + if (batch_sz == 0) { + return out; + } + + // Separate batches into blocks + out = out.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + out_channels, + out_h, + out_w}); + input_c = input_c.view( + {batch_sz / n_parallel_imgs, n_parallel_imgs, in_channels, in_h, in_w}); + + offset_c = offset_c.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + if (use_mask) { + mask_c = mask_c.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + + at::Tensor out_buf = at::zeros( + {batch_sz / n_parallel_imgs, + out_channels, + n_parallel_imgs * out_h, + out_w}, + out.options()); + + // Separate channels into convolution groups + out_buf = out_buf.view({out_buf.size(0), + n_weight_grps, + out_buf.size(1) / n_weight_grps, + out_buf.size(2), + out_buf.size(3)}); + weight_c = weight_c.view({n_weight_grps, + weight_c.size(0) / n_weight_grps, + weight_c.size(1), + weight_c.size(2), + weight_c.size(3)}); + + // Sample points and perform convolution + auto columns = at::zeros( + {in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w}, + input_c.options()); + for (int b = 0; b < batch_sz / n_parallel_imgs; b++) { + deformable_im2col( + input_c[b], + offset_c[b], + mask_c[b], + in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + out_h, + out_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + columns); + + columns = columns.view( + {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + for (int g = 0; g < n_weight_grps; g++) { + out_buf[b][g] = out_buf[b][g] + .flatten(1) + .addmm_(weight_c[g].flatten(1), columns[g]) + .view_as(out_buf[b][g]); + } + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + out_buf = out_buf.view({batch_sz / n_parallel_imgs, + out_channels, + n_parallel_imgs, + out_h, + out_w}); + out_buf.transpose_(1, 2); + out.copy_(out_buf); + out = out.view({batch_sz, out_channels, out_h, out_w}); + + return out + bias_c.view({1, out_channels, 1, 1}); +} + std::tuple -DeformConv2d_backward_cuda( - const at::Tensor& grad_out_param, - const at::Tensor& input_param, - const at::Tensor& weight_param, - const at::Tensor& offset_param, - const at::Tensor& mask_param, - const at::Tensor& bias_param, +deform_conv2d_backward_cuda( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, int64_t stride_h, int64_t stride_w, int64_t pad_h, int64_t pad_w, - int64_t dil_h, - int64_t dil_w, + int64_t dilation_h, + int64_t dilation_w, int64_t n_weight_grps, int64_t n_offset_grps, bool use_mask) { - at::Tensor grad_out = grad_out_param.contiguous(); - at::Tensor input = input_param.contiguous(); - at::Tensor weight = weight_param.contiguous(); - at::Tensor offset = offset_param.contiguous(); - at::Tensor mask = mask_param.contiguous(); - at::Tensor bias = bias_param.contiguous(); - - const int batch_sz = input.size(0); + at::Tensor grad_out_c = grad_out.contiguous(); + at::Tensor input_c = input.contiguous(); + at::Tensor weight_c = weight.contiguous(); + at::Tensor offset_c = offset.contiguous(); + at::Tensor mask_c = mask.contiguous(); + at::Tensor bias_c = bias.contiguous(); + + const int batch_sz = input_c.size(0); const int n_parallel_imgs = get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); - auto grad_input_and_offset_and_mask = deform_conv2d_backward_input_cuda( - input, - weight, - offset, - mask, - grad_out, + auto grad_input_and_offset_and_mask = backward_gradient_inputs( + input_c, + weight_c, + offset_c, + mask_c, + grad_out_c, stride_h, stride_w, pad_h, pad_w, - dil_h, - dil_w, + dilation_h, + dilation_w, n_weight_grps, n_offset_grps, n_parallel_imgs, @@ -1154,25 +1165,25 @@ DeformConv2d_backward_cuda( auto grad_offset = std::get<1>(grad_input_and_offset_and_mask); auto grad_mask = std::get<2>(grad_input_and_offset_and_mask); - auto grad_weight = deform_conv2d_backward_parameters_cuda( - input, - weight, - offset, - mask, - grad_out, + auto grad_weight = backward_gradient_parameters( + input_c, + weight_c, + offset_c, + mask_c, + grad_out_c, stride_h, stride_w, pad_h, pad_w, - dil_h, - dil_w, + dilation_h, + dilation_w, n_weight_grps, n_offset_grps, n_parallel_imgs, use_mask); - auto value = grad_out.sum({0, 2, 3}); - auto grad_bias = at::ones_like(bias) * value; + auto value = grad_out_c.sum({0, 2, 3}); + auto grad_bias = at::ones_like(bias_c) * value; return std::make_tuple( grad_input, grad_weight, grad_offset, grad_mask, grad_bias); diff --git a/torchvision/csrc/cuda/deform_conv2d_kernel.h b/torchvision/csrc/cuda/deform_conv2d_kernel.h new file mode 100644 index 00000000000..00f3f3dc15d --- /dev/null +++ b/torchvision/csrc/cuda/deform_conv2d_kernel.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include "../macros.h" + +VISION_API at::Tensor deform_conv2d_forward_cuda( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + bool use_mask); + +VISION_API std:: + tuple + deform_conv2d_backward_cuda( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + bool use_mask); diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index bf57f1c7967..834973c5327 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -2,40 +2,7 @@ #include #include "../macros.h" -VISION_API at::Tensor DeformConv2d_forward_cuda( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t deformable_groups, - bool use_mask); - -VISION_API std:: - tuple - DeformConv2d_backward_cuda( - const at::Tensor& grad_out, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t deformable_groups, - bool use_mask); +// TODO: Delete this file once all the methods are gone VISION_API at::Tensor nms_cuda( const at::Tensor& dets, diff --git a/torchvision/csrc/DeformConv.h b/torchvision/csrc/deform_conv2d.cpp similarity index 96% rename from torchvision/csrc/DeformConv.h rename to torchvision/csrc/deform_conv2d.cpp index f8a8dba60e6..74ba630537a 100644 --- a/torchvision/csrc/DeformConv.h +++ b/torchvision/csrc/deform_conv2d.cpp @@ -1,18 +1,10 @@ -#pragma once +#include "deform_conv2d.h" +#include -#include "cpu/vision_cpu.h" - -#ifdef WITH_CUDA -#include "autocast.h" -#include "cuda/vision_cuda.h" -#endif -#ifdef WITH_HIP -#include "autocast.h" -#include "hip/vision_cuda.h" +#if defined(WITH_CUDA) || defined(WITH_HIP) +#include #endif -// TODO: put this stuff in torchvision namespace - at::Tensor deform_conv2d( const at::Tensor& input, const at::Tensor& weight, @@ -49,7 +41,7 @@ at::Tensor deform_conv2d( } #if defined(WITH_CUDA) || defined(WITH_HIP) -at::Tensor DeformConv2d_autocast( +at::Tensor deform_conv2d_autocast( const at::Tensor& input, const at::Tensor& weight, const at::Tensor& offset, @@ -123,6 +115,8 @@ _deform_conv2d_backward( use_mask); } +namespace { + class DeformConv2dFunction : public torch::autograd::Function { public: @@ -297,7 +291,9 @@ class DeformConv2dBackwardFunction } }; -at::Tensor DeformConv2d_autograd( +} // namespace + +at::Tensor deform_conv2d_autograd( const at::Tensor& input, const at::Tensor& weight, const at::Tensor& offset, @@ -330,7 +326,7 @@ at::Tensor DeformConv2d_autograd( } std::tuple -DeformConv2d_backward_autograd( +deform_conv2d_backward_autograd( const at::Tensor& grad, const at::Tensor& input, const at::Tensor& weight, diff --git a/torchvision/csrc/deform_conv2d.h b/torchvision/csrc/deform_conv2d.h new file mode 100644 index 00000000000..6adc77fb888 --- /dev/null +++ b/torchvision/csrc/deform_conv2d.h @@ -0,0 +1,100 @@ +#pragma once + +#include "cpu/deform_conv2d_kernel.h" + +#ifdef WITH_CUDA +#include "cuda/deform_conv2d_kernel.h" +#endif +#ifdef WITH_HIP +#include "hip/deform_conv2d_kernel.h" +#endif + +// C++ Forward +at::Tensor deform_conv2d( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t offset_groups, + bool use_mask); + +// Autocast Forward +#if defined(WITH_CUDA) || defined(WITH_HIP) +at::Tensor deform_conv2d_autocast( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t offset_groups, + bool use_mask); +#endif + +// C++ Backward +std::tuple +_deform_conv2d_backward( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t offset_groups, + bool use_mask); + +// Autograd Forward and Backward +at::Tensor deform_conv2d_autograd( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t offset_groups, + bool use_mask); + +std::tuple +deform_conv2d_backward_autograd( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t offset_groups, + bool use_mask); diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index 44c8346ff7b..2d4e2af0f53 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -8,11 +8,11 @@ #include #endif -#include "DeformConv.h" #include "PSROIAlign.h" #include "PSROIPool.h" #include "ROIAlign.h" #include "ROIPool.h" +#include "deform_conv2d.h" #include "empty_tensor_op.h" #include "nms.h" @@ -62,8 +62,8 @@ TORCH_LIBRARY(torchvision, m) { } TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl("deform_conv2d", DeformConv2d_forward_cpu); - m.impl("_deform_conv2d_backward", DeformConv2d_backward_cpu); + m.impl("deform_conv2d", deform_conv2d_forward_cpu); + m.impl("_deform_conv2d_backward", deform_conv2d_backward_cpu); m.impl("nms", nms_cpu); m.impl("ps_roi_align", PSROIAlign_forward_cpu); m.impl("_ps_roi_align_backward", PSROIAlign_backward_cpu); @@ -78,8 +78,8 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) { // TODO: Place this in a hypothetical separate torchvision_cuda library #if defined(WITH_CUDA) || defined(WITH_HIP) TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { - m.impl("deform_conv2d", DeformConv2d_forward_cuda); - m.impl("_deform_conv2d_backward", DeformConv2d_backward_cuda); + m.impl("deform_conv2d", deform_conv2d_forward_cuda); + m.impl("_deform_conv2d_backward", deform_conv2d_backward_cuda); m.impl("nms", nms_cuda); m.impl("ps_roi_align", PSROIAlign_forward_cuda); m.impl("_ps_roi_align_backward", PSROIAlign_backward_cuda); @@ -95,7 +95,7 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { // Autocast only needs to wrap forward pass ops. #if defined(WITH_CUDA) || defined(WITH_HIP) TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { - m.impl("deform_conv2d", DeformConv2d_autocast); + m.impl("deform_conv2d", deform_conv2d_autocast); m.impl("nms", nms_autocast); m.impl("ps_roi_align", PSROIAlign_autocast); m.impl("ps_roi_pool", PSROIPool_autocast); @@ -105,8 +105,8 @@ TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { #endif TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { - m.impl("deform_conv2d", DeformConv2d_autograd); - m.impl("_deform_conv2d_backward", DeformConv2d_backward_autograd); + m.impl("deform_conv2d", deform_conv2d_autograd); + m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd); m.impl("ps_roi_align", PSROIAlign_autograd); m.impl("_ps_roi_align_backward", PSROIAlign_backward_autograd); m.impl("ps_roi_pool", PSROIPool_autograd);