File tree 2 files changed +7
-1
lines changed
torchao/csrc/cuda/tensor_core_tiled_layout 2 files changed +7
-1
lines changed Original file line number Diff line number Diff line change @@ -59,7 +59,7 @@ def get_extensions():
59
59
60
60
if not torch .cuda .is_available ():
61
61
print ("PyTorch GPU support is not available. Skipping compilation of CUDA extensions" )
62
- if CUDA_HOME is None or not IS_ROCM and torch .cuda .is_available ():
62
+ if ( CUDA_HOME is None and ROCM_HOME is None ) and torch .cuda .is_available ():
63
63
print ("CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions" )
64
64
print ("If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit" )
65
65
Original file line number Diff line number Diff line change @@ -167,6 +167,7 @@ __global__ void _dequantize_int4_kernel(
167
167
// All b values within a 16x16 tile should fall within the same q group
168
168
// Hence we load 1 scale and zero per loop
169
169
int qgroup = ks[0 ] / groupSize;
170
+ #if defined(USE_ROCM)
170
171
__nv_bfloat162 scale2 = __bfloat162bfloat162 (__hip_bfloat16 (1 .0f ));
171
172
__nv_bfloat162 zero2 = __bfloat162bfloat162 (__hip_bfloat16 (1 .0f ));
172
173
@@ -177,6 +178,11 @@ __global__ void _dequantize_int4_kernel(
177
178
scale2 = __bfloat162bfloat162 (pSZ[0 ]);
178
179
zero2 = __bfloat162bfloat162 (pSZ[1 ]);
179
180
}
181
+ #else
182
+ const __nv_bfloat16 *pSZ = reinterpret_cast <const __nv_bfloat16*>(&scales_and_zeros.value ()[qgroup][n0][0 ]);
183
+ __nv_bfloat162 scale2 = __bfloat162bfloat162 (pSZ[0 ]);
184
+ __nv_bfloat162 zero2 = __bfloat162bfloat162 (pSZ[1 ]);
185
+ #endif
180
186
181
187
#pragma unroll
182
188
for (int i = 0 ; i < 4 ; i++) {
You can’t perform that action at this time.
0 commit comments