Skip to content

Commit e5bd90d

Browse files
jeffdailyjameslambshiyu1994StrikerRUS
authored
[ROCm] add ROCm support (pt. 2) (#7039)
* [ROCm] re-add support for ROCm builds Previously #6086 added ROCm support but after numerous rebases it lost critical changes. This PR restores the ROCm build. There are many source file changes but most were automated using the following: ```bash for f in `grep -rl '#ifdef USE_CUDA'` do sed -i 's@#ifdef USE_CUDA@#if defined(USE_CUDA) || defined(USE_ROCM)@g' $f done for f in `grep -rl '#endif // USE_CUDA'` do sed -i 's@#endif // USE_CUDA@#endif // USE_CUDA || USE_ROCM@g' $f done ``` * Fix error in cpp_tests/test_arrow.cpp. error: explicit specialization in non-namespace scope ‘class ArrowChunkedArrayTest’ * update for ROCm 7 BC-breaking change to warpSize * lint * Revert "Fix error in cpp_tests/test_arrow.cpp." This reverts commit e461e86. * partial revert of 61ec4f1 Instead of replacing all #ifdef USE_CUDA, just add USE_CUDA define to ROCm build. * add --use-rocm option to build-python.sh * fix cuda build missing CUDASUCCESS_OR_FATAL in vector_cudahost.h * add rocm docs * fix doc using pre-commit * apply reviewer suggestions * fix build-python.sh doc * fix build for rocm 7.0 --------- Co-authored-by: James Lamb <jaylamb20@gmail.com> Co-authored-by: shiyu1994 <shiyu_k1994@qq.com> Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
1 parent 569f89a commit e5bd90d

File tree

14 files changed

+162
-18
lines changed

14 files changed

+162
-18
lines changed

CMakeLists.txt

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,9 @@ if(USE_ROCM)
286286
endif()
287287
message(STATUS "CMAKE_HIP_FLAGS: ${CMAKE_HIP_FLAGS}")
288288

289+
# Building for ROCm almost always means USE_CUDA.
290+
# Exceptions to this will be guarded by USE_ROCM.
291+
add_definitions(-DUSE_CUDA)
289292
add_definitions(-DUSE_ROCM)
290293
endif()
291294

@@ -473,10 +476,21 @@ set(
473476
src/cuda/cuda_algorithms.cu
474477
)
475478

476-
if(USE_CUDA)
479+
if(USE_CUDA OR USE_ROCM)
477480
list(APPEND LGBM_SOURCES ${LGBM_CUDA_SOURCES})
478481
endif()
479482

483+
if(USE_ROCM)
484+
set(CU_FILES "")
485+
foreach(file IN LISTS LGBM_CUDA_SOURCES)
486+
string(REGEX MATCH "\\.cu$" is_cu_file "${file}")
487+
if(is_cu_file)
488+
list(APPEND CU_FILES "${file}")
489+
endif()
490+
endforeach()
491+
set_source_files_properties(${CU_FILES} PROPERTIES LANGUAGE HIP)
492+
endif()
493+
480494
add_library(lightgbm_objs OBJECT ${LGBM_SOURCES})
481495

482496
if(BUILD_CLI)
@@ -629,6 +643,10 @@ if(USE_CUDA)
629643
endif()
630644
endif()
631645

646+
if(USE_ROCM)
647+
target_link_libraries(lightgbm_objs PUBLIC hip::host)
648+
endif()
649+
632650
if(WIN32)
633651
if(MINGW OR CYGWIN)
634652
target_link_libraries(lightgbm_objs PUBLIC ws2_32 iphlpapi)

build-python.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
# --precompile
5555
# Use precompiled library.
5656
# Only used with 'install' command.
57+
# --rocm
58+
# Compile ROCm version.
5759
# --time-costs
5860
# Compile version that outputs time costs for different internal routines.
5961
# --user
@@ -142,6 +144,9 @@ while [ $# -gt 0 ]; do
142144
--cuda)
143145
BUILD_ARGS="${BUILD_ARGS} --config-setting=cmake.define.USE_CUDA=ON"
144146
;;
147+
--rocm)
148+
BUILD_ARGS="${BUILD_ARGS} --config-setting=cmake.define.USE_ROCM=ON"
149+
;;
145150
--gpu)
146151
BUILD_ARGS="${BUILD_ARGS} --config-setting=cmake.define.USE_GPU=ON"
147152
;;

docs/Installation-Guide.rst

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,65 @@ macOS
749749

750750
The CUDA version is not supported on macOS.
751751

752+
Build ROCm Version
753+
~~~~~~~~~~~~~~~~~~
754+
755+
The `original GPU version <#build-gpu-version>`__ of LightGBM (``device_type=gpu``) is based on OpenCL.
756+
757+
The ROCm-based version (``device_type=cuda``) is a separate implementation. Yes, the ROCm version reuses the ``device_type=cuda`` as a convenience for users. Use this version in Linux environments with an AMD GPU.
758+
759+
Windows
760+
^^^^^^^
761+
762+
The ROCm version is not supported on Windows.
763+
Use the `GPU version <#build-gpu-version>`__ (``device_type=gpu``) for GPU acceleration on Windows.
764+
765+
Linux
766+
^^^^^
767+
768+
On Linux, a ROCm version of LightGBM can be built using
769+
770+
- **CMake**, **gcc** and **ROCm**;
771+
- **CMake**, **Clang** and **ROCm**.
772+
773+
Please refer to `the ROCm docs`_ for **ROCm** libraries installation.
774+
775+
After compilation the executable and ``.so`` files will be in ``LightGBM/`` folder.
776+
777+
gcc
778+
***
779+
780+
1. Install `CMake`_, **gcc** and **ROCm**.
781+
782+
2. Run the following commands:
783+
784+
.. code:: sh
785+
786+
git clone --recursive https://github.com/microsoft/LightGBM
787+
cd LightGBM
788+
cmake -B build -S . -DUSE_ROCM=ON
789+
cmake --build build -j4
790+
791+
Clang
792+
*****
793+
794+
1. Install `CMake`_, **Clang**, **OpenMP** and **ROCm**.
795+
796+
2. Run the following commands:
797+
798+
.. code:: sh
799+
800+
git clone --recursive https://github.com/microsoft/LightGBM
801+
cd LightGBM
802+
export CXX=clang++-14 CC=clang-14 # replace "14" with version of Clang installed on your machine
803+
cmake -B build -S . -DUSE_ROCM=ON
804+
cmake --build build -j4
805+
806+
macOS
807+
^^^^^
808+
809+
The ROCm version is not supported on macOS.
810+
752811
Build Java Wrapper
753812
~~~~~~~~~~~~~~~~~~
754813

@@ -1054,6 +1113,8 @@ gcc
10541113

10551114
.. _this detailed guide: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html
10561115

1116+
.. _the ROCm docs: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/
1117+
10571118
.. _following docs: https://github.com/google/sanitizers/wiki
10581119

10591120
.. _Ninja: https://ninja-build.org

docs/Parameters.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,15 +264,15 @@ Core Parameters
264264

265265
- ``cpu`` supports all LightGBM functionality and is portable across the widest range of operating systems and hardware
266266

267-
- ``cuda`` offers faster training than ``gpu`` or ``cpu``, but only works on GPUs supporting CUDA
267+
- ``cuda`` offers faster training than ``gpu`` or ``cpu``, but only works on GPUs supporting CUDA or ROCm
268268

269269
- ``gpu`` can be faster than ``cpu`` and works on a wider range of GPUs than CUDA
270270

271271
- **Note**: it is recommended to use the smaller ``max_bin`` (e.g. 63) to get the better speed up
272272

273273
- **Note**: for the faster speed, GPU uses 32-bit float point to sum up by default, so this may affect the accuracy for some tasks. You can set ``gpu_use_dp=true`` to enable 64-bit float point, but it will slow down the training
274274

275-
- **Note**: refer to `Installation Guide <./Installation-Guide.rst>`__ to build LightGBM with GPU or CUDA support
275+
- **Note**: refer to `Installation Guide <./Installation-Guide.rst>`__ to build LightGBM with GPU, CUDA, or ROCm support
276276

277277
- ``seed`` :raw-html:`<a id="seed" title="Permalink to this parameter" href="#seed">&#x1F517;&#xFE0E;</a>`, default = ``None``, type = int, aliases: ``random_seed``, ``random_state``
278278

docs/_static/js/script.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ $(() => {
2222
"#build-mpi-version",
2323
"#build-gpu-version",
2424
"#build-cuda-version",
25+
"#build-rocm-version",
2526
"#build-java-wrapper",
2627
"#build-python-package",
2728
"#build-r-package",

include/LightGBM/config.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,11 +246,11 @@ struct Config {
246246
// alias = device
247247
// desc = device for the tree learning
248248
// desc = ``cpu`` supports all LightGBM functionality and is portable across the widest range of operating systems and hardware
249-
// desc = ``cuda`` offers faster training than ``gpu`` or ``cpu``, but only works on GPUs supporting CUDA
249+
// desc = ``cuda`` offers faster training than ``gpu`` or ``cpu``, but only works on GPUs supporting CUDA or ROCm
250250
// desc = ``gpu`` can be faster than ``cpu`` and works on a wider range of GPUs than CUDA
251251
// desc = **Note**: it is recommended to use the smaller ``max_bin`` (e.g. 63) to get the better speed up
252252
// desc = **Note**: for the faster speed, GPU uses 32-bit float point to sum up by default, so this may affect the accuracy for some tasks. You can set ``gpu_use_dp=true`` to enable 64-bit float point, but it will slow down the training
253-
// desc = **Note**: refer to `Installation Guide <./Installation-Guide.rst>`__ to build LightGBM with GPU or CUDA support
253+
// desc = **Note**: refer to `Installation Guide <./Installation-Guide.rst>`__ to build LightGBM with GPU, CUDA, or ROCm support
254254
std::string device_type = "cpu";
255255

256256
// [no-automatically-extract]

include/LightGBM/cuda/cuda_algorithms.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99

1010
#ifdef USE_CUDA
1111

12+
#ifndef USE_ROCM
1213
#include <cuda.h>
1314
#include <cuda_runtime.h>
15+
#endif
1416
#include <stdio.h>
1517

1618
#include <LightGBM/bin.h>

include/LightGBM/cuda/cuda_random.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77

88
#ifdef USE_CUDA
99

10+
#ifndef USE_ROCM
1011
#include <cuda.h>
1112
#include <cuda_runtime.h>
13+
#endif
1214

1315
namespace LightGBM {
1416

include/LightGBM/cuda/cuda_rocm_interop.h

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,59 @@
77

88
#ifdef USE_CUDA
99

10-
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIP__)
11-
// ROCm doesn't have __shfl_down_sync, only __shfl_down without mask.
10+
#if defined(__HIP_PLATFORM_AMD__)
11+
12+
// ROCm doesn't have atomicAdd_block, but it should be semantically the same as atomicAdd
13+
#define atomicAdd_block atomicAdd
14+
15+
// hipify
16+
#include <hip/hip_runtime.h>
17+
#define cudaDeviceProp hipDeviceProp_t
18+
#define cudaDeviceSynchronize hipDeviceSynchronize
19+
#define cudaError_t hipError_t
20+
#define cudaFree hipFree
21+
#define cudaFreeHost hipFreeHost
22+
#define cudaGetDevice hipGetDevice
23+
#define cudaGetDeviceProperties hipGetDeviceProperties
24+
#define cudaGetErrorName hipGetErrorName
25+
#define cudaGetErrorString hipGetErrorString
26+
#define cudaGetLastError hipGetLastError
27+
#define cudaHostAlloc hipHostAlloc
28+
#define cudaHostAllocPortable hipHostAllocPortable
29+
#define cudaMalloc hipMalloc
30+
#define cudaMemcpy hipMemcpy
31+
#define cudaMemcpyAsync hipMemcpyAsync
32+
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
33+
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
34+
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
35+
#define cudaMemoryTypeHost hipMemoryTypeHost
36+
#define cudaMemset hipMemset
37+
#define cudaPointerAttributes hipPointerAttribute_t
38+
#define cudaPointerGetAttributes hipPointerGetAttributes
39+
#define cudaSetDevice hipSetDevice
40+
#define cudaStreamCreate hipStreamCreate
41+
#define cudaStreamDestroy hipStreamDestroy
42+
#define cudaStream_t hipStream_t
43+
#define cudaSuccess hipSuccess
44+
45+
// ROCm 7.0 did add __shfl_down_sync et al, but the following hack still works.
1246
// Since mask is full 0xffffffff, we can use __shfl_down instead.
1347
#define __shfl_down_sync(mask, val, offset) __shfl_down(val, offset)
1448
#define __shfl_up_sync(mask, val, offset) __shfl_up(val, offset)
15-
// ROCm warpSize is constexpr and is either 32 or 64 depending on gfx arch.
16-
#define WARPSIZE warpSize
17-
// ROCm doesn't have atomicAdd_block, but it should be semantically the same as atomicAdd
18-
#define atomicAdd_block atomicAdd
19-
#else
49+
50+
// warpSize is only allowed for device code.
51+
// HIP header used to define warpSize as a constexpr that was either 32 or 64
52+
// depending on the target device, and then always set it to 64 for host code.
53+
static inline constexpr int WARP_SIZE_INTERNAL() {
54+
#if defined(__GFX9__)
55+
return 64;
56+
#else // __GFX9__
57+
return 32;
58+
#endif // __GFX9__
59+
}
60+
#define WARPSIZE (WARP_SIZE_INTERNAL())
61+
62+
#else // __HIP_PLATFORM_AMD__
2063
// CUDA warpSize is not a constexpr, but always 32
2164
#define WARPSIZE 32
2265
#endif // defined(__HIP_PLATFORM_AMD__) || defined(__HIP__)

include/LightGBM/cuda/cuda_utils.hu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88

99
#ifdef USE_CUDA
1010

11+
#if defined(USE_ROCM)
12+
#include <LightGBM/cuda/cuda_rocm_interop.h>
13+
#else
1114
#include <cuda.h>
1215
#include <cuda_runtime.h>
16+
#endif
1317
#include <stdio.h>
1418
#include <nccl.h>
1519

0 commit comments

Comments
 (0)