Skip to content

Commit f544463

Browse files
committed
Dispatcher + Autocast.
1 parent 098d0ef commit f544463

File tree

2 files changed

+60
-75
lines changed

2 files changed

+60
-75
lines changed

torchvision/csrc/DeformConv.h

Lines changed: 51 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@
33
#include "cpu/vision_cpu.h"
44

55
#ifdef WITH_CUDA
6+
#include "autocast.h"
67
#include "cuda/vision_cuda.h"
78
#endif
89
#ifdef WITH_HIP
10+
#include "autocast.h"
911
#include "hip/vision_cuda.h"
1012
#endif
1113

12-
at::Tensor DeformConv2d_forward(
14+
// TODO: put this stuff in torchvision namespace
15+
16+
at::Tensor deform_conv2d(
1317
const at::Tensor& input,
1418
const at::Tensor& weight,
1519
const at::Tensor& offset,
@@ -22,26 +26,10 @@ at::Tensor DeformConv2d_forward(
2226
const int64_t dilation_w,
2327
const int64_t groups,
2428
const int64_t offset_groups) {
25-
if (input.is_cuda()) {
26-
#if defined(WITH_CUDA) || defined(WITH_HIP)
27-
return DeformConv2d_forward_cuda(
28-
input.contiguous(),
29-
weight.contiguous(),
30-
offset.contiguous(),
31-
bias.contiguous(),
32-
stride_h,
33-
stride_w,
34-
pad_h,
35-
pad_w,
36-
dilation_h,
37-
dilation_w,
38-
groups,
39-
offset_groups);
40-
#else
41-
TORCH_CHECK(false, "Not compiled with GPU support");
42-
#endif
43-
}
44-
return DeformConv2d_forward_cpu(
29+
static auto op = c10::Dispatcher::singleton()
30+
.findSchemaOrThrow("torchvision::deform_conv2d", "")
31+
.typed<decltype(deform_conv2d)>();
32+
return op.call(
4533
input.contiguous(),
4634
weight.contiguous(),
4735
offset.contiguous(),
@@ -56,8 +44,8 @@ at::Tensor DeformConv2d_forward(
5644
offset_groups);
5745
}
5846

59-
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> DeformConv2d_backward(
60-
const at::Tensor& grad,
47+
#if defined(WITH_CUDA) || defined(WITH_HIP)
48+
at::Tensor DeformConv2d_autocast(
6149
const at::Tensor& input,
6250
const at::Tensor& weight,
6351
const at::Tensor& offset,
@@ -70,27 +58,44 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> DeformConv2d_backward
7058
const int64_t dilation_w,
7159
const int64_t groups,
7260
const int64_t offset_groups) {
73-
if (grad.is_cuda()) {
74-
#if defined(WITH_CUDA) || defined(WITH_HIP)
75-
return DeformConv2d_backward_cuda(
76-
grad.contiguous(),
77-
input.contiguous(),
78-
weight.contiguous(),
79-
offset.contiguous(),
80-
bias.contiguous(),
81-
stride_h,
82-
stride_w,
83-
pad_h,
84-
pad_w,
85-
dilation_h,
86-
dilation_w,
87-
groups,
88-
offset_groups);
89-
#else
90-
TORCH_CHECK(false, "Not compiled with GPU support");
61+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
62+
return deform_conv2d(
63+
at::autocast::cached_cast(at::kFloat, input),
64+
at::autocast::cached_cast(at::kFloat, weight),
65+
at::autocast::cached_cast(at::kFloat, offset),
66+
at::autocast::cached_cast(at::kFloat, bias),
67+
stride_h,
68+
stride_w,
69+
pad_h,
70+
pad_w,
71+
dilation_h,
72+
dilation_w,
73+
groups,
74+
offset_groups)
75+
.to(input.scalar_type());
76+
}
9177
#endif
92-
}
93-
return DeformConv2d_backward_cpu(
78+
79+
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
80+
_deform_conv2d_backward(
81+
const at::Tensor& grad,
82+
const at::Tensor& input,
83+
const at::Tensor& weight,
84+
const at::Tensor& offset,
85+
const at::Tensor& bias,
86+
const int64_t stride_h,
87+
const int64_t stride_w,
88+
const int64_t pad_h,
89+
const int64_t pad_w,
90+
const int64_t dilation_h,
91+
const int64_t dilation_w,
92+
const int64_t groups,
93+
const int64_t offset_groups) {
94+
static auto op =
95+
c10::Dispatcher::singleton()
96+
.findSchemaOrThrow("torchvision::_deform_conv2d_backward", "")
97+
.typed<decltype(_deform_conv2d_backward)>();
98+
return op.call(
9499
grad.contiguous(),
95100
input.contiguous(),
96101
weight.contiguous(),
@@ -123,7 +128,8 @@ class DeformConv2dFunction
123128
int64_t dilation_w,
124129
int64_t groups,
125130
int64_t offset_groups) {
126-
auto output = DeformConv2d_forward(
131+
at::AutoNonVariableTypeMode g; // TODO: check if necessary
132+
auto output = deform_conv2d(
127133
input,
128134
weight,
129135
offset,
@@ -170,7 +176,7 @@ class DeformConv2dFunction
170176
auto groups = ctx->saved_data["groups"].toInt();
171177
auto offset_groups = ctx->saved_data["offset_groups"].toInt();
172178

173-
auto grads = DeformConv2d_backward(
179+
auto grads = _deform_conv2d_backward(
174180
grad_output[0],
175181
input,
176182
weight,
@@ -205,32 +211,3 @@ class DeformConv2dFunction
205211
};
206212
}
207213
};
208-
209-
at::Tensor deform_conv2d(
210-
const at::Tensor& input,
211-
const at::Tensor& weight,
212-
const at::Tensor& offset,
213-
const at::Tensor& bias,
214-
int64_t stride_h,
215-
int64_t stride_w,
216-
int64_t pad_h,
217-
int64_t pad_w,
218-
int64_t dilation_h,
219-
int64_t dilation_w,
220-
int64_t groups,
221-
int64_t offset_groups) {
222-
auto result = DeformConv2dFunction::apply(
223-
input,
224-
weight,
225-
offset,
226-
bias,
227-
stride_h,
228-
stride_w,
229-
pad_h,
230-
pad_w,
231-
dilation_h,
232-
dilation_w,
233-
groups,
234-
offset_groups);
235-
return result[0];
236-
}

torchvision/csrc/vision.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,18 @@ TORCH_LIBRARY(torchvision, m) {
5454
m.def("_new_empty_tensor_op", &new_empty_tensor);
5555
m.def("ps_roi_align", &ps_roi_align);
5656
m.def("ps_roi_pool", &ps_roi_pool);
57-
m.def("deform_conv2d", &deform_conv2d);
57+
m.def(
58+
"deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> Tensor");
59+
m.def(
60+
"_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> (Tensor, Tensor, Tensor, Tensor)");
5861
m.def("_cuda_version", &vision::cuda_version);
5962
}
6063

6164
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
6265
m.impl("roi_align", ROIAlign_forward_cpu);
6366
m.impl("_roi_align_backward", ROIAlign_backward_cpu);
67+
m.impl("deform_conv2d", DeformConv2d_forward_cpu);
68+
m.impl("_deform_conv2d_backward", DeformConv2d_backward_cpu);
6469
m.impl("nms", nms_cpu);
6570
}
6671

@@ -69,6 +74,8 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
6974
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
7075
m.impl("roi_align", ROIAlign_forward_cuda);
7176
m.impl("_roi_align_backward", ROIAlign_backward_cuda);
77+
m.impl("deform_conv2d", DeformConv2d_forward_cuda);
78+
m.impl("_deform_conv2d_backward", DeformConv2d_backward_cuda);
7279
m.impl("nms", nms_cuda);
7380
}
7481
#endif
@@ -77,6 +84,7 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
7784
#if defined(WITH_CUDA) || defined(WITH_HIP)
7885
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
7986
m.impl("roi_align", ROIAlign_autocast);
87+
m.impl("deform_conv2d", DeformConv2d_autocast);
8088
m.impl("nms", nms_autocast);
8189
}
8290
#endif

0 commit comments

Comments
 (0)