Skip to content

Issue #998 Fix #1074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions torchao/csrc/cuda/fp6_llm/ptx_mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,18 @@
#define PTX_MMA_CUH

#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h> // Include BF16 header
#include <cuda_runtime.h>

#include <assert.h>
#include "configs.h"

// MODIFICATION NOTE: to support MSVC
// - uint32_t __restrict__ Reg[][4] is changed to uint32_t (* __restrict__ Reg)[4]
// - half __restrict__ (*read_SPTR) is changed to half (* __restrict__ read_SPTR)
// - __nv_bfloat16 __restrict__ (*read_SPTR) is changed to __nv_bfloat16 (* __restrict__ read_SPTR)
template <typename TilingConfig>
__device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[4],
half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
__nv_bfloat16 (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
int slice_id) {
#ifdef DEBUG_MODE
static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) );
Expand Down Expand Up @@ -82,19 +82,19 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3])
: "r"(smem_local_ptr));
smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(half);
smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(__nv_bfloat16);
}
}
}

// MODIFICATION NOTE: to support MSVC, the function signature is changed from
// MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b).
__device__ __forceinline__ void
MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t * __restrict__ b)
MMA_BF16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t * __restrict__ b)
{
#if __CUDA_ARCH__ == 750
// m16n8k16 op. requires >=sm_80, so instead we use two m16n8k8 ops.
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32"
"{ %0, %1, %2, %3},"
"{ %4, %5},"
"{ %6 },"
Expand All @@ -103,7 +103,7 @@ MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t
: "r"(a[0]), "r"(a[1]),
"r"(b[0]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32"
"{ %0, %1, %2, %3},"
"{ %4, %5},"
"{ %6 },"
Expand All @@ -114,7 +114,7 @@ MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));

#else
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"
"{ %0, %1, %2, %3},"
"{ %4, %5, %6, %7 },"
"{ %8, %9 },"
Expand Down
52 changes: 17 additions & 35 deletions torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh
Original file line number Diff line number Diff line change
@@ -1,25 +1,8 @@
// Copyright 2024 FP6-LLM authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/utils_parallel_dequant.cuh
// To support MSVC, all instances of u_int32_t are changed to uint32_t.

#ifndef UTILS_PARALLELDEQUANT_CUH
#define UTILS_PARALLELDEQUANT_CUH

#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h> // Include BF16 header
#include <cuda_runtime.h>

/*
Expand All @@ -28,33 +11,32 @@
* Note: Simplified Exponent calculation is applied.
*/
template<int EXPONENT, int MANTISSA>
__device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t *In, uint32_t *Out1, uint32_t *Out2) {
//
__device__ __forceinline__ void FPx_BF16_Cast_4Way(uint32_t *In, uint32_t *Out1, uint32_t *Out2) {
constexpr int RIGHT_SHIFT = 5 - EXPONENT;
constexpr int MASK1 = 0x80000000;
constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA;
constexpr int MASK3 = MASK2 & 0x7fffffff;
constexpr int MASK = MASK3 | MASK3 >> 16;
//

*Out1 = *In & 0x80008000;
*Out1 |= ( (*In) & MASK ) >> RIGHT_SHIFT;
//

*In = (*In) << 8;
*Out2 = *In & 0x80008000;
*Out2 |= ( (*In) & MASK ) >> RIGHT_SHIFT;
}

template<int EXPONENT, int MANTISSA>
__device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scale) {
__device__ __forceinline__ uint32_t MultScale(uint32_t PackedBF16Pair, __nv_bfloat16 Scale) {
constexpr int BIAS_OFFSET = (int(1) << (5-1)) - (int(1) << (EXPONENT-1));
constexpr int BIAS = int(1) << BIAS_OFFSET;
//
half* FP16_1 = reinterpret_cast<half*>(&PackedFP16Pair);
half* FP16_2 = FP16_1 + 1;

__nv_bfloat16* BF16_1 = reinterpret_cast<__nv_bfloat16*>(&PackedBF16Pair);
__nv_bfloat16* BF16_2 = BF16_1 + 1;
uint32_t output;
half* output_half_ptr = reinterpret_cast<half*>(&output);
output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2half(1.0f*BIAS)), Scale);
output_half_ptr[1] = __hmul( __hmul(*FP16_2,__float2half(1.0f*BIAS)), Scale);
__nv_bfloat16* output_bf16_ptr = reinterpret_cast<__nv_bfloat16*>(&output);
output_bf16_ptr[0] = __hmul( __hmul(*BF16_1,__float2bfloat16(1.0f*BIAS)), Scale);
output_bf16_ptr[1] = __hmul( __hmul(*BF16_2,__float2bfloat16(1.0f*BIAS)), Scale);
return output;
}

Expand All @@ -77,7 +59,7 @@ __device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg)
uint32_t *Frag_PTR_1bit = read_RPTR_1bit;
uint32_t *Frag_PTR_2bit = read_RPTR_2bit;
uint32_t *Frag_PTR_4bit = read_RPTR_4bit;
half *Scale_RPTR = reinterpret_cast<half*>(Scales);
__nv_bfloat16 *Scale_RPTR = reinterpret_cast<__nv_bfloat16*>(Scales);
// Dequantizing 32 FP6, each Loop dequantizing 4 FP6
#pragma unroll(8)
for(int i=0; i<8; i++) {
Expand Down Expand Up @@ -106,13 +88,13 @@ __device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg)
}
//
uint32_t out1, out2;
FPx_FP16_Cast_4Way<EXPONENT, MANTISSA>(&Packed_FP6, &out1, &out2);
FPx_BF16_Cast_4Way<EXPONENT, MANTISSA>(&Packed_FP6, &out1, &out2);
//
*OutputRegs = MultScale<EXPONENT, MANTISSA>(out1, Scale_RPTR[0] ); // Muliply FP16 scales
*OutputRegs = MultScale<EXPONENT, MANTISSA>(out1, Scale_RPTR[0] ); // Muliply BF16 scales
OutputRegs += 1;
*OutputRegs = MultScale<EXPONENT, MANTISSA>(out2, Scale_RPTR[1]); // Muliply FP16 scales
*OutputRegs = MultScale<EXPONENT, MANTISSA>(out2, Scale_RPTR[1]); // Muliply BF16 scales
OutputRegs += 1;
// Updating offset for FP16 scales for every two iterations
// Updating offset for BF16 scales for every two iterations
if(i%2==1) Scale_RPTR += 2;
}

Expand All @@ -121,7 +103,7 @@ __device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg)
/*
*
*/
__device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, half* WARP_SPTR_Scales) {
__device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, __nv_bfloat16* WARP_SPTR_Scales) {
int lane_id = threadIdx.x % WARP_SIZE;
uint32_t* SPTR_uint = reinterpret_cast<uint32_t*>(WARP_SPTR_Scales);
uint32_t tmpReg = SPTR_uint[lane_id];
Expand Down
Loading