Skip to content

Port DeformConv to use the Dispatcher and support Autocast #2898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 27, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 91 additions & 97 deletions torchvision/csrc/DeformConv.h
Original file line number Diff line number Diff line change
@@ -1,89 +1,105 @@
#pragma once

#include "cpu/vision_cpu.h"

#ifdef WITH_CUDA
#include "cuda/vision_cuda.h"
#endif
#ifdef WITH_HIP
#include "hip/vision_cuda.h"
#if defined(WITH_CUDA) || defined(WITH_HIP)
#include "autocast.h"
#endif

at::Tensor DeformConv2d_forward(
// TODO: put this stuff in torchvision namespace

at::Tensor deform_conv2d(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const std::pair<int, int>& stride,
const std::pair<int, int>& padding,
const std::pair<int, int>& dilation,
const int groups,
const int offset_groups) {
if (input.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return DeformConv2d_forward_cuda(
input.contiguous(),
weight.contiguous(),
offset.contiguous(),
bias.contiguous(),
stride,
padding,
dilation,
groups,
offset_groups);
#else
TORCH_CHECK(false, "Not compiled with GPU support");
#endif
}
return DeformConv2d_forward_cpu(
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::deform_conv2d", "")
.typed<decltype(deform_conv2d)>();
return op.call(
input.contiguous(),
weight.contiguous(),
offset.contiguous(),
bias.contiguous(),
stride,
padding,
dilation,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> DeformConv2d_backward(
const at::Tensor& grad,
#if defined(WITH_CUDA) || defined(WITH_HIP)
at::Tensor DeformConv2d_autocast(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const std::pair<int, int>& stride,
const std::pair<int, int>& padding,
const std::pair<int, int>& dilation,
const int groups,
const int offset_groups) {
if (grad.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return DeformConv2d_backward_cuda(
grad.contiguous(),
input.contiguous(),
weight.contiguous(),
offset.contiguous(),
bias.contiguous(),
stride,
padding,
dilation,
groups,
offset_groups);
#else
TORCH_CHECK(false, "Not compiled with GPU support");
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return deform_conv2d(
at::autocast::cached_cast(at::kFloat, input),
at::autocast::cached_cast(at::kFloat, weight),
at::autocast::cached_cast(at::kFloat, offset),
at::autocast::cached_cast(at::kFloat, bias),
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups)
.to(input.scalar_type());
}
#endif
}
return DeformConv2d_backward_cpu(

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_deform_conv2d_backward(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_deform_conv2d_backward", "")
.typed<decltype(_deform_conv2d_backward)>();
return op.call(
grad.contiguous(),
input.contiguous(),
weight.contiguous(),
offset.contiguous(),
bias.contiguous(),
stride,
padding,
dilation,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups);
}
Expand All @@ -105,14 +121,18 @@ class DeformConv2dFunction
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
auto output = DeformConv2d_forward(
at::AutoNonVariableTypeMode g; // TODO: check if necessary
auto output = deform_conv2d(
input,
weight,
offset,
bias,
{stride_h, stride_w},
{pad_h, pad_w},
{dilation_h, dilation_w},
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups);

Expand Down Expand Up @@ -149,15 +169,18 @@ class DeformConv2dFunction
auto groups = ctx->saved_data["groups"].toInt();
auto offset_groups = ctx->saved_data["offset_groups"].toInt();

auto grads = DeformConv2d_backward(
auto grads = _deform_conv2d_backward(
grad_output[0],
input,
weight,
offset,
bias,
{stride_h, stride_w},
{pad_h, pad_w},
{dilation_h, dilation_w},
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups);
auto grad_input = std::get<0>(grads);
Expand All @@ -181,32 +204,3 @@ class DeformConv2dFunction
};
}
};

at::Tensor deform_conv2d(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
auto result = DeformConv2dFunction::apply(
input,
weight,
offset,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups);
return result[0];
}
89 changes: 40 additions & 49 deletions torchvision/csrc/cpu/DeformConv_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,14 @@ at::Tensor DeformConv2d_forward_cpu(
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& bias,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
int n_weight_grps,
int n_offset_grps) {
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps) {
at::Tensor input = input_param;
at::Tensor offset = offset_param;
at::Tensor weight = weight_param;
Expand All @@ -263,15 +266,6 @@ at::Tensor DeformConv2d_forward_cpu(
int weight_h = weight.size(2);
int weight_w = weight.size(3);

int stride_h = stride.first;
int stride_w = stride.second;

int pad_h = pad.first;
int pad_w = pad.second;

int dil_h = dilation.first;
int dil_w = dilation.second;

int ker_h = dil_h * (weight_h - 1) + 1;
int ker_w = dil_w * (weight_w - 1) + 1;
int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1;
Expand Down Expand Up @@ -683,9 +677,12 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
at::Tensor weight,
at::Tensor offset,
at::Tensor grad_out,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
int stride_h,
int stride_w,
int pad_h,
int pad_w,
int dil_h,
int dil_w,
int n_weight_grps,
int n_offset_grps,
int n_parallel_imgs) {
Expand All @@ -700,15 +697,6 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
int weight_h = weight.size(2);
int weight_w = weight.size(3);

int stride_h = stride.first;
int stride_w = stride.second;

int pad_h = pad.first;
int pad_w = pad.second;

int dil_h = dilation.first;
int dil_w = dilation.second;

long out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) / stride_h + 1;
long out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) / stride_w + 1;

Expand Down Expand Up @@ -813,9 +801,12 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
const at::Tensor& weight,
at::Tensor offset,
const at::Tensor& grad_out,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
int stride_h,
int stride_w,
int pad_h,
int pad_w,
int dil_h,
int dil_w,
int n_weight_grps,
int n_offset_grps,
int n_parallel_imgs) {
Expand All @@ -830,15 +821,6 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
int weight_h = weight.size(2);
int weight_w = weight.size(3);

int stride_h = stride.first;
int stride_w = stride.second;

int pad_h = pad.first;
int pad_w = pad.second;

int dil_h = dilation.first;
int dil_w = dilation.second;

long out_h = grad_out.size(2);
long out_w = grad_out.size(3);

Expand Down Expand Up @@ -922,11 +904,14 @@ DeformConv2d_backward_cpu(
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
int n_weight_grps,
int n_offset_grps) {
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps) {
const int batch_sz = input.size(0);
const int n_parallel_imgs =
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
Expand All @@ -936,9 +921,12 @@ DeformConv2d_backward_cpu(
weight,
offset,
grad_out,
stride,
pad,
dilation,
stride_h,
stride_w,
pad_h,
pad_w,
dil_h,
dil_w,
n_weight_grps,
n_offset_grps,
n_parallel_imgs);
Expand All @@ -951,9 +939,12 @@ DeformConv2d_backward_cpu(
weight,
offset,
grad_out,
stride,
pad,
dilation,
stride_h,
stride_w,
pad_h,
pad_w,
dil_h,
dil_w,
n_weight_grps,
n_offset_grps,
n_parallel_imgs);
Expand Down
Loading