Skip to content

Commit 19b014c

Browse files
authored
Update ptx_mma.cuh (issue#998)
Issue #998
1 parent 6d4d21d commit 19b014c

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

torchao/csrc/cuda/fp6_llm/ptx_mma.cuh

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,18 @@
3434
#define PTX_MMA_CUH
3535

3636
#include <cuda.h>
37-
#include <cuda_fp16.h>
37+
#include <cuda_bf16.h> // Include BF16 header
3838
#include <cuda_runtime.h>
3939

4040
#include <assert.h>
4141
#include "configs.h"
4242

4343
// MODIFICATION NOTE: to support MSVC
4444
// - uint32_t __restrict__ Reg[][4] is changed to uint32_t (* __restrict__ Reg)[4]
45-
// - half __restrict__ (*read_SPTR) is changed to half (* __restrict__ read_SPTR)
45+
// - __nv_bfloat16 __restrict__ (*read_SPTR) is changed to __nv_bfloat16 (* __restrict__ read_SPTR)
4646
template <typename TilingConfig>
4747
__device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[4],
48-
half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
48+
__nv_bfloat16 (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
4949
int slice_id) {
5050
#ifdef DEBUG_MODE
5151
static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) );
@@ -82,19 +82,19 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[
8282
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
8383
: "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3])
8484
: "r"(smem_local_ptr));
85-
smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(half);
85+
smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(__nv_bfloat16);
8686
}
8787
}
8888
}
8989

9090
// MODIFICATION NOTE: to support MSVC, the function signature is changed from
9191
// MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b).
9292
__device__ __forceinline__ void
93-
MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t * __restrict__ b)
93+
MMA_BF16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t * __restrict__ b)
9494
{
9595
#if __CUDA_ARCH__ == 750
9696
// m16n8k16 op. requires >=sm_80, so instead we use two m16n8k8 ops.
97-
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
97+
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32"
9898
"{ %0, %1, %2, %3},"
9999
"{ %4, %5},"
100100
"{ %6 },"
@@ -103,7 +103,7 @@ MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t
103103
: "r"(a[0]), "r"(a[1]),
104104
"r"(b[0]),
105105
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
106-
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
106+
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32"
107107
"{ %0, %1, %2, %3},"
108108
"{ %4, %5},"
109109
"{ %6 },"
@@ -114,7 +114,7 @@ MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t
114114
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
115115

116116
#else
117-
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
117+
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"
118118
"{ %0, %1, %2, %3},"
119119
"{ %4, %5, %6, %7 },"
120120
"{ %8, %9 },"

0 commit comments

Comments
 (0)