3
3
#include " cpu/vision_cpu.h"
4
4
5
5
#ifdef WITH_CUDA
6
+ #include " autocast.h"
6
7
#include " cuda/vision_cuda.h"
7
8
#endif
8
9
#ifdef WITH_HIP
10
+ #include " autocast.h"
9
11
#include " hip/vision_cuda.h"
10
12
#endif
11
13
12
- at::Tensor DeformConv2d_forward (
14
+ // TODO: put this stuff in torchvision namespace
15
+
16
+ at::Tensor deform_conv2d (
13
17
const at::Tensor& input,
14
18
const at::Tensor& weight,
15
19
const at::Tensor& offset,
@@ -22,26 +26,10 @@ at::Tensor DeformConv2d_forward(
22
26
const int64_t dilation_w,
23
27
const int64_t groups,
24
28
const int64_t offset_groups) {
25
- if (input.is_cuda ()) {
26
- #if defined(WITH_CUDA) || defined(WITH_HIP)
27
- return DeformConv2d_forward_cuda (
28
- input.contiguous (),
29
- weight.contiguous (),
30
- offset.contiguous (),
31
- bias.contiguous (),
32
- stride_h,
33
- stride_w,
34
- pad_h,
35
- pad_w,
36
- dilation_h,
37
- dilation_w,
38
- groups,
39
- offset_groups);
40
- #else
41
- TORCH_CHECK (false , " Not compiled with GPU support" );
42
- #endif
43
- }
44
- return DeformConv2d_forward_cpu (
29
+ static auto op = c10::Dispatcher::singleton ()
30
+ .findSchemaOrThrow (" torchvision::deform_conv2d" , " " )
31
+ .typed <decltype (deform_conv2d)>();
32
+ return op.call (
45
33
input.contiguous (),
46
34
weight.contiguous (),
47
35
offset.contiguous (),
@@ -56,8 +44,8 @@ at::Tensor DeformConv2d_forward(
56
44
offset_groups);
57
45
}
58
46
59
- std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> DeformConv2d_backward (
60
- const at::Tensor& grad,
47
+ # if defined(WITH_CUDA) || defined(WITH_HIP)
48
+ at::Tensor DeformConv2d_autocast (
61
49
const at::Tensor& input,
62
50
const at::Tensor& weight,
63
51
const at::Tensor& offset,
@@ -70,27 +58,44 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> DeformConv2d_backward
70
58
const int64_t dilation_w,
71
59
const int64_t groups,
72
60
const int64_t offset_groups) {
73
- if (grad.is_cuda ()) {
74
- #if defined(WITH_CUDA) || defined(WITH_HIP)
75
- return DeformConv2d_backward_cuda (
76
- grad.contiguous (),
77
- input.contiguous (),
78
- weight.contiguous (),
79
- offset.contiguous (),
80
- bias.contiguous (),
81
- stride_h,
82
- stride_w,
83
- pad_h,
84
- pad_w,
85
- dilation_h,
86
- dilation_w,
87
- groups,
88
- offset_groups);
89
- #else
90
- TORCH_CHECK (false , " Not compiled with GPU support" );
61
+ c10::impl::ExcludeDispatchKeyGuard no_autocast (c10::DispatchKey::Autocast);
62
+ return deform_conv2d (
63
+ at::autocast::cached_cast (at::kFloat , input),
64
+ at::autocast::cached_cast (at::kFloat , weight),
65
+ at::autocast::cached_cast (at::kFloat , offset),
66
+ at::autocast::cached_cast (at::kFloat , bias),
67
+ stride_h,
68
+ stride_w,
69
+ pad_h,
70
+ pad_w,
71
+ dilation_h,
72
+ dilation_w,
73
+ groups,
74
+ offset_groups)
75
+ .to (input.scalar_type ());
76
+ }
91
77
#endif
92
- }
93
- return DeformConv2d_backward_cpu (
78
+
79
+ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
80
+ _deform_conv2d_backward (
81
+ const at::Tensor& grad,
82
+ const at::Tensor& input,
83
+ const at::Tensor& weight,
84
+ const at::Tensor& offset,
85
+ const at::Tensor& bias,
86
+ const int64_t stride_h,
87
+ const int64_t stride_w,
88
+ const int64_t pad_h,
89
+ const int64_t pad_w,
90
+ const int64_t dilation_h,
91
+ const int64_t dilation_w,
92
+ const int64_t groups,
93
+ const int64_t offset_groups) {
94
+ static auto op =
95
+ c10::Dispatcher::singleton ()
96
+ .findSchemaOrThrow (" torchvision::_deform_conv2d_backward" , " " )
97
+ .typed <decltype (_deform_conv2d_backward)>();
98
+ return op.call (
94
99
grad.contiguous (),
95
100
input.contiguous (),
96
101
weight.contiguous (),
@@ -123,7 +128,8 @@ class DeformConv2dFunction
123
128
int64_t dilation_w,
124
129
int64_t groups,
125
130
int64_t offset_groups) {
126
- auto output = DeformConv2d_forward (
131
+ at::AutoNonVariableTypeMode g; // TODO: check if necessary
132
+ auto output = deform_conv2d (
127
133
input,
128
134
weight,
129
135
offset,
@@ -170,7 +176,7 @@ class DeformConv2dFunction
170
176
auto groups = ctx->saved_data [" groups" ].toInt ();
171
177
auto offset_groups = ctx->saved_data [" offset_groups" ].toInt ();
172
178
173
- auto grads = DeformConv2d_backward (
179
+ auto grads = _deform_conv2d_backward (
174
180
grad_output[0 ],
175
181
input,
176
182
weight,
@@ -205,32 +211,3 @@ class DeformConv2dFunction
205
211
};
206
212
}
207
213
};
208
-
209
- at::Tensor deform_conv2d (
210
- const at::Tensor& input,
211
- const at::Tensor& weight,
212
- const at::Tensor& offset,
213
- const at::Tensor& bias,
214
- int64_t stride_h,
215
- int64_t stride_w,
216
- int64_t pad_h,
217
- int64_t pad_w,
218
- int64_t dilation_h,
219
- int64_t dilation_w,
220
- int64_t groups,
221
- int64_t offset_groups) {
222
- auto result = DeformConv2dFunction::apply (
223
- input,
224
- weight,
225
- offset,
226
- bias,
227
- stride_h,
228
- stride_w,
229
- pad_h,
230
- pad_w,
231
- dilation_h,
232
- dilation_w,
233
- groups,
234
- offset_groups);
235
- return result[0 ];
236
- }
0 commit comments