Skip to content

Commit c80ea15

Browse files
committed
Add ROCm support
1 parent 7d8ca43 commit c80ea15

11 files changed

+83
-3
lines changed

cuda_ext.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@
1010
library_dir = "../exllama/"
1111
extension_name = "exllama_ext"
1212

13+
if torch.version.hip:
14+
# FIXME: To build, I had to comment "flags += ['-fno-gpu-rdc']" in torch/utils/cpp_extension.py.
15+
# I am not sure if it's possible to find a way to build without editing that file.
16+
# If building without gpu-rdc, build will error with "lld: error: undefined hidden symbol: __llvm_amdgcn_rcp_f16".
17+
extra_cuda_cflags= ["-U__HIP_NO_HALF_CONVERSIONS__", "-fgpu-rdc"]
18+
else:
19+
extra_cuda_cflags = []
20+
1321
exllama_ext = load(
1422
name = extension_name,
1523
sources = [
@@ -21,6 +29,7 @@
2129
os.path.join(library_dir, "exllama_ext/q4v2_sequential.cu"),
2230
os.path.join(library_dir, "exllama_ext/rms_norm.cu")
2331
],
32+
extra_cuda_cflags = extra_cuda_cflags
2433
# verbose = True,
2534
# extra_cflags = ["-ftime-report", "-DTORCH_USE_CUDA_DSA"]
2635
)

exllama_ext/column_remap.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
#ifndef _column_remap_h
22
#define _column_remap_h
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#define cudaError_t hipError_t
8+
#else
49
#include <cuda_runtime.h>
510
#include <cuda_fp16.h>
11+
#endif
612
#include <cstdint>
713

814
cudaError_t column_remap_cuda

exllama_ext/cuda_compat.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ __device__ __forceinline__ void atomicAdd_half(half* address, half val)
2222
while (assumed != old);
2323
}
2424

25-
#ifdef __CUDA_ARCH__
26-
#if __CUDA_ARCH__ < 700
25+
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
26+
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
2727

2828
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
2929

exllama_ext/exllama_ext.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,22 @@
1919
// Python is super tricky, so in place of proper exceptions, CUDA functions return with a cudaError_t which we can
2020
// parse and dump to the console.
2121

22+
#if defined(USE_ROCM)
23+
// FIXME: Get aborted, here the stacktrace:
24+
// #0 in ?? () from /usr/lib/libc.so.6
25+
// #1 in raise () from /usr/lib/libc.so.6
26+
// #2 in abort () from /usr/lib/libc.so.6
27+
// #3 in amd::report_fatal(char const*, int, char const*) () from torch/lib/libamdhip64.so
28+
// #4 in hip::DeviceFunc::DeviceFunc(std::basic_string<char, std::char_traits<char>, std::allocator<char> >, ihipModule_t*) ()
29+
// from torch/lib/libamdhip64.so
30+
// #5 in hip::Function::getStatFunc(ihipModuleSymbol_t**, int) () from torch/lib/libamdhip64.so
31+
// #6 in hip::StatCO::getStatFunc(ihipModuleSymbol_t**, void const*, int) () from torch/lib/libamdhip64.so
32+
// #7 in ihipLaunchKernel(void const*, dim3, dim3, void**, unsigned long, ihipStream_t*, ihipEvent_t*, ihipEvent_t*, int) ()
33+
// from torch/lib/libamdhip64.so
34+
// #8 in hipLaunchKernel_common () from torch/lib/libamdhip64.so
35+
// #9 in hipLaunchKernel () from torch/lib/libamdhip64.so
36+
#define _cuda_raise(fn)
37+
#else
2238
#define _cuda_raise(fn) \
2339
do { \
2440
cudaError_t _cuda_err_temp; \
@@ -40,6 +56,7 @@ do { \
4056
} \
4157
} \
4258
} while(false)
59+
#endif
4360

4461

4562
void q4v2_matmul

exllama_ext/matrix.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
#ifndef _matrix_h
22
#define _matrix_h
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#else
48
#include <cuda_runtime.h>
59
#include <cuda_fp16.h>
10+
#endif
611

712
class MatrixView_half
813
{

exllama_ext/q4v2_matmul.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
#ifndef _q4v2_matmul_h
22
#define _q4v2_matmul_h
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#define cudaError_t hipError_t
8+
#else
49
#include <cuda_runtime.h>
510
#include <cuda_fp16.h>
11+
#endif
612
#include <cstdint>
713
#include <cstdio>
814

exllama_ext/q4v2_mlp.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
#ifndef _q4v2_mlp_h
22
#define _q4v2_mlp_h
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#define cudaError_t hipError_t
8+
#else
49
#include <cuda_runtime.h>
510
#include <cuda_fp16.h>
11+
#endif
612
#include <cstdint>
713

814
cudaError_t q4v2_mlp_cuda

exllama_ext/q4v2_recons.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
#ifndef _q4v2_recons_h
22
#define _q4v2_recons_h
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#define cudaError_t hipError_t
8+
#else
49
#include <cuda_runtime.h>
510
#include <cuda_fp16.h>
11+
#endif
612
#include <cstdint>
713

814
cudaError_t q4v2_recons_cuda

exllama_ext/q4v2_sequential.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
#ifndef _q4v2_sequential_h
22
#define _q4v2_sequential_h
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#define cudaError_t hipError_t
8+
#else
49
#include <cuda_runtime.h>
510
#include <cuda_fp16.h>
11+
#endif
612
#include <cstdint>
713
#include <cstdio>
814

exllama_ext/rms_norm.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
#ifndef _rms_norm_h
22
#define _rms_norm_h
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#define cudaError_t hipError_t
8+
#else
49
#include <cuda_runtime.h>
510
#include <cuda_fp16.h>
11+
#endif
612
#include <cstdint>
713

814
cudaError_t rms_norm_cuda

exllama_ext/util.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,25 @@
11
#ifndef _util_h
22
#define _util_h
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#define cudaDeviceSynchronize hipDeviceSynchronize
8+
#define cudaError_t hipError_t
9+
#define cudaMalloc hipMalloc
10+
#define cudaMemcpy hipMemcpy
11+
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
12+
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
13+
#define cudaSuccess hipSuccess
14+
#define cudaUnspecified hipErrorUnknown
15+
#else
416
#include <cuda_runtime.h>
517
#include <cuda_fp16.h>
18+
#define cudaUnspecified cudaErrorApiFailureBase
19+
#endif
620
#include <cstdint>
721
#include <cstdio>
822

9-
#define cudaUnspecified cudaErrorApiFailureBase
1023

1124
// React to failure on return code != cudaSuccess
1225

0 commit comments

Comments
 (0)