Skip to content

Add ROCm support. #3235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions mmdet/ops/carafe/src/cuda/carafe_cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <ATen/Utils.h>

#ifdef __NVCC__
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <THC/THCAtomics.cuh>
#endif

#ifdef __HIP_PLATFORM_HCC__
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/HIPApplyUtils.cuh>
#include <THH/THHAtomics.cuh>
#include <hip/hip_runtime.h>
#endif

#include <cmath>

using namespace at;
Expand All @@ -22,6 +33,7 @@ using namespace at;
#define kBlockRows 8
#define FULL_MASK 0xffffffff


inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); }

__device__ inline int Loc2Index(const int n, const int c, const int h,
Expand All @@ -32,19 +44,34 @@ __device__ inline int Loc2Index(const int n, const int c, const int h,
}
/* TODO: move this to a common place */
template <typename scalar_t>
#ifdef __NVCC__
__device__ inline scalar_t min(scalar_t a, scalar_t b) {
#endif
#ifdef __HIP_PLATFORM_HCC__
__device__ inline scalar_t min_min(scalar_t a, scalar_t b) {
#endif
return a < b ? a : b;
}

template <typename scalar_t>
#ifdef __NVCC__
__device__ inline scalar_t max(scalar_t a, scalar_t b) {
#endif
#ifdef __HIP_PLATFORM_HCC__
__device__ inline scalar_t max_max(scalar_t a, scalar_t b) {
#endif
return a > b ? a : b;
}

template <typename scalar_t>
__device__ __forceinline__ scalar_t WARP_SHFL_DOWN(scalar_t val, int offset)
{
#ifdef __NVCC__
return __shfl_down_sync(FULL_MASK, val, offset);
#endif
#ifdef __HIP_PLATFORM_HCC__
return __shfl_down(FULL_MASK, val, offset);
#endif
}

template<>
Expand Down Expand Up @@ -168,7 +195,12 @@ int CARAFEForwardLaucher(const at::Tensor features, const at::Tensor masks,
at::Tensor rfeatures, at::Tensor routput,
at::Tensor rmasks, at::Tensor output) {
// one warp per pixel
#ifdef __NVCC__
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#endif
#ifdef __HIP_PLATFORM_HCC__
hipStream_t stream = at::cuda::getCurrentHIPStream();
#endif
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features.scalar_type(), "NCHW2NHWC_Feature", ([&] {
const scalar_t *bottom_data = features.data_ptr<scalar_t>();
Expand Down Expand Up @@ -217,9 +249,16 @@ int CARAFEForwardLaucher(const at::Tensor features, const at::Tensor masks,
batch_size, output_height * output_width, channels, dh, dw,
bottom_data, top_data);
}));
#ifdef __NVCC__
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
#endif
#ifdef __HIP_PLATFORM_HCC__
hipError_t err = hipGetLastError();
if (hipSuccess != err) {
fprintf(stderr, "hipCheckError() failed : %s\n", hipGetErrorString(err));
#endif
exit(-1);
}

Expand Down Expand Up @@ -373,15 +412,26 @@ __global__ void CARAFEBackward_Mask(const int num_kernels,
down_x <= down_width - 1) {
const int channels_per_mask = ceilf(channels / (float)group_size);
const int start = channels_per_mask * mask_group;
#ifdef __NVCC__
const int end = min(channels_per_mask * (mask_group + 1), channels);
#endif
#ifdef __HIP_PLATFORM_HCC__
const int end = min_min(channels_per_mask * (mask_group + 1), channels);
#endif
for (int c = start + lane_id; c < end; c += WARP_SIZE) {
int bottom_id =
Loc2Index(n, down_y, down_x, c, down_height, down_width, channels);
int top_id = Loc2Index(n, ph, pw, c, height, width, channels);
output_val += top_diff[top_id] * bottom_data[bottom_id];
}
}
#ifdef __NVCC__
__syncwarp();
#endif
#ifdef __HIP_PLATFORM_HCC__
// confused !!!
__syncthreads();
#endif
output_val = warpReduceSum(output_val);
if (lane_id == 0) {
const int mask_id =
Expand All @@ -400,7 +450,12 @@ int CARAFEBackwardLaucher(const at::Tensor top_grad, const at::Tensor rfeatures,
at::Tensor rbottom_grad_hs, at::Tensor rbottom_grad,
at::Tensor rmask_grad, at::Tensor bottom_grad,
at::Tensor mask_grad) {
#ifdef __NVCC__
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#endif
#ifdef __HIP_PLATFORM_HCC__
hipStream_t stream = at::cuda::getCurrentHIPStream();
#endif
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.scalar_type(), "NCHW2NHWC_Top_Grad", ([&] {
const scalar_t *bottom_data = top_grad.data_ptr<scalar_t>();
Expand Down Expand Up @@ -479,9 +534,16 @@ int CARAFEBackwardLaucher(const at::Tensor top_grad, const at::Tensor rfeatures,
batch_size, output_height * output_width, mask_channels, dh, dw,
bottom_data, top_data);
}));
#ifdef __NVCC__
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
#endif
#ifdef __HIP_PLATFORM_HCC__
hipError_t err = hipGetLastError();
if (hipSuccess != err) {
fprintf(stderr, "hipCheckError() failed : %s\n", hipGetErrorString(err));
#endif
exit(-1);
}

Expand Down
19 changes: 19 additions & 0 deletions mmdet/ops/carafe/src/cuda/carafe_naive_cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
#include <ATen/ATen.h>
#ifdef __NVCC__
#include <THC/THCAtomics.cuh>
#endif
#ifdef __HIP_PLATFORM_HCC__
#include <THH/THHAtomics.cuh>
#endif

using namespace at; // temporal fix for pytorch<=0.4.1 (see #9848)

Expand Down Expand Up @@ -86,9 +91,16 @@ int CARAFENAIVEForwardLaucher(const at::Tensor features, const at::Tensor masks,
output_size, bottom_data, bottom_masks, kernel_size, group_size,
scale_factor, channels, height, width, top_data);
}));
#ifdef __NVCC__
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
#endif
#ifdef __HIP_PLATFORM_HCC__
hipError_t err = hipGetLastError();
if (hipSuccess != err) {
fprintf(stderr, "hipCheckError() failed : %s\n", hipGetErrorString(err));
#endif
exit(-1);
}

Expand Down Expand Up @@ -166,9 +178,16 @@ int CARAFENAIVEBackwardLaucher(const at::Tensor top_grad,
mask_diff);
}));

#ifdef __NVCC__
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
#endif
#ifdef __HIP_PLATFORM_HCC__
hipError_t err = hipGetLastError();
if (hipSuccess != err) {
fprintf(stderr, "hipCheckError() failed : %s\n", hipGetErrorString(err));
#endif
exit(-1);
}

Expand Down
76 changes: 76 additions & 0 deletions mmdet/ops/dcn/src/cuda/deform_conv_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,14 @@
// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu

#include <ATen/ATen.h>
#ifdef __NVCC__
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#endif
#ifdef __HIP_PLATFORM_HCC__
#include <ATen/hip/HIPContext.h>
#include <THH/THHAtomics.cuh>
#endif
#include <stdio.h>
#include <math.h>
#include <float.h>
Expand Down Expand Up @@ -262,17 +268,30 @@ void deformable_im2col(
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
scalar_t *data_col_ = data_col.data_ptr<scalar_t>();

#ifdef __NVCC__
deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
#endif
#ifdef __HIP_PLATFORM_HCC__
deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentHIPStream()>>>(
#endif
num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
channel_per_deformable_group, parallel_imgs, channels, deformable_group,
height_col, width_col, data_col_);
}));

#ifdef __NVCC__
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
#endif
#ifdef __HIP_PLATFORM_HCC__
hipError_t err = hipGetLastError();
if (err != hipSuccess)
{
printf("error in deformable_im2col: %s\n", hipGetErrorString(err));
#endif
}
}

Expand Down Expand Up @@ -356,17 +375,30 @@ void deformable_col2im(
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();

#ifdef __NVCC__
deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
#endif
#ifdef __HIP_PLATFORM_HCC__
deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentHIPStream()>>>(
#endif
num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
ksize_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
parallel_imgs, deformable_group, height_col, width_col, grad_im_);
}));

#ifdef __NVCC__
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
#endif
#ifdef __HIP_PLATFORM_HCC__
hipError_t err = hipGetLastError();
if (err != hipSuccess)
{
printf("error in deformable_col2im: %s\n", hipGetErrorString(err));
#endif
}
}

Expand Down Expand Up @@ -455,7 +487,12 @@ void deformable_col2im_coord(
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();

#ifdef __NVCC__
deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
#endif
#ifdef __HIP_PLATFORM_HCC__
deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentHIPStream()>>>(
#endif
num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
Expand Down Expand Up @@ -785,16 +822,29 @@ void modulated_deformable_im2col_cuda(
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
scalar_t *data_col_ = data_col.data_ptr<scalar_t>();

#ifdef __NVCC__
modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
#endif
#ifdef __HIP_PLATFORM_HCC__
modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentHIPStream()>>>(
#endif
num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
batch_size, channels, deformable_group, height_col, width_col, data_col_);
}));

#ifdef __NVCC__
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
#endif
#ifdef __HIP_PLATFORM_HCC__
hipError_t err = hipGetLastError();
if (err != hipSuccess)
{
printf("error in modulated_deformable_im2col_cuda: %s\n", hipGetErrorString(err));
#endif
}
}

Expand All @@ -817,17 +867,30 @@ void modulated_deformable_col2im_cuda(
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();

#ifdef __NVCC__
modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
#endif
#ifdef __HIP_PLATFORM_HCC__
modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentHIPStream()>>>(
#endif
num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
batch_size, deformable_group, height_col, width_col, grad_im_);
}));

#ifdef __NVCC__
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
#endif
#ifdef __HIP_PLATFORM_HCC__
hipError_t err = hipGetLastError();
if (err != hipSuccess)
{
printf("error in modulated_deformable_col2im_cuda: %s\n", hipGetErrorString(err));
#endif
}
}

Expand All @@ -852,16 +915,29 @@ void modulated_deformable_col2im_coord_cuda(
scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
scalar_t *grad_mask_ = grad_mask.data_ptr<scalar_t>();

#ifdef __NVCC__
modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
#endif
#ifdef __HIP_PLATFORM_HCC__
modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentHIPStream()>>>(
#endif
num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
grad_offset_, grad_mask_);
}));
#ifdef __NVCC__
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
#endif
#ifdef __HIP_PLATFORM_HCC__
hipError_t err = hipGetLastError();
if (err != hipSuccess)
{
printf("error in modulated_deformable_col2im_coord_cuda: %s\n", hipGetErrorString(err));
#endif
}
}
Loading