Skip to content

Commit 61ec4f1

Browse files
committed
[ROCm] re-add support for ROCm builds
Previously #6086 added ROCm support but after numerous rebases it lost critical changes. This PR restores the ROCm build. There are many source file changes but most were automated using the following: ```bash for f in `grep -rl '#ifdef USE_CUDA'` do sed -i 's@#ifdef USE_CUDA@#if defined(USE_CUDA) || defined(USE_ROCM)@g' $f done for f in `grep -rl '#endif // USE_CUDA'` do sed -i 's@#endif // USE_CUDA@#endif // USE_CUDA || USE_ROCM@g' $f done ```
1 parent 336a77d commit 61ec4f1

File tree

82 files changed

+316
-256
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+316
-256
lines changed

CMakeLists.txt

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ endif()
3636

3737
project(lightgbm LANGUAGES C CXX)
3838

39-
if(USE_CUDA)
39+
if(USE_CUDA OR USE_ROCM)
4040
set(CMAKE_CXX_STANDARD 17)
4141
elseif(BUILD_CPP_TEST)
4242
set(CMAKE_CXX_STANDARD 14)
@@ -480,10 +480,21 @@ set(
480480
src/cuda/cuda_algorithms.cu
481481
)
482482

483-
if(USE_CUDA)
483+
if(USE_CUDA OR USE_ROCM)
484484
list(APPEND LGBM_SOURCES ${LGBM_CUDA_SOURCES})
485485
endif()
486486

487+
if(USE_ROCM)
488+
set(CU_FILES "")
489+
foreach(file IN LISTS LGBM_CUDA_SOURCES)
490+
string(REGEX MATCH "\\.cu$" is_cu_file "${file}")
491+
if(is_cu_file)
492+
list(APPEND CU_FILES "${file}")
493+
endif()
494+
endforeach()
495+
set_source_files_properties(${CU_FILES} PROPERTIES LANGUAGE HIP)
496+
endif()
497+
487498
add_library(lightgbm_objs OBJECT ${LGBM_SOURCES})
488499

489500
if(BUILD_CLI)
@@ -632,6 +643,10 @@ if(USE_CUDA)
632643
endif()
633644
endif()
634645

646+
if(USE_ROCM)
647+
target_link_libraries(lightgbm_objs PUBLIC hip::host)
648+
endif()
649+
635650
if(WIN32)
636651
if(MINGW OR CYGWIN)
637652
target_link_libraries(lightgbm_objs PUBLIC ws2_32 iphlpapi)

include/LightGBM/bin.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -600,13 +600,13 @@ class MultiValBin {
600600

601601
virtual MultiValBin* Clone() = 0;
602602

603-
#ifdef USE_CUDA
603+
#if defined(USE_CUDA) || defined(USE_ROCM)
604604
virtual const void* GetRowWiseData(uint8_t* bit_type,
605605
size_t* total_size,
606606
bool* is_sparse,
607607
const void** out_data_ptr,
608608
uint8_t* data_ptr_bit_type) const = 0;
609-
#endif // USE_CUDA
609+
#endif // USE_CUDA || USE_ROCM
610610
};
611611

612612
inline uint32_t BinMapper::ValueToBin(double value) const {

include/LightGBM/cuda/cuda_algorithms.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
#ifndef LIGHTGBM_CUDA_CUDA_ALGORITHMS_HPP_
88
#define LIGHTGBM_CUDA_CUDA_ALGORITHMS_HPP_
99

10-
#ifdef USE_CUDA
10+
#if defined(USE_CUDA) || defined(USE_ROCM)
1111

12+
#if defined(USE_CUDA)
1213
#include <cuda.h>
1314
#include <cuda_runtime.h>
15+
#endif
1416
#include <stdio.h>
1517

1618
#include <LightGBM/bin.h>
@@ -619,5 +621,5 @@ __device__ VAL_T PercentileDevice(const VAL_T* values,
619621

620622
} // namespace LightGBM
621623

622-
#endif // USE_CUDA
624+
#endif // USE_CUDA || USE_ROCM
623625
#endif // LIGHTGBM_CUDA_CUDA_ALGORITHMS_HPP_

include/LightGBM/cuda/cuda_column_data.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* Licensed under the MIT License. See LICENSE file in the project root for license information.
44
*/
55

6-
#ifdef USE_CUDA
6+
#if defined(USE_CUDA) || defined(USE_ROCM)
77

88
#ifndef LIGHTGBM_CUDA_CUDA_COLUMN_DATA_HPP_
99
#define LIGHTGBM_CUDA_CUDA_COLUMN_DATA_HPP_
@@ -139,4 +139,4 @@ class CUDAColumnData {
139139

140140
#endif // LIGHTGBM_CUDA_CUDA_COLUMN_DATA_HPP_
141141

142-
#endif // USE_CUDA
142+
#endif // USE_CUDA || USE_ROCM

include/LightGBM/cuda/cuda_metadata.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* Licensed under the MIT License. See LICENSE file in the project root for license information.
44
*/
55

6-
#ifdef USE_CUDA
6+
#if defined(USE_CUDA) || defined(USE_ROCM)
77

88
#ifndef LIGHTGBM_CUDA_CUDA_METADATA_HPP_
99
#define LIGHTGBM_CUDA_CUDA_METADATA_HPP_
@@ -55,4 +55,4 @@ class CUDAMetadata {
5555

5656
#endif // LIGHTGBM_CUDA_CUDA_METADATA_HPP_
5757

58-
#endif // USE_CUDA
58+
#endif // USE_CUDA || USE_ROCM

include/LightGBM/cuda/cuda_metric.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#ifndef LIGHTGBM_CUDA_CUDA_METRIC_HPP_
88
#define LIGHTGBM_CUDA_CUDA_METRIC_HPP_
99

10-
#ifdef USE_CUDA
10+
#if defined(USE_CUDA) || defined(USE_ROCM)
1111

1212
#include <LightGBM/cuda/cuda_utils.hu>
1313
#include <LightGBM/metric.h>
@@ -39,6 +39,6 @@ class CUDAMetricInterface: public HOST_METRIC {
3939

4040
} // namespace LightGBM
4141

42-
#endif // USE_CUDA
42+
#endif // USE_CUDA || USE_ROCM
4343

4444
#endif // LIGHTGBM_CUDA_CUDA_METRIC_HPP_

include/LightGBM/cuda/cuda_objective_function.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#ifndef LIGHTGBM_CUDA_CUDA_OBJECTIVE_FUNCTION_HPP_
88
#define LIGHTGBM_CUDA_CUDA_OBJECTIVE_FUNCTION_HPP_
99

10-
#ifdef USE_CUDA
10+
#if defined(USE_CUDA) || defined(USE_ROCM)
1111

1212
#include <LightGBM/cuda/cuda_utils.hu>
1313
#include <LightGBM/objective_function.h>
@@ -81,6 +81,6 @@ class CUDAObjectiveInterface: public HOST_OBJECTIVE {
8181

8282
} // namespace LightGBM
8383

84-
#endif // USE_CUDA
84+
#endif // USE_CUDA || USE_ROCM
8585

8686
#endif // LIGHTGBM_CUDA_CUDA_OBJECTIVE_FUNCTION_HPP_

include/LightGBM/cuda/cuda_random.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
#ifndef LIGHTGBM_CUDA_CUDA_RANDOM_HPP_
66
#define LIGHTGBM_CUDA_CUDA_RANDOM_HPP_
77

8-
#ifdef USE_CUDA
8+
#if defined(USE_CUDA) || defined(USE_ROCM)
99

10+
#if defined(USE_CUDA)
1011
#include <cuda.h>
1112
#include <cuda_runtime.h>
13+
#endif
1214

1315
namespace LightGBM {
1416

@@ -69,6 +71,6 @@ class CUDARandom {
6971

7072
} // namespace LightGBM
7173

72-
#endif // USE_CUDA
74+
#endif // USE_CUDA || USE_ROCM
7375

7476
#endif // LIGHTGBM_CUDA_CUDA_RANDOM_HPP_
Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/*!
22
* Copyright(C) 2023 Advanced Micro Devices, Inc. All rights reserved.
33
*/
4-
#ifdef USE_CUDA
4+
#if defined(USE_CUDA) || defined(USE_ROCM)
55

66
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIP__)
77
// ROCm doesn't have __shfl_down_sync, only __shfl_down without mask.
@@ -12,9 +12,38 @@
1212
#define WARPSIZE warpSize
1313
// ROCm doesn't have atomicAdd_block, but it should be semantically the same as atomicAdd
1414
#define atomicAdd_block atomicAdd
15-
#else
15+
// hipify
16+
#include <hip/hip_runtime.h>
17+
#define cudaDeviceProp hipDeviceProp_t
18+
#define cudaDeviceSynchronize hipDeviceSynchronize
19+
#define cudaError_t hipError_t
20+
#define cudaFree hipFree
21+
#define cudaFreeHost hipFreeHost
22+
#define cudaGetDevice hipGetDevice
23+
#define cudaGetDeviceProperties hipGetDeviceProperties
24+
#define cudaGetErrorName hipGetErrorName
25+
#define cudaGetErrorString hipGetErrorString
26+
#define cudaGetLastError hipGetLastError
27+
#define cudaHostAlloc hipHostAlloc
28+
#define cudaHostAllocPortable hipHostAllocPortable
29+
#define cudaMalloc hipMalloc
30+
#define cudaMemcpy hipMemcpy
31+
#define cudaMemcpyAsync hipMemcpyAsync
32+
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
33+
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
34+
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
35+
#define cudaMemoryTypeHost hipMemoryTypeHost
36+
#define cudaMemset hipMemset
37+
#define cudaPointerAttributes hipPointerAttribute_t
38+
#define cudaPointerGetAttributes hipPointerGetAttributes
39+
#define cudaSetDevice hipSetDevice
40+
#define cudaStreamCreate hipStreamCreate
41+
#define cudaStreamDestroy hipStreamDestroy
42+
#define cudaStream_t hipStream_t
43+
#define cudaSuccess hipSuccess
44+
#else // __HIP_PLATFORM_AMD__ || __HIP__
1645
// CUDA warpSize is not a constexpr, but always 32
1746
#define WARPSIZE 32
1847
#endif
1948

20-
#endif
49+
#endif // USE_CUDA || USE_ROCM

include/LightGBM/cuda/cuda_row_data.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* Licensed under the MIT License. See LICENSE file in the project root for license information.
44
*/
55

6-
#ifdef USE_CUDA
6+
#if defined(USE_CUDA) || defined(USE_ROCM)
77

88
#ifndef LIGHTGBM_CUDA_CUDA_ROW_DATA_HPP_
99
#define LIGHTGBM_CUDA_CUDA_ROW_DATA_HPP_
@@ -177,4 +177,4 @@ class CUDARowData {
177177
} // namespace LightGBM
178178
#endif // LIGHTGBM_CUDA_CUDA_ROW_DATA_HPP_
179179

180-
#endif // USE_CUDA
180+
#endif // USE_CUDA || USE_ROCM

0 commit comments

Comments
 (0)