Skip to content

Port DeformConv to use the Dispatcher and support Autocast #2898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 198 additions & 89 deletions torchvision/csrc/DeformConv.h
Original file line number Diff line number Diff line change
@@ -1,89 +1,105 @@
#pragma once

#include "cpu/vision_cpu.h"

#ifdef WITH_CUDA
#include "cuda/vision_cuda.h"
#endif
#ifdef WITH_HIP
#include "hip/vision_cuda.h"
#if defined(WITH_CUDA) || defined(WITH_HIP)
#include "autocast.h"
#endif

at::Tensor DeformConv2d_forward(
// TODO: put this stuff in torchvision namespace

at::Tensor deform_conv2d(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const std::pair<int, int>& stride,
const std::pair<int, int>& padding,
const std::pair<int, int>& dilation,
const int groups,
const int offset_groups) {
if (input.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return DeformConv2d_forward_cuda(
input.contiguous(),
weight.contiguous(),
offset.contiguous(),
bias.contiguous(),
stride,
padding,
dilation,
groups,
offset_groups);
#else
TORCH_CHECK(false, "Not compiled with GPU support");
#endif
}
return DeformConv2d_forward_cpu(
input.contiguous(),
weight.contiguous(),
offset.contiguous(),
bias.contiguous(),
stride,
padding,
dilation,
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::deform_conv2d", "")
.typed<decltype(deform_conv2d)>();
return op.call(
input,
weight,
offset,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> DeformConv2d_backward(
const at::Tensor& grad,
#if defined(WITH_CUDA) || defined(WITH_HIP)
at::Tensor DeformConv2d_autocast(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const std::pair<int, int>& stride,
const std::pair<int, int>& padding,
const std::pair<int, int>& dilation,
const int groups,
const int offset_groups) {
if (grad.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return DeformConv2d_backward_cuda(
grad.contiguous(),
input.contiguous(),
weight.contiguous(),
offset.contiguous(),
bias.contiguous(),
stride,
padding,
dilation,
groups,
offset_groups);
#else
TORCH_CHECK(false, "Not compiled with GPU support");
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return deform_conv2d(
at::autocast::cached_cast(at::kFloat, input),
at::autocast::cached_cast(at::kFloat, weight),
at::autocast::cached_cast(at::kFloat, offset),
at::autocast::cached_cast(at::kFloat, bias),
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups)
.to(input.scalar_type());
}
#endif
}
return DeformConv2d_backward_cpu(
grad.contiguous(),
input.contiguous(),
weight.contiguous(),
offset.contiguous(),
bias.contiguous(),
stride,
padding,
dilation,

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_deform_conv2d_backward(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_deform_conv2d_backward", "")
.typed<decltype(_deform_conv2d_backward)>();
return op.call(
grad,
input,
weight,
offset,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups);
}
Expand All @@ -105,14 +121,18 @@ class DeformConv2dFunction
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
auto output = DeformConv2d_forward(
at::AutoNonVariableTypeMode g; // TODO_vv: check if necessary
auto output = deform_conv2d(
input,
weight,
offset,
bias,
{stride_h, stride_w},
{pad_h, pad_w},
{dilation_h, dilation_w},
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups);

Expand Down Expand Up @@ -149,15 +169,18 @@ class DeformConv2dFunction
auto groups = ctx->saved_data["groups"].toInt();
auto offset_groups = ctx->saved_data["offset_groups"].toInt();

auto grads = DeformConv2d_backward(
auto grads = _deform_conv2d_backward(
grad_output[0],
input,
weight,
offset,
bias,
{stride_h, stride_w},
{pad_h, pad_w},
{dilation_h, dilation_w},
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups);
auto grad_input = std::get<0>(grads);
Expand All @@ -182,20 +205,75 @@ class DeformConv2dFunction
}
};

at::Tensor deform_conv2d(
// TODO: There should be an easier way to do this
class DeformConv2dBackwardFunction
: public torch::autograd::Function<DeformConv2dBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
torch::autograd::Variable grad,
torch::autograd::Variable input,
torch::autograd::Variable weight,
torch::autograd::Variable offset,
torch::autograd::Variable bias,
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
at::AutoNonVariableTypeMode g;
auto result = _deform_conv2d_backward(
grad,
input,
weight,
offset,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups);

auto grad_input = std::get<0>(result);
auto grad_weight = std::get<1>(result);
auto grad_offset = std::get<2>(result);
auto grad_bias = std::get<3>(result);

return {
grad_input,
grad_weight,
grad_offset,
grad_bias,
};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
TORCH_CHECK(0, "double backwards on deform_conv2d not supported");
}
};

at::Tensor DeformConv2d_autograd(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
auto result = DeformConv2dFunction::apply(
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
return DeformConv2dFunction::apply(
input,
weight,
offset,
Expand All @@ -207,6 +285,37 @@ at::Tensor deform_conv2d(
dilation_h,
dilation_w,
groups,
offset_groups);
return result[0];
offset_groups)[0];
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_autograd(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
auto result = DeformConv2dBackwardFunction::apply(
grad,
input,
weight,
offset,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups);
return std::make_tuple(result[0], result[1], result[2], result[3]);
}
Loading