Skip to content

Commit f3c5a2e

Browse files
committed
Registering operators in their files.
1 parent da80ce1 commit f3c5a2e

File tree

4 files changed

+22
-9
lines changed

4 files changed

+22
-9
lines changed

torchvision/csrc/cpu/deform_conv2d_kernel.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@
6666
// modified from
6767
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp
6868

69+
#include <torch/script.h>
70+
6971
#include "deform_conv2d_kernel.h"
7072

7173
namespace {
@@ -1137,3 +1139,8 @@ deform_conv2d_backward_cpu(
11371139
return std::make_tuple(
11381140
grad_input, grad_weight, grad_offset, grad_mask, grad_bias);
11391141
}
1142+
1143+
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
1144+
m.impl("deform_conv2d", deform_conv2d_forward_cpu);
1145+
m.impl("_deform_conv2d_backward", deform_conv2d_backward_cpu);
1146+
}

torchvision/csrc/cuda/deform_conv2d_kernel.cu

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@
6666
// modified from
6767
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp
6868

69-
#include <ATen/ATen.h>
7069
#include <ATen/cuda/CUDAContext.h>
7170
#include <c10/cuda/CUDAGuard.h>
7271
#include <THC/THCAtomics.cuh>
72+
#include <torch/script.h>
7373

7474
#include "cuda_helpers.h"
7575
#include "deform_conv2d_kernel.h"
@@ -1188,3 +1188,8 @@ deform_conv2d_backward_cuda(
11881188
return std::make_tuple(
11891189
grad_input, grad_weight, grad_offset, grad_mask, grad_bias);
11901190
}
1191+
1192+
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
1193+
m.impl("deform_conv2d", deform_conv2d_forward_cuda);
1194+
m.impl("_deform_conv2d_backward", deform_conv2d_backward_cuda);
1195+
}

torchvision/csrc/deform_conv2d.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ at::Tensor deform_conv2d_autocast(
7474
use_mask)
7575
.to(input.scalar_type());
7676
}
77+
78+
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
79+
m.impl("deform_conv2d", deform_conv2d_autocast);
80+
}
7781
#endif
7882

7983
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
@@ -361,3 +365,8 @@ deform_conv2d_backward_autograd(
361365

362366
return std::make_tuple(result[0], result[1], result[2], result[3], result[4]);
363367
}
368+
369+
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
370+
m.impl("deform_conv2d", deform_conv2d_autograd);
371+
m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd);
372+
}

torchvision/csrc/vision.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include "PSROIPool.h"
1313
#include "ROIAlign.h"
1414
#include "ROIPool.h"
15-
#include "deform_conv2d.h"
1615
#include "empty_tensor_op.h"
1716
#include "nms.h"
1817

@@ -62,8 +61,6 @@ TORCH_LIBRARY(torchvision, m) {
6261
}
6362

6463
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
65-
m.impl("deform_conv2d", deform_conv2d_forward_cpu);
66-
m.impl("_deform_conv2d_backward", deform_conv2d_backward_cpu);
6764
m.impl("nms", nms_cpu);
6865
m.impl("ps_roi_align", PSROIAlign_forward_cpu);
6966
m.impl("_ps_roi_align_backward", PSROIAlign_backward_cpu);
@@ -78,8 +75,6 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
7875
// TODO: Place this in a hypothetical separate torchvision_cuda library
7976
#if defined(WITH_CUDA) || defined(WITH_HIP)
8077
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
81-
m.impl("deform_conv2d", deform_conv2d_forward_cuda);
82-
m.impl("_deform_conv2d_backward", deform_conv2d_backward_cuda);
8378
m.impl("nms", nms_cuda);
8479
m.impl("ps_roi_align", PSROIAlign_forward_cuda);
8580
m.impl("_ps_roi_align_backward", PSROIAlign_backward_cuda);
@@ -95,7 +90,6 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
9590
// Autocast only needs to wrap forward pass ops.
9691
#if defined(WITH_CUDA) || defined(WITH_HIP)
9792
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
98-
m.impl("deform_conv2d", deform_conv2d_autocast);
9993
m.impl("nms", nms_autocast);
10094
m.impl("ps_roi_align", PSROIAlign_autocast);
10195
m.impl("ps_roi_pool", PSROIPool_autocast);
@@ -105,8 +99,6 @@ TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
10599
#endif
106100

107101
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
108-
m.impl("deform_conv2d", deform_conv2d_autograd);
109-
m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd);
110102
m.impl("ps_roi_align", PSROIAlign_autograd);
111103
m.impl("_ps_roi_align_backward", PSROIAlign_backward_autograd);
112104
m.impl("ps_roi_pool", PSROIPool_autograd);

0 commit comments

Comments
 (0)