Skip to content

Commit 7734f79

Browse files
Add FP16Act-FP6Weight Linear (#223)
* add files from fp6_llm * try to port weight packing first * rename * rename fp6 weight packing * add fp16act_fp6weight_linear * fix function def * delete duplicate file * move weight quant file * rename * add pytorch interface for fp6 weight dequant * add fake_fp6 to fp6 * move weight_quant to csrc/cuda due to cuda_fp16.h dependency * add fake_fp6_to_fp6 test * add test for fp16act_fp6weight_linear * add test for fp6_weight_dequant * Fp6WeightOnlyQuantizedLinearWeight (not working yet) * skip some tests, since the functions are not built w/o CUDA * add the original test * implement transpose and clone so that F.linear will work * remove print * remove dequantize * add notes and some rename * typo * small cleanup * improve tensor subclass and add test (which is failing for torch-compile) * add note * add note * add qtorch as dev requirement * update error message * add __repr__ and fix transposed issue * add fp6 perplexity test * rename variables * remove subclass * add correctness test * remove unwanted changes * add apache 2.0 notice * add benchmark script * add note about FP6 kernel * relax tolerance --------- Co-authored-by: Mark Saroufim <[email protected]>
1 parent ad12663 commit 7734f79

17 files changed

+1882
-2
lines changed

benchmarks/benchmark_fp6.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import torch
2+
import torchao
3+
from torch.utils.benchmark import Timer
4+
import pandas as pd
5+
from tqdm import tqdm
6+
7+
8+
def benchmark(m, k, n, splitK):
9+
# Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t.
10+
fp6_weight = torch.randint(4294967295, (n, k // 16 * 3)).to(torch.int)
11+
fp16_scale = torch.rand(n).half() + 0.5
12+
fp16_activation = torch.rand(m, k).half() + 0.5
13+
14+
fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight)
15+
act_cuda = fp16_activation.cuda()
16+
weight_cuda = fp6_weight_packed.cuda()
17+
scale_cuda = fp16_scale.cuda()
18+
19+
# need to do this since Timer cannot see torchao
20+
def fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK):
21+
return torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK)
22+
23+
fp6_output = fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK)
24+
25+
fp6_measurement = Timer(
26+
stmt="fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK)",
27+
globals=locals(),
28+
).blocked_autorange()
29+
30+
fp16_weight = torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale).cuda()
31+
fp16_output = act_cuda @ fp16_weight.T
32+
33+
fp16_measurement = Timer(
34+
stmt="act_cuda @ fp16_weight.T",
35+
globals=locals(),
36+
).blocked_autorange()
37+
38+
# follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py
39+
# doesn't seem to be the right way to check for correctness
40+
correct = (fp6_output - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3
41+
42+
return {
43+
"m": m,
44+
"k": k,
45+
"n": n,
46+
"fp6_latency (ms)": fp6_measurement.median * 1000,
47+
"fp16_latency (ms)": fp16_measurement.median * 1000,
48+
"speedup (d/s)": fp16_measurement.median / fp6_measurement.median,
49+
"correct": correct,
50+
}
51+
52+
53+
if __name__ == "__main__":
54+
# from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/run.sh
55+
k_vals = (8192, 8192, 8192, 28672)
56+
n_vals = (10240, 8192, 57344, 8192)
57+
58+
results = []
59+
60+
# splitK can be tuned based on m, k, n
61+
for m, splitK_vals in tqdm([
62+
(1, (5, 6, 7, 6)),
63+
(2, (5, 6, 7, 6)),
64+
(4, (5, 6, 7, 6)),
65+
(8, (5, 6, 7, 6)),
66+
# (16, (5, 6, 7, 6)),
67+
# (64, (5, 6, 7, 6)),
68+
# (128, (5, 3, 3, 3)),
69+
# (256, (4, 3, 2, 3)),
70+
# (512, (2, 5, 2, 4)),
71+
(1024, (1, 2, 1, 2)),
72+
(2048, (1, 1, 1, 1)),
73+
(4096, (1, 1, 1, 1)),
74+
# (8192, (1, 1, 1, 1)),
75+
# (16384, (1, 1, 1, 1)),
76+
]):
77+
for n, k, splitK in zip(n_vals, k_vals, splitK_vals):
78+
results.append(benchmark(m, n, k, splitK))
79+
80+
df = pd.DataFrame(results)
81+
df.to_csv("fp6_benchmark_results.csv", index=False)
82+
print(df.to_markdown(index=False))

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ def get_extensions():
6363

6464
this_dir = os.path.dirname(os.path.curdir)
6565
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
66-
sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp")))
66+
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))
6767

6868
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
69-
cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu")))
69+
cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True))
7070

7171
if use_cuda:
7272
sources += cuda_sources

test/test_ops.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torchao
55
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
66
import unittest
7+
from parameterized import parameterized
78

89

910
# torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...):
@@ -42,6 +43,98 @@ def test_nms(self):
4243
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
4344
opcheck(torch.ops.torchao.nms, (boxes, scores, iou), test_utils=test_utils)
4445

46+
def _create_fp6_inputs(self, BS: int, OC: int, IC: int):
47+
# Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t.
48+
fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int)
49+
fp16_scale = torch.rand(OC).half() + 0.5
50+
fp16_activation = torch.rand(BS, IC).half() + 0.5
51+
return fp6_weight, fp16_scale, fp16_activation
52+
53+
def test_prepack_fp6_weight(self):
54+
OC = 256
55+
IC = 256
56+
fp6_weight, _, _ = self._create_fp6_inputs(0, OC, IC)
57+
58+
# smoke test
59+
torchao.ops.prepack_fp6_weight(fp6_weight)
60+
61+
# comprehensive testing
62+
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
63+
opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils)
64+
65+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
66+
def test_fp16_to_fp6(self):
67+
OC = 256
68+
IC = 256
69+
70+
# in this fp6, we use 3 bits for exponent and 2 bits for mantissa
71+
# also, we don't have nan/inf
72+
fp6_absmax = 28.0 # 2 ** (0b111 - 0b011) * (1 + 0.5 + 0.25), where E=111, M=11
73+
fp6_absmin = 0.0625 # 2 ** (-0b010) * 0.25, where E=000, M=01 (subnormal number)
74+
fp16_weight = torch.randn((OC, IC), dtype=torch.float16)
75+
fp16_weight.clip_(-fp6_absmax, fp6_absmax)
76+
fp16_weight[fp16_weight.abs() < fp6_absmin] = 0
77+
78+
# smoke test
79+
torchao.ops.fp16_to_fp6(fp16_weight)
80+
81+
# comprehensive testing
82+
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
83+
opcheck(torch.ops.torchao.fp16_to_fp6, (fp16_weight,), test_utils=test_utils)
84+
85+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
86+
def test_fp16act_fp6weight_linear(self):
87+
BS = 2
88+
OC = 256
89+
IC = 256
90+
splitK = 1
91+
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC)
92+
93+
fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight)
94+
act_cuda = fp16_activation.cuda()
95+
weight_cuda = fp6_weight_packed.cuda()
96+
scale_cuda = fp16_scale.cuda()
97+
98+
# smoke test
99+
torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK)
100+
101+
# comprehensive testing
102+
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
103+
opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (act_cuda, weight_cuda, scale_cuda, splitK), test_utils=test_utils)
104+
105+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
106+
def test_fp6_weight_dequant(self):
107+
OC = 256
108+
IC = 256
109+
fp6_weight, fp16_scale, _ = self._create_fp6_inputs(0, OC, IC)
110+
111+
# smoke test
112+
torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale)
113+
114+
# comprehensive testing
115+
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
116+
opcheck(torch.ops.torchao.fp6_weight_dequant, (fp6_weight, fp16_scale), test_utils=test_utils)
117+
118+
# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py
119+
@parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
120+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
121+
def test_fp6_matmul_correctness(self, BS, OC, IC, splitK):
122+
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC)
123+
124+
fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight)
125+
act_cuda = fp16_activation.cuda()
126+
weight_cuda = fp6_weight_packed.cuda()
127+
scale_cuda = fp16_scale.cuda()
128+
129+
results_fp6 = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK)
130+
131+
fp16_weight = torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale).cuda()
132+
results_fp16 = act_cuda @ fp16_weight.T
133+
134+
error = (results_fp6 - results_fp16).abs()
135+
relative_error = error / results_fp16.abs()
136+
assert relative_error.mean() < 1e-2
137+
45138

46139
if __name__ == "__main__":
47140
unittest.main()

torchao/csrc/cuda/fp6_llm/configs.h

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// Copyright 2024 FP6-LLM authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/configs.h
16+
17+
#ifndef CONFIGS_H
18+
#define CONFIGS_H
19+
20+
//#define DEBUG_MODE
21+
#define PIPELINE_LEVEL_GMEM 2
22+
#define PIPELINE_LEVEL_SMEM 2 // only support 2
23+
24+
/************************ Hardware Parameters ************************/
25+
#define WARP_SIZE 32
26+
#define REG_BIT_WIDTH 32
27+
// mma: M=16 K=16 N=8
28+
#define MMA_8 8
29+
#define MMA_16 16
30+
// for memory access
31+
#define THREAD_OPT_ACCESS_BIT_WIDTH_128 128 // LDS.128, cp_async.128, ...
32+
#define BIT_WIDTH_PER_HALF 16 // Half precision: FP16
33+
34+
/******************** Register Allocation For GEMM ********************/
35+
#define REG_PER_THREAD_C_TENSOR_16_16 8 // 8 for FP32 Accumulation
36+
/********************** Memory Padding Parameters **********************/
37+
// Eliminating bank-conflict
38+
#define PADDING_BYTES_16 16 // Padding 16 bytes each column
39+
#define PADDING_SHARED_MEM_FOR_B_8 8 // Padding 8 half each column, during CopyFromGlobalToShared() for B
40+
#define PADDING_SHARED_MEM_FOR_C_4 4 // Padding 4 float each column, during StoreToSharedMemoryFromRegister() for C
41+
/************************* WARP Tiling part-1 *************************/
42+
#define WARP_ROW_MMA_TENSORS 4
43+
#define WARP_M (WARP_ROW_MMA_TENSORS * MMA_16) // 64
44+
#define WARP_K_MMA_TENSORS 4
45+
#define WARP_K (WARP_K_MMA_TENSORS * MMA_16) // 64
46+
template<int BLOCK_ROW_WARPS_, int BLOCK_COL_WARPS_, int WARP_COL_MMA_TENSORS_>
47+
struct TilingConfig {
48+
// Depending on "n" dimension of the GEMM
49+
static constexpr int BLOCK_ROW_WARPS = BLOCK_ROW_WARPS_;
50+
static constexpr int BLOCK_COL_WARPS = BLOCK_COL_WARPS_;
51+
static constexpr int WARP_COL_MMA_TENSORS = WARP_COL_MMA_TENSORS_;
52+
/************************* WARP Tiling part-2 *************************/
53+
static constexpr int WARP_N = WARP_COL_MMA_TENSORS * MMA_8;
54+
/*************************Thread Block Tiling *************************/
55+
static constexpr int TILE_M = WARP_M * BLOCK_ROW_WARPS;
56+
static constexpr int TILE_N = MMA_8 * WARP_COL_MMA_TENSORS * BLOCK_COL_WARPS;
57+
static constexpr int TILE_K = WARP_K;
58+
/********************** #Thread per Thread Block **********************/
59+
static constexpr int BLOCK_WARPS = BLOCK_ROW_WARPS * BLOCK_COL_WARPS;
60+
static constexpr int BLOCK_THREADS = BLOCK_WARPS * WARP_SIZE;
61+
/******************************* Others *******************************/
62+
static constexpr int SMEM_SIZE_B_TILE = TILE_N * (TILE_K + PADDING_BYTES_16) * 2 * PIPELINE_LEVEL_GMEM; // sizeof(half)=2, doubleBuffer=2
63+
static constexpr int SMEM_SIZE_C_TILE = TILE_N * (TILE_M + PADDING_BYTES_16) * 4; // sizeof(float)=4
64+
};
65+
66+
/************************ General Config for FP6-LLM **********************/
67+
#define WEIGHT_FRAG1_BIT_WIDTH 2
68+
#define WEIGHT_FRAG2_BIT_WIDTH 4
69+
#define WEIGHT_BIT_WIDTH (WEIGHT_FRAG1_BIT_WIDTH+WEIGHT_FRAG2_BIT_WIDTH) // 6
70+
//#define QUANT_GROUP_SIZE_DIVIDED_BY_64 4 // QuantGroupSize: 4*64 = 256
71+
/*************************** 64*64 Weghts of A WARP *************************/
72+
#define WEIGHT_PER_UNIT (WARP_M*WARP_K) // 64*64
73+
#define SMEM_SIZE_IN_BYTES_PER_WARP_A1 (WEIGHT_PER_UNIT*WEIGHT_FRAG1_BIT_WIDTH/8) // 1024 Bytes #doubleBuffer not takedn into consideration
74+
#define SMEM_SIZE_IN_BYTES_PER_WARP_A2 (WEIGHT_PER_UNIT*WEIGHT_FRAG2_BIT_WIDTH/8) // 2048 Bytes #doubleBuffer not takedn into consideration
75+
#define SMEM_SIZE_A1_TILE (SMEM_SIZE_IN_BYTES_PER_WARP_A1*4*PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 12 KB; double buffer for 2-level pipeline A= 8 KB.
76+
#define SMEM_SIZE_A2_TILE (SMEM_SIZE_IN_BYTES_PER_WARP_A2*4*PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 24 KB; double buffer for 2-level pipeline A= 16 KB.
77+
/******************** Gloabl Memory Layout For QUANTIZED DATA ******************/
78+
#define NUM_INT4_PER_UNIT_2BIT_FRAG (WEIGHT_PER_UNIT*WEIGHT_FRAG1_BIT_WIDTH/128) // 64
79+
#define NUM_INT4_PER_UNIT_4BIT_FRAG (WEIGHT_PER_UNIT*WEIGHT_FRAG2_BIT_WIDTH/128) // 128
80+
/******************** Register Allocation For QUANTIZED DATA ******************/
81+
#define WEIGHT_PER_THREAD (WEIGHT_PER_UNIT/WARP_SIZE) // 128
82+
#define REG_PER_THREAD_2BIT_FRAG (WEIGHT_PER_THREAD/REG_BIT_WIDTH*2) // 8
83+
#define REG_PER_THREAD_4BIT_FRAG (WEIGHT_PER_THREAD/REG_BIT_WIDTH*4) // 16
84+
/******************** Register Allocation For QUANT Scales ******************/
85+
#define WARP_REG_QUANT_SCALE 4 // 8 rows per thread -> 8 FP16 scales -> 4 registers
86+
#define WARP_REG_QUANT_SCALE_DISTRIBUTED 1 // T0-T3, T4-T7, ..., T28-T31 share the same scales, using shfl to get all the scales for each thread
87+
88+
89+
90+
#endif // CONFIGS_H

0 commit comments

Comments
 (0)