Skip to content

Commit 735570e

Browse files
authored
Merge pull request #4 from lcskrishna/rocm_enablement
Fixes builds for non-rocm.
2 parents 612ad14 + bbf5a72 commit 735570e

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def get_extensions():
5959

6060
if not torch.cuda.is_available():
6161
print("PyTorch GPU support is not available. Skipping compilation of CUDA extensions")
62-
if CUDA_HOME is None or not IS_ROCM and torch.cuda.is_available():
62+
if (CUDA_HOME is None and ROCM_HOME is None) and torch.cuda.is_available():
6363
print("CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions")
6464
print("If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit")
6565

torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ __global__ void _dequantize_int4_kernel(
167167
// All b values within a 16x16 tile should fall within the same q group
168168
// Hence we load 1 scale and zero per loop
169169
int qgroup = ks[0] / groupSize;
170+
#if defined(USE_ROCM)
170171
__nv_bfloat162 scale2 = __bfloat162bfloat162(__hip_bfloat16(1.0f));
171172
__nv_bfloat162 zero2 = __bfloat162bfloat162(__hip_bfloat16(1.0f));
172173

@@ -177,6 +178,11 @@ __global__ void _dequantize_int4_kernel(
177178
scale2 = __bfloat162bfloat162(pSZ[0]);
178179
zero2 = __bfloat162bfloat162(pSZ[1]);
179180
}
181+
#else
182+
const __nv_bfloat16 *pSZ = reinterpret_cast<const __nv_bfloat16*>(&scales_and_zeros.value()[qgroup][n0][0]);
183+
__nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]);
184+
__nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]);
185+
#endif
180186

181187
#pragma unroll
182188
for (int i = 0; i < 4; i++) {

0 commit comments

Comments
 (0)