Skip to content

Add ROCm support. #393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
22 changes: 19 additions & 3 deletions mmcv/ops/csrc/carafe_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,37 @@ __device__ inline int Loc2Index(const int n, const int c, const int h,
}
/* TODO: move this to a common place */
template <typename scalar_t>
__device__ inline scalar_t min(scalar_t a, scalar_t b) {
__device__ inline scalar_t min_min(scalar_t a, scalar_t b) {
return a < b ? a : b;
}

template <typename scalar_t>
__device__ inline scalar_t max(scalar_t a, scalar_t b) {
__device__ inline scalar_t max_max(scalar_t a, scalar_t b) {
return a > b ? a : b;
}

template <typename scalar_t>
__device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) {
for (int offset = 16; offset > 0; offset /= 2)
#ifdef __NVCC__
val += __shfl_down_sync(FULL_MASK, val, offset);
#endif
#ifdef __HIP_PLATFORM_HCC__
val += __shfl_down(FULL_MASK, val, offset);
#endif
return val;
}

template <>
__device__ __forceinline__ phalf warpReduceSum(phalf val) {
for (int offset = 16; offset > 0; offset /= 2)
__PHALF(val) +=
#ifdef __NVCC__
__shfl_down_sync(FULL_MASK, static_cast<__half>(__PHALF(val)), offset);
#endif
#ifdef __HIP_PLATFORM_HCC__
__shfl_down(FULL_MASK, static_cast<unsigned int>(__PHALF(val)), offset);
#endif
return val;
}

Expand Down Expand Up @@ -294,15 +304,21 @@ __global__ void CARAFEBackward_Mask(const int num_kernels,
down_x <= down_width - 1) {
const int channels_per_mask = ceilf(channels / (float)group_size);
const int start = channels_per_mask * mask_group;
const int end = min(channels_per_mask * (mask_group + 1), channels);
const int end = min_min(channels_per_mask * (mask_group + 1), channels);
for (int c = start + lane_id; c < end; c += WARP_SIZE) {
int bottom_id =
Loc2Index(n, down_y, down_x, c, down_height, down_width, channels);
int top_id = Loc2Index(n, ph, pw, c, height, width, channels);
output_val += top_diff[top_id] * bottom_data[bottom_id];
}
}
#ifdef __NVCC__
__syncwarp();
#endif
#ifdef __HIP_PLATFORM_HCC__
// confused !!!
Copy link

@wangyingrui wangyingrui Sep 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__syncwarp() in order to avoid read-write conflict of output_val, when warpReduceSum reading output_val from other threads within a warp.
AMD HIP doesn't support __syncwarp() now. Using __syncthreads() instead is ok, although bringing a few performance decrease.

__syncthreads();
#endif
output_val = warpReduceSum(output_val);
if (lane_id == 0) {
const int mask_id =
Expand Down
2 changes: 2 additions & 0 deletions mmcv/ops/csrc/common_cuda_helper.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#ifndef COMMON_CUDA_HELPER
#define COMMON_CUDA_HELPER

#ifdef __NVCC__
#include <cuda.h>
#endif

#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
Expand Down
11 changes: 11 additions & 0 deletions mmcv/ops/csrc/pytorch/bbox_overlaps_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@ void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2,
int num_bbox1 = bboxes1.size(0);
int num_bbox2 = bboxes2.size(0);

#ifdef __NVCC__
at::cuda::CUDAGuard device_guard(bboxes1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#endif
#ifdef __HIP_PLATFORM_HCC__
// at::cuda::HIPGuard device_guard(bboxes1.device());

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why comment out HIPGuard?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because pytorch handles ROCm as CUDA, it will cause assert device type error if not comment.
I'm not familiar with CUDA or ROCm, but i hope ROCM can get official efficient support.

Copy link

@wangyingrui wangyingrui Sep 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that here should use HIPGuardImplMasqueradingAsCUDA.h and getCurrentHIPStreamMasqueradingAsCUDA?

hipStream_t stream = at::cuda::getCurrentHIPStream();
#endif
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
bboxes1.scalar_type(), "bbox_overlaps_cuda_kernel", ([&] {
bbox_overlaps_cuda_kernel<scalar_t>
Expand All @@ -18,5 +24,10 @@ void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2,
ious.data_ptr<scalar_t>(), num_bbox1, num_bbox2, mode, aligned,
offset);
}));
#ifdef __NVCC__
AT_CUDA_CHECK(cudaGetLastError());
#endif
#ifdef __HIP_PLATFORM_HCC__
AT_CUDA_CHECK(hipGetLastError());
#endif
}
22 changes: 22 additions & 0 deletions mmcv/ops/csrc/pytorch/carafe_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@ void CARAFEForwardCUDAKernelLauncher(const Tensor features, const Tensor masks,
rmasks.resize_({batch_size, output_height, output_width, mask_channels});

// one warp per pixel
#ifdef __NVCC__
at::cuda::CUDAGuard device_guard(features.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#endif
#ifdef __HIP_PLATFORM_HCC__
// at::cuda::HIPGuard device_guard(features.device());
hipStream_t stream = at::cuda::getCurrentHIPStream();
#endif
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features.scalar_type(), "NCHW2NHWC_Feature", ([&] {
const scalar_t *bottom_data = features.data_ptr<scalar_t>();
Expand Down Expand Up @@ -72,7 +78,12 @@ void CARAFEForwardCUDAKernelLauncher(const Tensor features, const Tensor masks,
bottom_data, top_data);
}));

#ifdef __NVCC__
AT_CUDA_CHECK(cudaGetLastError());
#endif
#ifdef __HIP_PLATFORM_HCC__
AT_CUDA_CHECK(hipGetLastError());
#endif
}

void CARAFEBackwardCUDAKernelLauncher(
Expand All @@ -95,8 +106,14 @@ void CARAFEBackwardCUDAKernelLauncher(
rbottom_grad_hs.resize_({batch_size, output_height, output_width, channels});
rmask_grad.resize_({batch_size, output_height, output_width, mask_channels});

#ifdef __NVCC__
at::cuda::CUDAGuard device_guard(top_grad.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#endif
#ifdef __HIP_PLATFORM_HCC__
// at::cuda::HIPGuard device_guard(top_grad.device());
hipStream_t stream = at::cuda::getCurrentHIPStream();
#endif
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.scalar_type(), "NCHW2NHWC_Top_Grad", ([&] {
const scalar_t *bottom_data = top_grad.data_ptr<scalar_t>();
Expand Down Expand Up @@ -175,5 +192,10 @@ void CARAFEBackwardCUDAKernelLauncher(
bottom_data, top_data);
}));

#ifdef __NVCC__
AT_CUDA_CHECK(cudaGetLastError());
#endif
#ifdef __HIP_PLATFORM_HCC__
AT_CUDA_CHECK(hipGetLastError());
#endif
}
22 changes: 22 additions & 0 deletions mmcv/ops/csrc/pytorch/carafe_naive_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,14 @@ int CARAFENAIVEForwardCUDAKernelLauncher(const Tensor features,
int height = output.size(2);
int width = output.size(3);

#ifdef __NVCC__
at::cuda::CUDAGuard device_guard(features.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#endif
#ifdef __HIP_PLATFORM_HCC__
// at::cuda::HIPGuard device_guard(features.device());
hipStream_t stream = at::cuda::getCurrentHIPStream();
#endif
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features.scalar_type(), "CARAFENAIVEForward", ([&] {
carafe_naive_forward_cuda_kernel<scalar_t>
Expand All @@ -22,7 +28,12 @@ int CARAFENAIVEForwardCUDAKernelLauncher(const Tensor features,
kernel_size, group_size, scale_factor, channels, height, width);
}));

#ifdef __NVCC__
AT_CUDA_CHECK(cudaGetLastError());
#endif
#ifdef __HIP_PLATFORM_HCC__
AT_CUDA_CHECK(hipGetLastError());
#endif
return 0;
}

Expand All @@ -35,8 +46,14 @@ int CARAFENAIVEBackwardCUDAKernelLauncher(
int height = top_grad.size(2);
int width = top_grad.size(3);

#ifdef __NVCC__
at::cuda::CUDAGuard device_guard(top_grad.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#endif
#ifdef __HIP_PLATFORM_HCC__
// at::cuda::HIPGuard device_guard(top_grad.device());
hipStream_t stream = at::cuda::getCurrentHIPStream();
#endif
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.scalar_type(), "CARAFENAIVEBackward", ([&] {
carafe_naive_backward_cuda_kernel<scalar_t>
Expand All @@ -48,6 +65,11 @@ int CARAFENAIVEBackwardCUDAKernelLauncher(
scale_factor, channels, height, width);
}));

#ifdef __NVCC__
AT_CUDA_CHECK(cudaGetLastError());
#endif
#ifdef __HIP_PLATFORM_HCC__
AT_CUDA_CHECK(hipGetLastError());
#endif
return 0;
}
47 changes: 47 additions & 0 deletions mmcv/ops/csrc/pytorch/cc_attention_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
// Modified from
// https://github.com/LikeLy-Journey/SegmenTron/blob/master/segmentron/modules/csrc/criss_cross_attention/ca_cuda.cu

#ifdef __NVCC__
#include <THC/THC.h>

#include <THC/THCDeviceUtils.cuh>
#endif
#ifdef __HIP_PLATFORM_HCC__
#include <THH/THH.h>

#include <THH/THHDeviceUtils.cuh>
#endif

#include "cc_attention_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
Expand All @@ -18,7 +25,12 @@ void CAForwardCUDAKernelLauncher(const Tensor t, const Tensor f,
auto h = t.size(2);
auto w = t.size(3);

#ifdef __NVCC__
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#endif
#ifdef __HIP_PLATFORM_HCC__
hipStream_t stream = at::cuda::getCurrentHIPStream();
#endif

// Run kernel
dim3 threads(32, 32);
Expand All @@ -33,7 +45,12 @@ void CAForwardCUDAKernelLauncher(const Tensor t, const Tensor f,
f.contiguous().data_ptr<scalar_t>(),
weight.contiguous().data_ptr<scalar_t>(), n, c, h, w);
});
#ifdef __NVCC__
THCudaCheck(cudaGetLastError());
#endif
#ifdef __HIP_PLATFORM_HCC__
THCudaCheck(hipGetLastError());
#endif
}

void CABackwardCUDAKernelLauncher(const Tensor dw, const Tensor t,
Expand All @@ -47,7 +64,12 @@ void CABackwardCUDAKernelLauncher(const Tensor dw, const Tensor t,
auto h = t.size(2);
auto w = t.size(3);

#ifdef __NVCC__
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#endif
#ifdef __HIP_PLATFORM_HCC__
hipStream_t stream = at::cuda::getCurrentHIPStream();
#endif

// Run kernel
dim3 threads(32, 32);
Expand All @@ -71,7 +93,12 @@ void CABackwardCUDAKernelLauncher(const Tensor dw, const Tensor t,
f.contiguous().data_ptr<scalar_t>(),
df.contiguous().data_ptr<scalar_t>(), n, c, h, w);
});
#ifdef __NVCC__
THCudaCheck(cudaGetLastError());
#endif
#ifdef __HIP_PLATFORM_HCC__
THCudaCheck(hipGetLastError());
#endif
}

void CAMapForwardCUDAKernelLauncher(const Tensor weight, const Tensor g,
Expand All @@ -84,7 +111,12 @@ void CAMapForwardCUDAKernelLauncher(const Tensor weight, const Tensor g,
auto h = g.size(2);
auto w = g.size(3);

#ifdef __NVCC__
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#endif
#ifdef __HIP_PLATFORM_HCC__
hipStream_t stream = at::cuda::getCurrentHIPStream();
#endif

// Run kernel
dim3 threads(32, 32);
Expand All @@ -99,7 +131,12 @@ void CAMapForwardCUDAKernelLauncher(const Tensor weight, const Tensor g,
g.contiguous().data_ptr<scalar_t>(),
out.contiguous().data_ptr<scalar_t>(), n, c, h, w);
});
#ifdef __NVCC__
THCudaCheck(cudaGetLastError());
#endif
#ifdef __HIP_PLATFORM_HCC__
THCudaCheck(hipGetLastError());
#endif
}

void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
Expand All @@ -113,7 +150,12 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
auto h = dout.size(2);
auto w = dout.size(3);

#ifdef __NVCC__
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#endif
#ifdef __HIP_PLATFORM_HCC__
hipStream_t stream = at::cuda::getCurrentHIPStream();
#endif

// Run kernel
dim3 threads(32, 32);
Expand All @@ -138,5 +180,10 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
g.contiguous().data_ptr<scalar_t>(),
dg.contiguous().data_ptr<scalar_t>(), n, c, h, w);
});
#ifdef __NVCC__
THCudaCheck(cudaGetLastError());
#endif
#ifdef __HIP_PLATFORM_HCC__
THCudaCheck(hipGetLastError());
#endif
}
30 changes: 30 additions & 0 deletions mmcv/ops/csrc/pytorch/deform_conv_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,23 @@ void deformable_im2col(Tensor data_im, Tensor data_offset, const int channels,

deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels),
THREADS_PER_BLOCK, 0,
#ifdef __NVCC__
at::cuda::getCurrentCUDAStream()>>>(
#endif
#ifdef __HIP_PLATFORM_HCC__
at::cuda::getCurrentHIPStream()>>>(
#endif
num_kernels, data_im_, data_offset_, height, width, ksize_h,
ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
channel_per_deformable_group, parallel_imgs, channels,
deformable_group, height_col, width_col, data_col_);
}));
#ifdef __NVCC__
AT_CUDA_CHECK(cudaGetLastError());
#endif
#ifdef __HIP_PLATFORM_HCC__
AT_CUDA_CHECK(hipGetLastError());
#endif
}

void deformable_col2im(Tensor data_col, Tensor data_offset, const int channels,
Expand All @@ -58,13 +68,23 @@ void deformable_col2im(Tensor data_col, Tensor data_offset, const int channels,

deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels),
THREADS_PER_BLOCK, 0,
#ifdef __NVCC__
at::cuda::getCurrentCUDAStream()>>>(
#endif
#ifdef __HIP_PLATFORM_HCC__
at::cuda::getCurrentHIPStream()>>>(
#endif
num_kernels, data_col_, data_offset_, channels, height, width,
ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h,
dilation_w, channel_per_deformable_group, parallel_imgs,
deformable_group, height_col, width_col, grad_im_);
}));
#ifdef __NVCC__
AT_CUDA_CHECK(cudaGetLastError());
#endif
#ifdef __HIP_PLATFORM_HCC__
AT_CUDA_CHECK(hipGetLastError());
#endif
}

void deformable_col2im_coord(
Expand All @@ -91,14 +111,24 @@ void deformable_col2im_coord(

deformable_col2im_coord_gpu_kernel<<<
GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0,
#ifdef __NVCC__
at::cuda::getCurrentCUDAStream()>>>(
#endif
#ifdef __HIP_PLATFORM_HCC__
at::cuda::getCurrentHIPStream()>>>(
#endif
num_kernels, data_col_, data_im_, data_offset_, channels, height,
width, ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group, parallel_imgs,
2 * ksize_h * ksize_w * deformable_group, deformable_group,
height_col, width_col, grad_offset_);
}));
#ifdef __NVCC__
AT_CUDA_CHECK(cudaGetLastError());
#endif
#ifdef __HIP_PLATFORM_HCC__
AT_CUDA_CHECK(hipGetLastError());
#endif
}

void deform_conv_shape_check(Tensor input, Tensor offset, Tensor *gradOutput,
Expand Down
Loading