-
Notifications
You must be signed in to change notification settings - Fork 438
Description
Quant-LLM code: https://github.com/pytorch/ao/tree/main/torchao/csrc/cuda/fp6_llm
Currently Quant-LLM kernel (backing FPx in torchao) only works with FP16. This creates a small divergence from other quantization methods, which all work with BF16. Since all recent models are trained and released with BF16, having BF16 support potentially improve accuracy for FPx models.
Might be over-simplifying, but I think it's just the matter of modifying dequant logic and MMA instructions (as well as update dtype in function signature appropriately)
ao/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh
Lines 30 to 45 in 09b8b3c
| template<int EXPONENT, int MANTISSA> | |
| __device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t *In, uint32_t *Out1, uint32_t *Out2) { | |
| // | |
| constexpr int RIGHT_SHIFT = 5 - EXPONENT; | |
| constexpr int MASK1 = 0x80000000; | |
| constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA; | |
| constexpr int MASK3 = MASK2 & 0x7fffffff; | |
| constexpr int MASK = MASK3 | MASK3 >> 16; | |
| // | |
| *Out1 = *In & 0x80008000; | |
| *Out1 |= ( (*In) & MASK ) >> RIGHT_SHIFT; | |
| // | |
| *In = (*In) << 8; | |
| *Out2 = *In & 0x80008000; | |
| *Out2 |= ( (*In) & MASK ) >> RIGHT_SHIFT; | |
| } |
ao/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh
Lines 117 to 125 in 09b8b3c
| asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" | |
| "{ %0, %1, %2, %3}," | |
| "{ %4, %5, %6, %7 }," | |
| "{ %8, %9 }," | |
| "{ %10, %11, %12, %13 };" | |
| : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) | |
| : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), | |
| "r"(b[0]), "r"(b[1]), | |
| "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); |
I might try to do it myself, but I think it would be an interesting good first issue task too. @tobiasvanderwerff Would you be interested?