Skip to content

Commit e8b6e3f

Browse files
authored
Port DeformConv to use the Dispatcher and support Autocast (#2898)
* Splitting tuples of stride, padding and dilation of DeformConv. * Fixing types. * Dispatcher + Autocast. * + Autograd. * Moving contiguous() convertions away dispatcher and into the implementations. * Removing rvalue references.
1 parent f9e31a6 commit e8b6e3f

File tree

6 files changed

+353
-232
lines changed

6 files changed

+353
-232
lines changed

torchvision/csrc/DeformConv.h

Lines changed: 198 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,89 +1,105 @@
11
#pragma once
22

3-
#include "cpu/vision_cpu.h"
4-
5-
#ifdef WITH_CUDA
6-
#include "cuda/vision_cuda.h"
7-
#endif
8-
#ifdef WITH_HIP
9-
#include "hip/vision_cuda.h"
3+
#if defined(WITH_CUDA) || defined(WITH_HIP)
4+
#include "autocast.h"
105
#endif
116

12-
at::Tensor DeformConv2d_forward(
7+
// TODO: put this stuff in torchvision namespace
8+
9+
at::Tensor deform_conv2d(
1310
const at::Tensor& input,
1411
const at::Tensor& weight,
1512
const at::Tensor& offset,
1613
const at::Tensor& bias,
17-
const std::pair<int, int>& stride,
18-
const std::pair<int, int>& padding,
19-
const std::pair<int, int>& dilation,
20-
const int groups,
21-
const int offset_groups) {
22-
if (input.is_cuda()) {
23-
#if defined(WITH_CUDA) || defined(WITH_HIP)
24-
return DeformConv2d_forward_cuda(
25-
input.contiguous(),
26-
weight.contiguous(),
27-
offset.contiguous(),
28-
bias.contiguous(),
29-
stride,
30-
padding,
31-
dilation,
32-
groups,
33-
offset_groups);
34-
#else
35-
TORCH_CHECK(false, "Not compiled with GPU support");
36-
#endif
37-
}
38-
return DeformConv2d_forward_cpu(
39-
input.contiguous(),
40-
weight.contiguous(),
41-
offset.contiguous(),
42-
bias.contiguous(),
43-
stride,
44-
padding,
45-
dilation,
14+
const int64_t stride_h,
15+
const int64_t stride_w,
16+
const int64_t pad_h,
17+
const int64_t pad_w,
18+
const int64_t dilation_h,
19+
const int64_t dilation_w,
20+
const int64_t groups,
21+
const int64_t offset_groups) {
22+
static auto op = c10::Dispatcher::singleton()
23+
.findSchemaOrThrow("torchvision::deform_conv2d", "")
24+
.typed<decltype(deform_conv2d)>();
25+
return op.call(
26+
input,
27+
weight,
28+
offset,
29+
bias,
30+
stride_h,
31+
stride_w,
32+
pad_h,
33+
pad_w,
34+
dilation_h,
35+
dilation_w,
4636
groups,
4737
offset_groups);
4838
}
4939

50-
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> DeformConv2d_backward(
51-
const at::Tensor& grad,
40+
#if defined(WITH_CUDA) || defined(WITH_HIP)
41+
at::Tensor DeformConv2d_autocast(
5242
const at::Tensor& input,
5343
const at::Tensor& weight,
5444
const at::Tensor& offset,
5545
const at::Tensor& bias,
56-
const std::pair<int, int>& stride,
57-
const std::pair<int, int>& padding,
58-
const std::pair<int, int>& dilation,
59-
const int groups,
60-
const int offset_groups) {
61-
if (grad.is_cuda()) {
62-
#if defined(WITH_CUDA) || defined(WITH_HIP)
63-
return DeformConv2d_backward_cuda(
64-
grad.contiguous(),
65-
input.contiguous(),
66-
weight.contiguous(),
67-
offset.contiguous(),
68-
bias.contiguous(),
69-
stride,
70-
padding,
71-
dilation,
72-
groups,
73-
offset_groups);
74-
#else
75-
TORCH_CHECK(false, "Not compiled with GPU support");
46+
const int64_t stride_h,
47+
const int64_t stride_w,
48+
const int64_t pad_h,
49+
const int64_t pad_w,
50+
const int64_t dilation_h,
51+
const int64_t dilation_w,
52+
const int64_t groups,
53+
const int64_t offset_groups) {
54+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
55+
return deform_conv2d(
56+
at::autocast::cached_cast(at::kFloat, input),
57+
at::autocast::cached_cast(at::kFloat, weight),
58+
at::autocast::cached_cast(at::kFloat, offset),
59+
at::autocast::cached_cast(at::kFloat, bias),
60+
stride_h,
61+
stride_w,
62+
pad_h,
63+
pad_w,
64+
dilation_h,
65+
dilation_w,
66+
groups,
67+
offset_groups)
68+
.to(input.scalar_type());
69+
}
7670
#endif
77-
}
78-
return DeformConv2d_backward_cpu(
79-
grad.contiguous(),
80-
input.contiguous(),
81-
weight.contiguous(),
82-
offset.contiguous(),
83-
bias.contiguous(),
84-
stride,
85-
padding,
86-
dilation,
71+
72+
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
73+
_deform_conv2d_backward(
74+
const at::Tensor& grad,
75+
const at::Tensor& input,
76+
const at::Tensor& weight,
77+
const at::Tensor& offset,
78+
const at::Tensor& bias,
79+
const int64_t stride_h,
80+
const int64_t stride_w,
81+
const int64_t pad_h,
82+
const int64_t pad_w,
83+
const int64_t dilation_h,
84+
const int64_t dilation_w,
85+
const int64_t groups,
86+
const int64_t offset_groups) {
87+
static auto op =
88+
c10::Dispatcher::singleton()
89+
.findSchemaOrThrow("torchvision::_deform_conv2d_backward", "")
90+
.typed<decltype(_deform_conv2d_backward)>();
91+
return op.call(
92+
grad,
93+
input,
94+
weight,
95+
offset,
96+
bias,
97+
stride_h,
98+
stride_w,
99+
pad_h,
100+
pad_w,
101+
dilation_h,
102+
dilation_w,
87103
groups,
88104
offset_groups);
89105
}
@@ -105,14 +121,18 @@ class DeformConv2dFunction
105121
int64_t dilation_w,
106122
int64_t groups,
107123
int64_t offset_groups) {
108-
auto output = DeformConv2d_forward(
124+
at::AutoNonVariableTypeMode g; // TODO_vv: check if necessary
125+
auto output = deform_conv2d(
109126
input,
110127
weight,
111128
offset,
112129
bias,
113-
{stride_h, stride_w},
114-
{pad_h, pad_w},
115-
{dilation_h, dilation_w},
130+
stride_h,
131+
stride_w,
132+
pad_h,
133+
pad_w,
134+
dilation_h,
135+
dilation_w,
116136
groups,
117137
offset_groups);
118138

@@ -149,15 +169,18 @@ class DeformConv2dFunction
149169
auto groups = ctx->saved_data["groups"].toInt();
150170
auto offset_groups = ctx->saved_data["offset_groups"].toInt();
151171

152-
auto grads = DeformConv2d_backward(
172+
auto grads = _deform_conv2d_backward(
153173
grad_output[0],
154174
input,
155175
weight,
156176
offset,
157177
bias,
158-
{stride_h, stride_w},
159-
{pad_h, pad_w},
160-
{dilation_h, dilation_w},
178+
stride_h,
179+
stride_w,
180+
pad_h,
181+
pad_w,
182+
dilation_h,
183+
dilation_w,
161184
groups,
162185
offset_groups);
163186
auto grad_input = std::get<0>(grads);
@@ -182,20 +205,75 @@ class DeformConv2dFunction
182205
}
183206
};
184207

185-
at::Tensor deform_conv2d(
208+
// TODO: There should be an easier way to do this
209+
class DeformConv2dBackwardFunction
210+
: public torch::autograd::Function<DeformConv2dBackwardFunction> {
211+
public:
212+
static torch::autograd::variable_list forward(
213+
torch::autograd::AutogradContext* ctx,
214+
torch::autograd::Variable grad,
215+
torch::autograd::Variable input,
216+
torch::autograd::Variable weight,
217+
torch::autograd::Variable offset,
218+
torch::autograd::Variable bias,
219+
const int64_t stride_h,
220+
const int64_t stride_w,
221+
const int64_t pad_h,
222+
const int64_t pad_w,
223+
const int64_t dilation_h,
224+
const int64_t dilation_w,
225+
const int64_t groups,
226+
const int64_t offset_groups) {
227+
at::AutoNonVariableTypeMode g;
228+
auto result = _deform_conv2d_backward(
229+
grad,
230+
input,
231+
weight,
232+
offset,
233+
bias,
234+
stride_h,
235+
stride_w,
236+
pad_h,
237+
pad_w,
238+
dilation_h,
239+
dilation_w,
240+
groups,
241+
offset_groups);
242+
243+
auto grad_input = std::get<0>(result);
244+
auto grad_weight = std::get<1>(result);
245+
auto grad_offset = std::get<2>(result);
246+
auto grad_bias = std::get<3>(result);
247+
248+
return {
249+
grad_input,
250+
grad_weight,
251+
grad_offset,
252+
grad_bias,
253+
};
254+
}
255+
256+
static torch::autograd::variable_list backward(
257+
torch::autograd::AutogradContext* ctx,
258+
torch::autograd::variable_list grad_output) {
259+
TORCH_CHECK(0, "double backwards on deform_conv2d not supported");
260+
}
261+
};
262+
263+
at::Tensor DeformConv2d_autograd(
186264
const at::Tensor& input,
187265
const at::Tensor& weight,
188266
const at::Tensor& offset,
189267
const at::Tensor& bias,
190-
int64_t stride_h,
191-
int64_t stride_w,
192-
int64_t pad_h,
193-
int64_t pad_w,
194-
int64_t dilation_h,
195-
int64_t dilation_w,
196-
int64_t groups,
197-
int64_t offset_groups) {
198-
auto result = DeformConv2dFunction::apply(
268+
const int64_t stride_h,
269+
const int64_t stride_w,
270+
const int64_t pad_h,
271+
const int64_t pad_w,
272+
const int64_t dilation_h,
273+
const int64_t dilation_w,
274+
const int64_t groups,
275+
const int64_t offset_groups) {
276+
return DeformConv2dFunction::apply(
199277
input,
200278
weight,
201279
offset,
@@ -207,6 +285,37 @@ at::Tensor deform_conv2d(
207285
dilation_h,
208286
dilation_w,
209287
groups,
210-
offset_groups);
211-
return result[0];
288+
offset_groups)[0];
212289
}
290+
291+
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
292+
DeformConv2d_backward_autograd(
293+
const at::Tensor& grad,
294+
const at::Tensor& input,
295+
const at::Tensor& weight,
296+
const at::Tensor& offset,
297+
const at::Tensor& bias,
298+
const int64_t stride_h,
299+
const int64_t stride_w,
300+
const int64_t pad_h,
301+
const int64_t pad_w,
302+
const int64_t dilation_h,
303+
const int64_t dilation_w,
304+
const int64_t groups,
305+
const int64_t offset_groups) {
306+
auto result = DeformConv2dBackwardFunction::apply(
307+
grad,
308+
input,
309+
weight,
310+
offset,
311+
bias,
312+
stride_h,
313+
stride_w,
314+
pad_h,
315+
pad_w,
316+
dilation_h,
317+
dilation_w,
318+
groups,
319+
offset_groups);
320+
return std::make_tuple(result[0], result[1], result[2], result[3]);
321+
}

0 commit comments

Comments
 (0)