3434#define PTX_MMA_CUH
3535
3636#include < cuda.h>
37- #include < cuda_fp16 .h>
37+ #include < cuda_bf16 .h> // Include BF16 header
3838#include < cuda_runtime.h>
3939
4040#include < assert.h>
4141#include " configs.h"
4242
4343// MODIFICATION NOTE: to support MSVC
4444// - uint32_t __restrict__ Reg[][4] is changed to uint32_t (* __restrict__ Reg)[4]
45- // - half __restrict__ (*read_SPTR) is changed to half (* __restrict__ read_SPTR)
45+ // - __nv_bfloat16 __restrict__ (*read_SPTR) is changed to __nv_bfloat16 (* __restrict__ read_SPTR)
4646template <typename TilingConfig>
4747__device__ __forceinline__ void B_FromSharedToReg (uint32_t (* __restrict__ Reg)[4],
48- half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
48+ __nv_bfloat16 (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
4949 int slice_id) {
5050 #ifdef DEBUG_MODE
5151 static_assert ( (TilingConfig::WARP_COL_MMA_TENSORS==1 ) || (TilingConfig::WARP_COL_MMA_TENSORS%2 ==0 ) );
@@ -82,19 +82,19 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[
8282 asm volatile (" ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n "
8383 : " =r" (Reg[i][0 ]), " =r" (Reg[i][1 ]), " =r" (Reg[i][2 ]), " =r" (Reg[i][3 ])
8484 : " r" (smem_local_ptr));
85- smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof (half );
85+ smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof (__nv_bfloat16 );
8686 }
8787 }
8888}
8989
9090// MODIFICATION NOTE: to support MSVC, the function signature is changed from
9191// MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b).
9292__device__ __forceinline__ void
93- MMA_FP16_M16N8K16 (uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t * __restrict__ b)
93+ MMA_BF16_M16N8K16 (uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t * __restrict__ b)
9494{
9595 #if __CUDA_ARCH__ == 750
9696 // m16n8k16 op. requires >=sm_80, so instead we use two m16n8k8 ops.
97- asm volatile (" mma.sync.aligned.m16n8k8.row.col.f32.f16.f16 .f32"
97+ asm volatile (" mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16 .f32"
9898 " { %0, %1, %2, %3},"
9999 " { %4, %5},"
100100 " { %6 },"
@@ -103,7 +103,7 @@ MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t
103103 : " r" (a[0 ]), " r" (a[1 ]),
104104 " r" (b[0 ]),
105105 " r" (c[0 ]), " r" (c[1 ]), " r" (c[2 ]), " r" (c[3 ]));
106- asm volatile (" mma.sync.aligned.m16n8k8.row.col.f32.f16.f16 .f32"
106+ asm volatile (" mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16 .f32"
107107 " { %0, %1, %2, %3},"
108108 " { %4, %5},"
109109 " { %6 },"
@@ -114,7 +114,7 @@ MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t
114114 " r" (c[0 ]), " r" (c[1 ]), " r" (c[2 ]), " r" (c[3 ]));
115115
116116 #else
117- asm volatile (" mma.sync.aligned.m16n8k16.row.col.f32.f16.f16 .f32"
117+ asm volatile (" mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16 .f32"
118118 " { %0, %1, %2, %3},"
119119 " { %4, %5, %6, %7 },"
120120 " { %8, %9 },"
0 commit comments