diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index c6d28c0bd1..25967baa25 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -8,29 +8,42 @@ def benchmark(m: int, k: int, n: int): - float_data = torch.randn(n, k, dtype=torch.half, device="cuda") - fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayout(3, 2)) - fp16_weight = fp6_weight.dequantize(torch.half) - - fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda") - fp6_output = F.linear(fp16_act, fp6_weight) + float_data_fp16 = torch.randn(n, k, dtype=torch.float16, device="cuda") + float_data_bf16 = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") + fp6_weight_fp16 = to_affine_quantized_fpx(float_data_fp16, FloatxTensorCoreLayout(3, 2)) + fp6_weight_bf16 = to_affine_quantized_fpx(float_data_bf16, FloatxTensorCoreLayout(3, 2)) + fp16_weight = fp6_weight_fp16.dequantize(torch.float16) + bf16_weight = fp6_weight_bf16.dequantize(torch.bfloat16) + + fp16_act = torch.randn(m, k, dtype=torch.float16, device="cuda") + bf16_act = fp16_act.to(torch.bfloat16) + fp6_output_fp16 = F.linear(fp16_act, fp6_weight_fp16) + fp6_output_bf16 = F.linear(bf16_act, fp6_weight_bf16) fp16_output = F.linear(fp16_act, fp16_weight) + bf16_output = F.linear(bf16_act, bf16_weight) - fp6_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight) fp16_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp16_weight) + bf16_time = benchmark_torch_function_in_microseconds(F.linear, bf16_act, bf16_weight) + fp6_time_fp16 = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight_fp16) + fp6_time_bf16 = benchmark_torch_function_in_microseconds(F.linear, bf16_act, fp6_weight_bf16) # follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py # doesn't seem to be the right way to check for correctness - correct = (fp6_output - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3 + correct_fp16 = (fp6_output_fp16 - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3 + correct_bf16 = (fp6_output_bf16 - bf16_output).abs().mean() / bf16_output.abs().mean() < 1e-2 return { "m": m, "k": k, "n": n, - "fp6_latency (ms)": fp6_time, - "fp16_latency (ms)": fp16_time, - "speedup (d/s)": fp16_time / fp6_time, - "correct": correct, + "fp6-fp16 latency (ms)": fp6_time_fp16, + "fp16 latency (ms)": fp16_time, + "speedup fp16": fp16_time / fp6_time_fp16, + "correct fp16": correct_fp16, + "fp6-bf16 latency (ms)": fp6_time_bf16, + "bf16 latency (ms)": bf16_time, + "speedup bf16": bf16_time / fp6_time_bf16, + "correct bf16": correct_bf16, } diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index 93dc7515d9..875a8c8d5e 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -91,16 +91,17 @@ def test_to_copy_device(self, ebits, mbits): @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+") @parametrize("ebits,mbits", _Floatx_DTYPES) @parametrize("bias", [False, True]) + @parametrize("dtype", [torch.half, torch.bfloat16]) @pytest.mark.skipif(is_fbcode(), reason="broken in fbcode") - def test_fpx_weight_only(self, ebits, mbits, bias): + def test_fpx_weight_only(self, ebits, mbits, bias, dtype): N, OC, IC = 4, 256, 64 device = "cuda" - linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=torch.half) + linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=dtype) fpx_linear = copy.deepcopy(linear) quantize_(fpx_linear, fpx_weight_only(ebits, mbits)) - x = torch.randn(N, IC, device=device, dtype=torch.half) + x = torch.randn(N, IC, device=device, dtype=dtype) expected = fpx_linear(x) actual = torch.compile(fpx_linear, fullgraph=True)(x) # somehow compile now changes the result a bit diff --git a/test/test_ops.py b/test/test_ops.py index 31000eafc2..7802fdeaeb 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -33,22 +33,23 @@ class TestOps(TestCase): - def _create_floatx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device): + def _create_floatx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device, dtype): # Randomly initialize each byte nbits = 1 + ebits + mbits floatx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8) - scale = torch.rand(OC).half() + 0.5 - fp16_act = torch.rand(BS, IC).half() + 0.5 + scale = torch.rand(OC).to(dtype) + 0.5 + fp16_act = torch.rand(BS, IC).to(dtype) + 0.5 return floatx_weight.to(device), scale.to(device), fp16_act.to(device) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @parametrize("ebits,mbits", [(3, 2), (2, 2)]) - def test_quant_llm_linear(self, ebits, mbits): + @parametrize("dtype", [torch.half, torch.bfloat16]) + def test_quant_llm_linear(self, ebits, mbits, dtype): BS = 2 OC = 256 IC = 256 splitK = 1 - floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda") + floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda", dtype) # smoke test torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK) @@ -60,19 +61,21 @@ def test_quant_llm_linear(self, ebits, mbits): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) @parametrize("ebits,mbits", [(3, 2), (2, 2)]) - def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK): + @parametrize("dtype", [torch.half, torch.bfloat16]) + def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK, dtype): # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py - floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda") + floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda", dtype) results_floatx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK) - fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).half() + fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).to(dtype) results_fp16 = fp16_act @ fp16_weight.T error = (results_floatx - results_fp16).abs().mean() gt = results_fp16.abs().mean() relative_error = error / gt - assert relative_error < 1e-3 + rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3 + assert relative_error < rtol instantiate_parametrized_tests(TestOps) diff --git a/torchao/csrc/cuda/fp6_llm/README.md b/torchao/csrc/cuda/fp6_llm/README.md index ff764cc27d..8df1fb1416 100644 --- a/torchao/csrc/cuda/fp6_llm/README.md +++ b/torchao/csrc/cuda/fp6_llm/README.md @@ -1,7 +1,7 @@ # FP6-LLM kernel -This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 and W is in FP6 (E3M2 without infinities and NaN). +This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 or BF16 and W is in FP6 (E3M2 without infinities and NaN). On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion. -See https://github.com/pytorch/ao/pull/223 for some benchmark results. +See https://github.com/pytorch/ao/pull/223 and and https://github.com/pytorch/ao/pull/1147 for some benchmark results. diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 978925a3f7..6141dc3d74 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -18,7 +18,6 @@ // - Modified the TilingConfig parameters for SM75 to deal with smaller shared memory // -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750 // at least Turing #include "kernel_matmul.cuh" #include "kernel_reduction.cuh" @@ -26,22 +25,26 @@ #include #include -inline bool isSM75GPU() { - int device; - cudaError_t err = cudaGetDevice(&device); - if (err != cudaSuccess) return false; +#include +#include +#include +#include - int major, minor; - err = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device); - if (err != cudaSuccess) return false; - err = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device); - if (err != cudaSuccess) return false; +// https://github.com/Dao-AILab/flash-attention/blob/478ee666cccbd1b8f63648633003059a8dc6827d/hopper/utils.h#L25 +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while(0) + +#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) - return (major == 7) && (minor == 5); -} -template +template static void Kernel_Ex(cudaStream_t stream, const uint4 *Weight, const half *Scales, @@ -59,7 +62,7 @@ static void Kernel_Ex(cudaStream_t stream, printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", TilingConfig::TILE_M, TilingConfig::TILE_K, TilingConfig::TILE_N); #endif static size_t SHMEM_SZ = max(TilingConfig::SMEM_SIZE_B_TILE+SMEM_SIZE_PER_TB_A_TILE, TilingConfig::SMEM_SIZE_C_TILE); - cudaFuncSetAttribute(QUANT_GEMM_Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ); + cudaFuncSetAttribute(QUANT_GEMM_Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ); size_t dimN = (N_Global-1) / TilingConfig::TILE_N + 1; size_t dimM = M_Global * Split_K / TilingConfig::TILE_M; dim3 GridDim(dimN, dimM, 1); @@ -70,22 +73,24 @@ static void Kernel_Ex(cudaStream_t stream, GridDim.x, GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z, SHMEM_SZ); printf("\n"); #endif - QUANT_GEMM_Kernel<<>> + QUANT_GEMM_Kernel<<>> (Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + CHECK_CUDA_KERNEL_LAUNCH(); } -template -cudaError_t fpx_linear_kernel(cudaStream_t stream, +template +void fpx_linear_kernel(cudaStream_t stream, const uint4 *Weight, const half *Scales, const half *B, - half *C, + InputDataType *C, const size_t M_Global, const size_t N_Global, const size_t K_Global, float *Reduction_Workspace, // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) int Split_K) { + static_assert(std::is_same::value || std::is_same::value, "Type must be 'half' or '__nv_bfloat16'"); assert(M_Global % 256 == 0); assert(K_Global % 64 == 0); assert(N_Global>0); @@ -99,40 +104,49 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, if(N_Global>64 && N_Global<=128) N_PowerOf2 = 128; if(N_Global>128) N_PowerOf2 = ((N_Global-1)/128+1) * 128; - if (isSM75GPU() && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0)) { + // Check GPU Compute Capability + int device, major, minor; + CHECK_CUDA(cudaGetDevice(&device)); + CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); + CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); + const bool is_sm75_gpu = (major == 7) && (minor == 5); + if (is_sm75_gpu && std::is_same::value) + TORCH_CHECK(false, "Bfloat16 inputs are not supported for SM75"); + if ((major < 7) || (major == 7 && minor < 5)) + TORCH_CHECK(false, "FP6LLM_API Error: FP6LLM requires GPU with SM75 or higher!\n"); + + if (is_sm75_gpu && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0)) { // For SM75 and N >= 64, we use a different TilingConfig to deal with smaller shared memory. if (Split_K == 1) { - Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); } else { - Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); + Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); } } else { if (Split_K == 1) { switch (N_PowerOf2) { - case 8: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 16: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 32: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 64: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 128: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 8: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 16: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 32: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 64: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 128: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { - printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); - return cudaErrorUnknown; + TORCH_CHECK(false, "FP6LLM_API Error: Unsupported N dimension ", N_PowerOf2); } - Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; } } else { switch (N_PowerOf2) { - case 8: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 16: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 32: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 64: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 128: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 8: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 16: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 32: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 64: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 128: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { - printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); - return cudaErrorUnknown; + TORCH_CHECK(false, "FP6LLM_API Error: Unsupported N dimension ", N_PowerOf2); } - Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; } } } @@ -141,17 +155,30 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, // Reduction for SplitK dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1); dim3 BlockDim(WARP_SIZE, 1, 1); - SplitK_Reduction<<>>(C, Reduction_Workspace, M_Global, N_Global, Split_K); + SplitK_Reduction<<>>(C, Reduction_Workspace, M_Global, N_Global, Split_K); + CHECK_CUDA_KERNEL_LAUNCH(); } - - return cudaGetLastError(); } -#include -#include -#include -#include +// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h +#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Half: { \ + using torch_t = at::Half; \ + using nv_t = half; \ + __VA_ARGS__(); \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using torch_t = at::BFloat16; \ + using nv_t = __nv_bfloat16; \ + __VA_ARGS__(); \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } namespace torchao { // MODIFICATION NOTE: dtype of _weights is changed to uint8 @@ -163,12 +190,12 @@ Standard definition of linear layer: Out = In * trans(W), where In, Out, and After Equivalent transformation : trans(Out) = W * trans(In). Note that we do not perform "transpose" during runtime, we instead interpret the In/Out as column-major matrices when calling our CUDA kernel. [Inputs] - _in_feats: tensor of shape [B, IC]; // half + _in_feats: tensor of shape [B, IC]; // half or bf16 _weights: int tensor of shape [OC, IC // 8 * x]; // x UINT8 words contains 8 FPx weights. - _scales: tensor of shape [OC]; // half + _scales: tensor of shape [OC]; // half or bf16 splitK: spliting the MatMul problem along K dimension for higher GPU utilization, default 1. [Outputs] - _out_feats: tensor of shape [B, OC]; // half + _out_feats: tensor of shape [B, OC]; // half or bf16 */ torch::Tensor fp_eXmY_linear_forward_cuda( int64_t EXPONENT, @@ -188,14 +215,8 @@ torch::Tensor fp_eXmY_linear_forward_cuda( int M = num_out_channels; int K = num_in_channels; int N = num_in_feats; - // Input Tensors - auto weight = reinterpret_cast(_weights.data_ptr()); // weights is [OC, IC] but in FP6. - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto scales = reinterpret_cast(_scales.data_ptr()); - // Output Tensors auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); at::Tensor _out_feats = torch::empty({num_in_feats, num_out_channels}, options); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); options = torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device()); at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options); @@ -205,26 +226,33 @@ torch::Tensor fp_eXmY_linear_forward_cuda( // this fixes problem with CUDA graphs when used with torch.compile() auto stream = at::cuda::getCurrentCUDAStream(); - // officially supported in Quant-LLM - if (EXPONENT == 3 && MANTISSA == 2) - fpx_linear_kernel<3, 2>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - else if (EXPONENT == 2 && MANTISSA == 2) - fpx_linear_kernel<2, 2>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - - // experimental - else if (EXPONENT == 2 && MANTISSA == 3) - fpx_linear_kernel<2, 3>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - else if (EXPONENT == 3 && MANTISSA == 1) - fpx_linear_kernel<3, 1>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - // else if (EXPONENT == 2 && MANTISSA == 1) - // fpx_linear_kernel<2, 1>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - // else if (EXPONENT == 3 && MANTISSA == 0) - // fpx_linear_kernel<3, 0>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - // else if (EXPONENT == 2 && MANTISSA == 0) - // fpx_linear_kernel<2, 0>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - - else - TORCH_CHECK(false, "FP", NBITS, " E", EXPONENT, "M", MANTISSA, " is not supported."); + DISPATCH_HALF_AND_BF16(_in_feats.scalar_type(), "fpx_linear_kernel", [&] { + auto weight = reinterpret_cast(_weights.data_ptr()); // weights is [OC, IC] but in FP6. + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto scales = reinterpret_cast(_scales.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + + // officially supported in Quant-LLM + if (EXPONENT == 3 && MANTISSA == 2) + fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 2 && MANTISSA == 2) + fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + + // experimental + else if (EXPONENT == 2 && MANTISSA == 3) + fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 3 && MANTISSA == 1) + fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + // else if (EXPONENT == 2 && MANTISSA == 1) + // fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + // else if (EXPONENT == 3 && MANTISSA == 0) + // fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + // else if (EXPONENT == 2 && MANTISSA == 0) + // fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + + else + TORCH_CHECK(false, "FP", NBITS, " E", EXPONENT, "M", MANTISSA, " is not supported."); + }); return _out_feats; } @@ -233,6 +261,4 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) { m.impl("torchao::quant_llm_linear", &fp_eXmY_linear_forward_cuda); } -} // namespace torchao - -#endif +} // namespace torchao \ No newline at end of file diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh index 600debd4a0..b008971647 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -46,13 +46,17 @@ * B: col major, FP16 * C: col major, FP16 */ - template + template __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, const half *B, OutputDataType* C, const size_t M_Global, const size_t N_Global, const size_t K_Global, int Split_K) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + static_assert(false, "Quant-LLM kernel: At least Turing generation (sm75) is required."); + // __trap(); // fails at runtime instead of compile time + #endif #ifdef DEBUG_MODE assert(K_Global%TilingConfig::TILE_K==0); assert(M_Global%TilingConfig::TILE_M==0); @@ -153,7 +157,8 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, uint32_t Scales_RPTR[4]; // 4 Registers per thread for Quantization Scales ExtractFromSharedToReg_Scales(Scales_RPTR, QuantScales + WARP_i*64); // Initializing the Software Pipeline: writing registers. //////////////////////////////////////////////////////////////////////////////////////////////// - initialize_mma_slice(a, b, AFrag_1BIT_SPTR, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); + constexpr bool USE_BF16 = std::is_same::value; + initialize_mma_slice(a, b, AFrag_1BIT_SPTR, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); // The outer loop. ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// #pragma unroll(1) for (size_t tile_id_k = 0; tile_id_k < NumIter; tile_id_k++) @@ -184,15 +189,15 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, #if __CUDA_ARCH__ >= 800 cp_async_group_commit(); #endif - core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag_2bit, read_SPTR_Frag_4bit are different for each WARP; read_SPTR is shared among WARPs - core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 2); - core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 3); + core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag_2bit, read_SPTR_Frag_4bit are different for each WARP; read_SPTR is shared among WARPs + core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 2); + core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 3); // Barriers and Synchronizations #if __CUDA_ARCH__ >= 800 cp_async_wait_group(); #endif __syncthreads(); - core_mma_slice(c, a, b, read2_SPTR_Frag_1bit, read2_SPTR_Frag_2bit, read2_SPTR_Frag_4bit, read2_SPTR, Scales_RPTR, 0); + core_mma_slice(c, a, b, read2_SPTR_Frag_1bit, read2_SPTR_Frag_2bit, read2_SPTR_Frag_4bit, read2_SPTR, Scales_RPTR, 0); // Updating global PTRs WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT/16; // 2KB/16=128 (1)/16: int4*+1 = char*+16 WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT/16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 @@ -212,7 +217,14 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, #pragma unroll for(size_t j=threadIdx.x%WARP_SIZE; j::value) BlockGlobalPTR[j+i*M_Global] = __float2half_rn(smem_CFrag[i][j]); - else BlockGlobalPTR[j+i*M_Global] = smem_CFrag[i][j]; + if constexpr (std::is_same::value) { + BlockGlobalPTR[j+i*M_Global] = __float2half_rn(smem_CFrag[i][j]); + } else if constexpr (std::is_same::value) { + #if __CUDA_ARCH__ >= 800 + BlockGlobalPTR[j+i*M_Global] = __float2bfloat16_rn(smem_CFrag[i][j]); + #endif + } else { + BlockGlobalPTR[j+i*M_Global] = smem_CFrag[i][j]; + } } } diff --git a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh index c0e7c1918a..d09d9b861d 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh @@ -36,16 +36,20 @@ #include #include +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 +#include +#endif #include #define REDUCTION_ELEMENT_PER_THREADBLOCK 256 #define HALF_PER_128BIT 8 -__global__ void SplitK_Reduction(half* C, float* Reduction_Workspace, size_t M_Global, size_t N_Global, int Split_K) +template +__global__ void SplitK_Reduction(T* C, float* Reduction_Workspace, size_t M_Global, size_t N_Global, int Split_K) { - half* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; + T* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; float* WARP_GPTR_R = Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; - half* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT; + T* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT; float* THREAD_GPTR_R = WARP_GPTR_R + threadIdx.x * HALF_PER_128BIT; // Initializing Thread-Local Results float Results[HALF_PER_128BIT]; @@ -58,6 +62,13 @@ __global__ void SplitK_Reduction(half* C, float* Reduction_Workspace, size_t M_G THREAD_GPTR_R += M_Global * N_Global; } // Writing to global memory - #pragma unroll - for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]); + if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]); + } else { // __nv_bfloat16> + #if __CUDA_ARCH__ >= 800 + #pragma unroll + for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2bfloat16_rn(Results[i]); + #endif + } } diff --git a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh index dededcf19d..57dd8cb53f 100644 --- a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh +++ b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh @@ -89,6 +89,7 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[ // MODIFICATION NOTE: to support MSVC, the function signature is changed from // MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b). +template __device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t * __restrict__ b) { @@ -114,15 +115,27 @@ MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); #else - asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{ %0, %1, %2, %3}," - "{ %4, %5, %6, %7 }," - "{ %8, %9 }," - "{ %10, %11, %12, %13 };" - : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), - "r"(b[0]), "r"(b[1]), - "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + if constexpr (USE_BF16) { + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32" + "{ %0, %1, %2, %3}," + "{ %4, %5, %6, %7 }," + "{ %8, %9 }," + "{ %10, %11, %12, %13 };" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + } else { // FP16 + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{ %0, %1, %2, %3}," + "{ %4, %5, %6, %7 }," + "{ %8, %9 }," + "{ %10, %11, %12, %13 };" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + } #endif } diff --git a/torchao/csrc/cuda/fp6_llm/utils_core.cuh b/torchao/csrc/cuda/fp6_llm/utils_core.cuh index 7a6cd36a46..4601edb397 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_core.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_core.cuh @@ -35,7 +35,7 @@ __device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], u } // MODIFICATION NOTE: to support MSVC, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. -template +template __device__ __forceinline__ void initialize_mma_slice(uint32_t (*a)[4], uint32_t (*b)[4], uint32_t* __restrict__ A_1BIT_SPTR_read, @@ -57,12 +57,12 @@ __device__ __forceinline__ void initialize_mma_slice(uint32_t ( if(USE_SEG_1BIT) CopyFromSharedToRegister_AFrag<1> (a_1bit, A_1BIT_SPTR_read, 0); if(USE_SEG_2BIT) CopyFromSharedToRegister_AFrag<2> (a_2bit, A_2BIT_SPTR_read, 0); if(USE_SEG_4BIT) CopyFromSharedToRegister_AFrag<4> (a_4bit, A_4BIT_SPTR_read, 0); - Dequant_32FP6_4Way(a, a_1bit, a_2bit, a_4bit, RPTR_Scales); // SIMT Dequant: dequantizing FPx to FP16 at register level, dequantizing a slice each time + Dequant_32FP6_4Way(a, a_1bit, a_2bit, a_4bit, RPTR_Scales); // SIMT Dequant: dequantizing FPx to FP16 at register level, dequantizing a slice each time B_FromSharedToReg(b, B_SPTR_read, 0); // Loading B from shared to registers } // MODIFICATION NOTE: to support MSVC, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. -template +template __device__ __forceinline__ void core_mma_slice(float c[][REG_PER_THREAD_C_TENSOR_16_16], uint32_t (*a)[4], uint32_t (*b)[4], @@ -98,13 +98,13 @@ __device__ __forceinline__ void core_mma_slice(float c[][REG #pragma unroll for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { if(TilingConfig::WARP_COL_MMA_TENSORS==1) { - MMA_FP16_M16N8K16( c_uint_ptr[i], a_read[i], b_read[0] ); + MMA_FP16_M16N8K16( c_uint_ptr[i], a_read[i], b_read[0] ); } else { #pragma unroll for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS/2; j++) { - MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a_read[i], b_read[j] ); - MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, a_read[i], b_read[j] + 2 ); // c+4; b+2 + MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a_read[i], b_read[j] ); + MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, a_read[i], b_read[j] + 2 ); // c+4; b+2 } } } @@ -116,7 +116,7 @@ __device__ __forceinline__ void core_mma_slice(float c[][REG if(USE_SEG_1BIT) CopyFromSharedToRegister_AFrag<1> (a_1bit, A_1bit_SPTR_read, slice_id); if(USE_SEG_2BIT) CopyFromSharedToRegister_AFrag<2> (a_2bit, A_2bit_SPTR_read, slice_id); if(USE_SEG_4BIT) CopyFromSharedToRegister_AFrag<4> (a_4bit, A_4bit_SPTR_read, slice_id); - Dequant_32FP6_4Way(a_write, a_1bit, a_2bit, a_4bit, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time + Dequant_32FP6_4Way(a_write, a_1bit, a_2bit, a_4bit, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time B_FromSharedToReg (b_write, B_SPTR_read, slice_id); // Loading B from shared to registers } diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh index 4c8c39603e..7fb77f9f8b 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -20,6 +20,9 @@ #include #include +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 +#include +#endif #include /* @@ -27,10 +30,10 @@ * Outputs: R1, R2 * Note: Simplified Exponent calculation is applied. */ -template +template __device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t *In, uint32_t *Out1, uint32_t *Out2) { // - constexpr int RIGHT_SHIFT = 5 - EXPONENT; + constexpr int RIGHT_SHIFT = USE_BF16 ? 8 - EXPONENT : 5 - EXPONENT; constexpr int MASK1 = 0x80000000; constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA; constexpr int MASK3 = MASK2 & 0x7fffffff; @@ -58,10 +61,29 @@ __device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scal return output; } +constexpr float power_of_two(int n) { + return (n == 0) ? 1.0f : 2.0f * power_of_two(n - 1); +} + +template +__device__ __forceinline__ uint32_t MultScale(uint32_t PackedBF16Pair, __nv_bfloat16 Scale) { +#if __CUDA_ARCH__ >= 800 + constexpr int BIAS_OFFSET = (int(1) << (8-1)) - (int(1) << (EXPONENT-1)); + constexpr float BIAS = power_of_two(BIAS_OFFSET); + __nv_bfloat16* BF16_1 = reinterpret_cast<__nv_bfloat16*>(&PackedBF16Pair); + __nv_bfloat16* BF16_2 = BF16_1 + 1; + uint32_t output; + __nv_bfloat16* output_bf16_ptr = reinterpret_cast<__nv_bfloat16*>(&output); + output_bf16_ptr[0] = __hmul( __hmul(*BF16_1,__float2bfloat16(BIAS)), Scale); + output_bf16_ptr[1] = __hmul( __hmul(*BF16_2,__float2bfloat16(BIAS)), Scale); + return output; +#endif +} + // MODIFICATION NOTE: to support MSVC // - u_int32_t __restrict__ Reg[][4] is changed to below. // - u_int32_t __restrict__ *read_RPTR_1bit is changed to below. similarly for read_RPTR_2bit and read_RPTR_4bit -template +template __device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg)[4], uint32_t * __restrict__ read_RPTR_1bit, uint32_t * __restrict__ read_RPTR_2bit, @@ -77,7 +99,8 @@ __device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg) uint32_t *Frag_PTR_1bit = read_RPTR_1bit; uint32_t *Frag_PTR_2bit = read_RPTR_2bit; uint32_t *Frag_PTR_4bit = read_RPTR_4bit; - half *Scale_RPTR = reinterpret_cast(Scales); + using scalar_t = typename std::conditional::type; + scalar_t *Scale_RPTR = reinterpret_cast(Scales); // Dequantizing 32 FP6, each Loop dequantizing 4 FP6 #pragma unroll(8) for(int i=0; i<8; i++) { @@ -104,15 +127,14 @@ __device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg) if(i%2==1) Frag_PTR_4bit++; else (*Frag_PTR_4bit) = (*Frag_PTR_4bit) << 4; } - // uint32_t out1, out2; - FPx_FP16_Cast_4Way(&Packed_FP6, &out1, &out2); + FPx_FP16_Cast_4Way(&Packed_FP6, &out1, &out2); // - *OutputRegs = MultScale(out1, Scale_RPTR[0] ); // Muliply FP16 scales + *OutputRegs = MultScale(out1, Scale_RPTR[0] ); // Muliply FP16/BF16 scales OutputRegs += 1; - *OutputRegs = MultScale(out2, Scale_RPTR[1]); // Muliply FP16 scales + *OutputRegs = MultScale(out2, Scale_RPTR[1]); // Muliply FP16/BF16 scales OutputRegs += 1; - // Updating offset for FP16 scales for every two iterations + // Updating offset for FP16/BF16 scales for every two iterations if(i%2==1) Scale_RPTR += 2; } diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 75d178fb50..34156697f2 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1612,13 +1612,13 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): y += bias.to(m.dtype) return y -def _linear_f16_act_floatx_weight_check(input_tensor, weight_tensor, bias): +def _linear_f16_bf16_act_floatx_weight_check(input_tensor, weight_tensor, bias): from torchao.dtypes.floatx import FloatxTensorCoreLayout return ( # input is native float32 tensor not is_traceable_wrapper_subclass(input_tensor) and input_tensor.is_floating_point() and - input_tensor.dtype == torch.float16 and + input_tensor.dtype in (torch.float16, torch.bfloat16) and # weight is floatx Tensor isinstance(weight_tensor, AffineQuantizedTensor) and isinstance(weight_tensor._layout, FloatxTensorCoreLayout) and @@ -1636,7 +1636,7 @@ def _linear_f16_act_floatx_weight_check(input_tensor, weight_tensor, bias): ) ) -def _linear_f16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): +def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): from torchao.dtypes.floatx import _SPLIT_K_MAP from torchao.ops import quant_llm_linear @@ -1644,7 +1644,7 @@ def _linear_f16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): weight = weight_tensor out_dim, in_dim = weight.shape - act_reshaped = act.view(-1, in_dim).half() + act_reshaped = act.view(-1, in_dim) # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py bsize = act_reshaped.shape[0] @@ -1804,7 +1804,7 @@ def _register_aqt_quantized_linear_dispatches(): (_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), - (_linear_f16_act_floatx_weight_check, _linear_f16_act_floatx_weight_impl), + (_linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl), (_linear_fp_act_int4_weight_sparse_marlin_check, _linear_fp_act_int4_weight_sparse_marlin_impl), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) diff --git a/torchao/dtypes/floatx/README.md b/torchao/dtypes/floatx/README.md index af770cf65c..16aec8362b 100644 --- a/torchao/dtypes/floatx/README.md +++ b/torchao/dtypes/floatx/README.md @@ -2,6 +2,8 @@ This is a FP16 x Floatx mixed matmul kernel optimized for io bound workloads per [FP6-LLM](https://arxiv.org/abs/2401.14112). The actual CUDA kernel is located under [csrc/cuda/fp6_llm/](../../csrc/cuda/fp6_llm/). This module provides helper functions to quantize FP32/FP16/BF16 weights to Floatx and integration with torchao API. +This kernel was originally designed for FP16, but was extended to work for BF16 by @tobiasvanderwerff. + ## Usage ```python @@ -11,7 +13,7 @@ from torchao.quantization import ( ) model = ... -model.half() # not necessary, but recommeneded to maintain accuracy +# model can have dtype float16 or bfloat16 # for generic Floatx EyMz where x = 1 + y + z # fp6 with ebits = 3 and mbits = 2 @@ -40,10 +42,10 @@ outputs = quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scales) # shape ``` **NOTE**: -- Since this kernel's computation dtype is FP16, it is recommended to convert the model to FP16 (instead of BF16) before applying quantization and use FP16 for activations. +- The kernel works for both FP16 and BF16 input activations - Only FP6 E3M2 and FP5 E2M2 are tested and enabled in the official repo. We additionally enable support for FP6 E2M3 and FP5 E3M1. -- On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion. See https://github.com/pytorch/ao/pull/223 for some microbenchmark results. -- FP6 is supported for >=SM80 (Ampere generation) as well as SM75 (Turing generation) GPUs. However, SM75 support requires manual compilation of the C++/CUDA extensions (see the installation instructions in the [README](https://github.com/pytorch/ao/blob/main/README.md#installation) for details). +- On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion. See https://github.com/pytorch/ao/pull/223 and https://github.com/pytorch/ao/pull/1147 for some microbenchmark results. +- The kernel is supported for >=SM80 (Ampere generation) as well as SM75 (Turing generation) GPUs. However, SM75 support requires manual compilation of the C++/CUDA extensions (see the installation instructions in the [README](https://github.com/pytorch/ao/blob/main/README.md#installation) for details). ## End-to-End benchmarks diff --git a/torchao/dtypes/floatx/floatx.py b/torchao/dtypes/floatx/floatx.py index f862106373..a4745e9315 100644 --- a/torchao/dtypes/floatx/floatx.py +++ b/torchao/dtypes/floatx/floatx.py @@ -128,11 +128,12 @@ def to_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, exp_bias = _ONES_TABLE[ebits - 1] max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits)) + dtype = tensor.dtype tensor = tensor.float() scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal tensor_floatx = _f32_to_floatx_unpacked(tensor / scale.view(-1, 1), ebits, mbits) tensor_tc_floatx = pack_tc_floatx(tensor_floatx, 1 + ebits + mbits) - return tensor_tc_floatx, scale.half() + return tensor_tc_floatx, scale.to(dtype) # inverse of _pack_tc_floatx() diff --git a/torchao/ops.py b/torchao/ops.py index 79c02dfd85..fa8ad7fe89 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -55,11 +55,11 @@ def _( splitK: int = 1, ) -> Tensor: torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D") - torch._check(_in_feats.dtype is torch.float16, lambda: f"weight must be FP16, got {_in_feats.dtype}") + torch._check(_in_feats.dtype in (torch.float16, torch.bfloat16), lambda: f"weight must be FP16 or BF16, got {_in_feats.dtype}") torch._check(_weights.dim() == 2, lambda: f"weight should be a 2d tensor, got {_weights.dim()}D") torch._check(_weights.dtype is torch.uint8, lambda: f"weight must be UINT8, got {_weights.dtype}") torch._check(_scales.dim() == 1, lambda: f"scale should be a 2d tensor, got {_scales.dim()}D") - torch._check(_scales.dtype is torch.float16, lambda: f"scale must be FP16, got {_scales.dtype}") + torch._check(_scales.dtype in (torch.float16, torch.bfloat16), lambda: f"scale must be FP16 or BF16, got {_scales.dtype}") BS, IC = _in_feats.shape OC, _ = _weights.shape diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index dfd3bcaad8..a8ac533740 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -1003,9 +1003,10 @@ def choose_qparams_affine_floatx(tensor: torch.Tensor, ebits: int, mbits: int) - exp_bias = _ONES_TABLE[ebits - 1] max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits)) + dtype = tensor.dtype tensor = tensor.float() scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal - return scale.half() + return scale.to(dtype) def quantize_affine_floatx(tensor: torch.Tensor, scale: torch.Tensor, ebits: int, mbits: int) -> torch.Tensor: """Quantizes the float32 high precision floating point tensor to low precision floating point number and