1
1
#pragma once
2
2
3
- #include " cpu/vision_cpu.h"
4
-
5
- #ifdef WITH_CUDA
6
- #include " cuda/vision_cuda.h"
7
- #endif
8
- #ifdef WITH_HIP
9
- #include " hip/vision_cuda.h"
3
+ #if defined(WITH_CUDA) || defined(WITH_HIP)
4
+ #include " autocast.h"
10
5
#endif
11
6
12
- at::Tensor DeformConv2d_forward (
7
+ // TODO: put this stuff in torchvision namespace
8
+
9
+ at::Tensor deform_conv2d (
13
10
const at::Tensor& input,
14
11
const at::Tensor& weight,
15
12
const at::Tensor& offset,
16
13
const at::Tensor& bias,
17
- const std::pair<int , int >& stride,
18
- const std::pair<int , int >& padding,
19
- const std::pair<int , int >& dilation,
20
- const int groups,
21
- const int offset_groups) {
22
- if (input.is_cuda ()) {
23
- #if defined(WITH_CUDA) || defined(WITH_HIP)
24
- return DeformConv2d_forward_cuda (
25
- input.contiguous (),
26
- weight.contiguous (),
27
- offset.contiguous (),
28
- bias.contiguous (),
29
- stride,
30
- padding,
31
- dilation,
32
- groups,
33
- offset_groups);
34
- #else
35
- TORCH_CHECK (false , " Not compiled with GPU support" );
36
- #endif
37
- }
38
- return DeformConv2d_forward_cpu (
39
- input.contiguous (),
40
- weight.contiguous (),
41
- offset.contiguous (),
42
- bias.contiguous (),
43
- stride,
44
- padding,
45
- dilation,
14
+ const int64_t stride_h,
15
+ const int64_t stride_w,
16
+ const int64_t pad_h,
17
+ const int64_t pad_w,
18
+ const int64_t dilation_h,
19
+ const int64_t dilation_w,
20
+ const int64_t groups,
21
+ const int64_t offset_groups) {
22
+ static auto op = c10::Dispatcher::singleton ()
23
+ .findSchemaOrThrow (" torchvision::deform_conv2d" , " " )
24
+ .typed <decltype (deform_conv2d)>();
25
+ return op.call (
26
+ input,
27
+ weight,
28
+ offset,
29
+ bias,
30
+ stride_h,
31
+ stride_w,
32
+ pad_h,
33
+ pad_w,
34
+ dilation_h,
35
+ dilation_w,
46
36
groups,
47
37
offset_groups);
48
38
}
49
39
50
- std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> DeformConv2d_backward (
51
- const at::Tensor& grad,
40
+ # if defined(WITH_CUDA) || defined(WITH_HIP)
41
+ at::Tensor DeformConv2d_autocast (
52
42
const at::Tensor& input,
53
43
const at::Tensor& weight,
54
44
const at::Tensor& offset,
55
45
const at::Tensor& bias,
56
- const std::pair<int , int >& stride,
57
- const std::pair<int , int >& padding,
58
- const std::pair<int , int >& dilation,
59
- const int groups,
60
- const int offset_groups) {
61
- if (grad.is_cuda ()) {
62
- #if defined(WITH_CUDA) || defined(WITH_HIP)
63
- return DeformConv2d_backward_cuda (
64
- grad.contiguous (),
65
- input.contiguous (),
66
- weight.contiguous (),
67
- offset.contiguous (),
68
- bias.contiguous (),
69
- stride,
70
- padding,
71
- dilation,
72
- groups,
73
- offset_groups);
74
- #else
75
- TORCH_CHECK (false , " Not compiled with GPU support" );
46
+ const int64_t stride_h,
47
+ const int64_t stride_w,
48
+ const int64_t pad_h,
49
+ const int64_t pad_w,
50
+ const int64_t dilation_h,
51
+ const int64_t dilation_w,
52
+ const int64_t groups,
53
+ const int64_t offset_groups) {
54
+ c10::impl::ExcludeDispatchKeyGuard no_autocast (c10::DispatchKey::Autocast);
55
+ return deform_conv2d (
56
+ at::autocast::cached_cast (at::kFloat , input),
57
+ at::autocast::cached_cast (at::kFloat , weight),
58
+ at::autocast::cached_cast (at::kFloat , offset),
59
+ at::autocast::cached_cast (at::kFloat , bias),
60
+ stride_h,
61
+ stride_w,
62
+ pad_h,
63
+ pad_w,
64
+ dilation_h,
65
+ dilation_w,
66
+ groups,
67
+ offset_groups)
68
+ .to (input.scalar_type ());
69
+ }
76
70
#endif
77
- }
78
- return DeformConv2d_backward_cpu (
79
- grad.contiguous (),
80
- input.contiguous (),
81
- weight.contiguous (),
82
- offset.contiguous (),
83
- bias.contiguous (),
84
- stride,
85
- padding,
86
- dilation,
71
+
72
+ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
73
+ _deform_conv2d_backward (
74
+ const at::Tensor& grad,
75
+ const at::Tensor& input,
76
+ const at::Tensor& weight,
77
+ const at::Tensor& offset,
78
+ const at::Tensor& bias,
79
+ const int64_t stride_h,
80
+ const int64_t stride_w,
81
+ const int64_t pad_h,
82
+ const int64_t pad_w,
83
+ const int64_t dilation_h,
84
+ const int64_t dilation_w,
85
+ const int64_t groups,
86
+ const int64_t offset_groups) {
87
+ static auto op =
88
+ c10::Dispatcher::singleton ()
89
+ .findSchemaOrThrow (" torchvision::_deform_conv2d_backward" , " " )
90
+ .typed <decltype (_deform_conv2d_backward)>();
91
+ return op.call (
92
+ grad,
93
+ input,
94
+ weight,
95
+ offset,
96
+ bias,
97
+ stride_h,
98
+ stride_w,
99
+ pad_h,
100
+ pad_w,
101
+ dilation_h,
102
+ dilation_w,
87
103
groups,
88
104
offset_groups);
89
105
}
@@ -105,14 +121,18 @@ class DeformConv2dFunction
105
121
int64_t dilation_w,
106
122
int64_t groups,
107
123
int64_t offset_groups) {
108
- auto output = DeformConv2d_forward (
124
+ at::AutoNonVariableTypeMode g; // TODO_vv: check if necessary
125
+ auto output = deform_conv2d (
109
126
input,
110
127
weight,
111
128
offset,
112
129
bias,
113
- {stride_h, stride_w},
114
- {pad_h, pad_w},
115
- {dilation_h, dilation_w},
130
+ stride_h,
131
+ stride_w,
132
+ pad_h,
133
+ pad_w,
134
+ dilation_h,
135
+ dilation_w,
116
136
groups,
117
137
offset_groups);
118
138
@@ -149,15 +169,18 @@ class DeformConv2dFunction
149
169
auto groups = ctx->saved_data [" groups" ].toInt ();
150
170
auto offset_groups = ctx->saved_data [" offset_groups" ].toInt ();
151
171
152
- auto grads = DeformConv2d_backward (
172
+ auto grads = _deform_conv2d_backward (
153
173
grad_output[0 ],
154
174
input,
155
175
weight,
156
176
offset,
157
177
bias,
158
- {stride_h, stride_w},
159
- {pad_h, pad_w},
160
- {dilation_h, dilation_w},
178
+ stride_h,
179
+ stride_w,
180
+ pad_h,
181
+ pad_w,
182
+ dilation_h,
183
+ dilation_w,
161
184
groups,
162
185
offset_groups);
163
186
auto grad_input = std::get<0 >(grads);
@@ -182,20 +205,75 @@ class DeformConv2dFunction
182
205
}
183
206
};
184
207
185
- at::Tensor deform_conv2d (
208
+ // TODO: There should be an easier way to do this
209
+ class DeformConv2dBackwardFunction
210
+ : public torch::autograd::Function<DeformConv2dBackwardFunction> {
211
+ public:
212
+ static torch::autograd::variable_list forward (
213
+ torch::autograd::AutogradContext* ctx,
214
+ torch::autograd::Variable grad,
215
+ torch::autograd::Variable input,
216
+ torch::autograd::Variable weight,
217
+ torch::autograd::Variable offset,
218
+ torch::autograd::Variable bias,
219
+ const int64_t stride_h,
220
+ const int64_t stride_w,
221
+ const int64_t pad_h,
222
+ const int64_t pad_w,
223
+ const int64_t dilation_h,
224
+ const int64_t dilation_w,
225
+ const int64_t groups,
226
+ const int64_t offset_groups) {
227
+ at::AutoNonVariableTypeMode g;
228
+ auto result = _deform_conv2d_backward (
229
+ grad,
230
+ input,
231
+ weight,
232
+ offset,
233
+ bias,
234
+ stride_h,
235
+ stride_w,
236
+ pad_h,
237
+ pad_w,
238
+ dilation_h,
239
+ dilation_w,
240
+ groups,
241
+ offset_groups);
242
+
243
+ auto grad_input = std::get<0 >(result);
244
+ auto grad_weight = std::get<1 >(result);
245
+ auto grad_offset = std::get<2 >(result);
246
+ auto grad_bias = std::get<3 >(result);
247
+
248
+ return {
249
+ grad_input,
250
+ grad_weight,
251
+ grad_offset,
252
+ grad_bias,
253
+ };
254
+ }
255
+
256
+ static torch::autograd::variable_list backward (
257
+ torch::autograd::AutogradContext* ctx,
258
+ torch::autograd::variable_list grad_output) {
259
+ TORCH_CHECK (0 , " double backwards on deform_conv2d not supported" );
260
+ }
261
+ };
262
+
263
+ at::Tensor DeformConv2d_autograd (
186
264
const at::Tensor& input,
187
265
const at::Tensor& weight,
188
266
const at::Tensor& offset,
189
267
const at::Tensor& bias,
190
- int64_t stride_h,
191
- int64_t stride_w,
192
- int64_t pad_h,
193
- int64_t pad_w,
194
- int64_t dilation_h,
195
- int64_t dilation_w,
196
- int64_t groups,
197
- int64_t offset_groups) {
198
- auto result = DeformConv2dFunction::apply (
268
+ const int64_t stride_h,
269
+ const int64_t stride_w,
270
+ const int64_t pad_h,
271
+ const int64_t pad_w,
272
+ const int64_t dilation_h,
273
+ const int64_t dilation_w,
274
+ const int64_t groups,
275
+ const int64_t offset_groups) {
276
+ return DeformConv2dFunction::apply (
199
277
input,
200
278
weight,
201
279
offset,
@@ -207,6 +285,37 @@ at::Tensor deform_conv2d(
207
285
dilation_h,
208
286
dilation_w,
209
287
groups,
210
- offset_groups);
211
- return result[0 ];
288
+ offset_groups)[0 ];
212
289
}
290
+
291
+ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
292
+ DeformConv2d_backward_autograd (
293
+ const at::Tensor& grad,
294
+ const at::Tensor& input,
295
+ const at::Tensor& weight,
296
+ const at::Tensor& offset,
297
+ const at::Tensor& bias,
298
+ const int64_t stride_h,
299
+ const int64_t stride_w,
300
+ const int64_t pad_h,
301
+ const int64_t pad_w,
302
+ const int64_t dilation_h,
303
+ const int64_t dilation_w,
304
+ const int64_t groups,
305
+ const int64_t offset_groups) {
306
+ auto result = DeformConv2dBackwardFunction::apply (
307
+ grad,
308
+ input,
309
+ weight,
310
+ offset,
311
+ bias,
312
+ stride_h,
313
+ stride_w,
314
+ pad_h,
315
+ pad_w,
316
+ dilation_h,
317
+ dilation_w,
318
+ groups,
319
+ offset_groups);
320
+ return std::make_tuple (result[0 ], result[1 ], result[2 ], result[3 ]);
321
+ }
0 commit comments