Skip to content

Commit 9026df0

Browse files
committed
Use header files:
- Create header files for kernel implementation and remove definitions from vision_*.h files. - Eliminate unnecessary headers and ensure all cpp include their headers.
1 parent 602acb2 commit 9026df0

9 files changed

+152
-90
lines changed

torchvision/csrc/autocast.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#pragma once
22

3+
// TODO: Delete this file once none of the methods use it
4+
35
#if defined(WITH_CUDA) || defined(WITH_HIP)
46
#include <ATen/autocast_mode.h>
57
#endif

torchvision/csrc/cpu/deform_conv2d_cpu.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,7 @@
6666
// modified from
6767
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp
6868

69-
#include <ATen/ATen.h>
70-
#include <ATen/TensorUtils.h>
71-
#include <TH/TH.h>
72-
73-
#include <cmath>
74-
#include <iostream>
75-
#include <tuple>
69+
#include "deform_conv2d_cpu.h"
7670

7771
namespace {
7872

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include "../macros.h"
5+
6+
VISION_API at::Tensor deform_conv2d_forward_cpu(
7+
const at::Tensor& input_param,
8+
const at::Tensor& weight_param,
9+
const at::Tensor& offset_param,
10+
const at::Tensor& mask_param,
11+
const at::Tensor& bias_param,
12+
int64_t stride_h,
13+
int64_t stride_w,
14+
int64_t pad_h,
15+
int64_t pad_w,
16+
int64_t dil_h,
17+
int64_t dil_w,
18+
int64_t n_weight_grps,
19+
int64_t n_offset_grps,
20+
bool use_mask);
21+
22+
VISION_API std::
23+
tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
24+
deform_conv2d_backward_cpu(
25+
const at::Tensor& grad_out_param,
26+
const at::Tensor& input_param,
27+
const at::Tensor& weight_param,
28+
const at::Tensor& offset_param,
29+
const at::Tensor& mask_param,
30+
const at::Tensor& bias_param,
31+
int64_t stride_h,
32+
int64_t stride_w,
33+
int64_t pad_h,
34+
int64_t pad_w,
35+
int64_t dil_h,
36+
int64_t dil_w,
37+
int64_t n_weight_grps,
38+
int64_t n_offset_grps,
39+
bool use_mask);

torchvision/csrc/cpu/vision_cpu.h

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,40 +2,7 @@
22
#include <torch/extension.h>
33
#include "../macros.h"
44

5-
VISION_API at::Tensor deform_conv2d_forward_cpu(
6-
const at::Tensor& input_param,
7-
const at::Tensor& weight_param,
8-
const at::Tensor& offset_param,
9-
const at::Tensor& mask_param,
10-
const at::Tensor& bias_param,
11-
int64_t stride_h,
12-
int64_t stride_w,
13-
int64_t pad_h,
14-
int64_t pad_w,
15-
int64_t dil_h,
16-
int64_t dil_w,
17-
int64_t n_weight_grps,
18-
int64_t n_offset_grps,
19-
bool use_mask);
20-
21-
VISION_API std::
22-
tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
23-
deform_conv2d_backward_cpu(
24-
const at::Tensor& grad_out_param,
25-
const at::Tensor& input_param,
26-
const at::Tensor& weight_param,
27-
const at::Tensor& offset_param,
28-
const at::Tensor& mask_param,
29-
const at::Tensor& bias_param,
30-
int64_t stride_h,
31-
int64_t stride_w,
32-
int64_t pad_h,
33-
int64_t pad_w,
34-
int64_t dil_h,
35-
int64_t dil_w,
36-
int64_t n_weight_grps,
37-
int64_t n_offset_grps,
38-
bool use_mask);
5+
// TODO: Delete this file once all the methods are gone
396

407
VISION_API at::Tensor nms_cpu(
418
const at::Tensor& dets,

torchvision/csrc/cuda/deform_conv2d_cuda.cu

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,13 @@
6767
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp
6868

6969
#include <ATen/ATen.h>
70-
#include <ATen/TensorUtils.h>
7170
#include <ATen/cuda/CUDAContext.h>
7271
#include <c10/cuda/CUDAGuard.h>
7372
#include <THC/THCAtomics.cuh>
7473

74+
#include "deform_conv2d_cuda.h"
7575
#include "cuda_helpers.h"
7676

77-
#include <cmath>
78-
#include <iostream>
79-
#include <tuple>
80-
8177
namespace {
8278

8379
const int kMaxParallelImgs = 32;
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include "../macros.h"
5+
6+
VISION_API at::Tensor deform_conv2d_forward_cuda(
7+
const at::Tensor& input_param,
8+
const at::Tensor& weight_param,
9+
const at::Tensor& offset_param,
10+
const at::Tensor& mask_param,
11+
const at::Tensor& bias_param,
12+
int64_t stride_h,
13+
int64_t stride_w,
14+
int64_t pad_h,
15+
int64_t pad_w,
16+
int64_t dil_h,
17+
int64_t dil_w,
18+
int64_t n_weight_grps,
19+
int64_t n_offset_grps,
20+
bool use_mask);
21+
22+
VISION_API std::
23+
tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
24+
deform_conv2d_backward_cuda(
25+
const at::Tensor& grad_out_param,
26+
const at::Tensor& input_param,
27+
const at::Tensor& weight_param,
28+
const at::Tensor& offset_param,
29+
const at::Tensor& mask_param,
30+
const at::Tensor& bias_param,
31+
int64_t stride_h,
32+
int64_t stride_w,
33+
int64_t pad_h,
34+
int64_t pad_w,
35+
int64_t dil_h,
36+
int64_t dil_w,
37+
int64_t n_weight_grps,
38+
int64_t n_offset_grps,
39+
bool use_mask);

torchvision/csrc/cuda/vision_cuda.h

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,40 +2,7 @@
22
#include <torch/extension.h>
33
#include "../macros.h"
44

5-
VISION_API at::Tensor deform_conv2d_forward_cuda(
6-
const at::Tensor& input_param,
7-
const at::Tensor& weight_param,
8-
const at::Tensor& offset_param,
9-
const at::Tensor& mask_param,
10-
const at::Tensor& bias_param,
11-
int64_t stride_h,
12-
int64_t stride_w,
13-
int64_t pad_h,
14-
int64_t pad_w,
15-
int64_t dil_h,
16-
int64_t dil_w,
17-
int64_t n_weight_grps,
18-
int64_t n_offset_grps,
19-
bool use_mask);
20-
21-
VISION_API std::
22-
tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
23-
deform_conv2d_backward_cuda(
24-
const at::Tensor& grad_out_param,
25-
const at::Tensor& input_param,
26-
const at::Tensor& weight_param,
27-
const at::Tensor& offset_param,
28-
const at::Tensor& mask_param,
29-
const at::Tensor& bias_param,
30-
int64_t stride_h,
31-
int64_t stride_w,
32-
int64_t pad_h,
33-
int64_t pad_w,
34-
int64_t dil_h,
35-
int64_t dil_w,
36-
int64_t n_weight_grps,
37-
int64_t n_offset_grps,
38-
bool use_mask);
5+
// TODO: Delete this file once all the methods are gone
396

407
VISION_API at::Tensor nms_cuda(
418
const at::Tensor& dets,

torchvision/csrc/deform_conv2d.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
1-
#pragma once
1+
#include "deform_conv2d.h"
2+
#include <torch/extension.h>
23

3-
#include "cpu/vision_cpu.h"
4-
5-
#ifdef WITH_CUDA
6-
#include "autocast.h"
7-
#include "cuda/vision_cuda.h"
8-
#endif
9-
#ifdef WITH_HIP
10-
#include "autocast.h"
11-
#include "hip/vision_cuda.h"
4+
#if defined(WITH_CUDA) || defined(WITH_HIP)
5+
#include <ATen/autocast_mode.h>
126
#endif
137

148
namespace {

torchvision/csrc/deform_conv2d.h

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#pragma once
2+
3+
#include "cpu/deform_conv2d_cpu.h"
4+
5+
#ifdef WITH_CUDA
6+
#include "cuda/deform_conv2d_cuda.h"
7+
#endif
8+
#ifdef WITH_HIP
9+
#include "hip/deform_conv2d_cuda.h"
10+
#endif
11+
12+
// Autocast Registration
13+
#if defined(WITH_CUDA) || defined(WITH_HIP)
14+
at::Tensor deform_conv2d_autocast(
15+
const at::Tensor& input,
16+
const at::Tensor& weight,
17+
const at::Tensor& offset,
18+
const at::Tensor& mask,
19+
const at::Tensor& bias,
20+
int64_t stride_h,
21+
int64_t stride_w,
22+
int64_t pad_h,
23+
int64_t pad_w,
24+
int64_t dilation_h,
25+
int64_t dilation_w,
26+
int64_t groups,
27+
int64_t offset_groups,
28+
bool use_mask);
29+
#endif
30+
31+
// Autograd Registration
32+
at::Tensor deform_conv2d_autograd(
33+
const at::Tensor& input,
34+
const at::Tensor& weight,
35+
const at::Tensor& offset,
36+
const at::Tensor& mask,
37+
const at::Tensor& bias,
38+
int64_t stride_h,
39+
int64_t stride_w,
40+
int64_t pad_h,
41+
int64_t pad_w,
42+
int64_t dilation_h,
43+
int64_t dilation_w,
44+
int64_t groups,
45+
int64_t offset_groups,
46+
bool use_mask);
47+
48+
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
49+
deform_conv2d_backward_autograd(
50+
const at::Tensor& grad,
51+
const at::Tensor& input,
52+
const at::Tensor& weight,
53+
const at::Tensor& offset,
54+
const at::Tensor& mask,
55+
const at::Tensor& bias,
56+
int64_t stride_h,
57+
int64_t stride_w,
58+
int64_t pad_h,
59+
int64_t pad_w,
60+
int64_t dilation_h,
61+
int64_t dilation_w,
62+
int64_t groups,
63+
int64_t offset_groups,
64+
bool use_mask);

0 commit comments

Comments
 (0)