Skip to content

Make Quant-LLM compatible with BF16 #998

@gau-nernst

Description

@gau-nernst

Quant-LLM code: https://github.com/pytorch/ao/tree/main/torchao/csrc/cuda/fp6_llm

Currently Quant-LLM kernel (backing FPx in torchao) only works with FP16. This creates a small divergence from other quantization methods, which all work with BF16. Since all recent models are trained and released with BF16, having BF16 support potentially improve accuracy for FPx models.

Might be over-simplifying, but I think it's just the matter of modifying dequant logic and MMA instructions (as well as update dtype in function signature appropriately)

template<int EXPONENT, int MANTISSA>
__device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t *In, uint32_t *Out1, uint32_t *Out2) {
//
constexpr int RIGHT_SHIFT = 5 - EXPONENT;
constexpr int MASK1 = 0x80000000;
constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA;
constexpr int MASK3 = MASK2 & 0x7fffffff;
constexpr int MASK = MASK3 | MASK3 >> 16;
//
*Out1 = *In & 0x80008000;
*Out1 |= ( (*In) & MASK ) >> RIGHT_SHIFT;
//
*In = (*In) << 8;
*Out2 = *In & 0x80008000;
*Out2 |= ( (*In) & MASK ) >> RIGHT_SHIFT;
}

asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{ %0, %1, %2, %3},"
"{ %4, %5, %6, %7 },"
"{ %8, %9 },"
"{ %10, %11, %12, %13 };"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
"r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));

cc @msaroufim @HDCharles

I might try to do it myself, but I think it would be an interesting good first issue task too. @tobiasvanderwerff Would you be interested?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions