From 36add71be2ae3c30b003e26a20f486e6de2c7bad Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 22 May 2024 16:32:35 +0800 Subject: [PATCH 01/22] add annotation --- .../fp6_llm/weight_prepacking_annotated.cpp | 300 ++++++++++++++++++ 1 file changed, 300 insertions(+) create mode 100644 torchao/csrc/fp6_llm/weight_prepacking_annotated.cpp diff --git a/torchao/csrc/fp6_llm/weight_prepacking_annotated.cpp b/torchao/csrc/fp6_llm/weight_prepacking_annotated.cpp new file mode 100644 index 0000000000..3081629b39 --- /dev/null +++ b/torchao/csrc/fp6_llm/weight_prepacking_annotated.cpp @@ -0,0 +1,300 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h + +#include +#include +#include + +using namespace std; + +void Padding_8_FP6_To_8_Bytes(unsigned char Padded_FP6[], unsigned char* FP6_Array) // padding 0 to the lowerest bit location +{ + Padded_FP6[0] = FP6_Array[0] & 0xfc; + Padded_FP6[1] = (FP6_Array[0]<<6) | ((FP6_Array[1]>>2) & 0xfc); + Padded_FP6[2] = (FP6_Array[1]<<4) | ((FP6_Array[2]>>4) & 0xfc ); + Padded_FP6[3] = FP6_Array[2]<<2; + Padded_FP6[4] = FP6_Array[3] & 0xfc; + Padded_FP6[5] = (FP6_Array[3]<<6) | ((FP6_Array[4]>>2) & 0xfc); + Padded_FP6[6] = (FP6_Array[4]<<4) | ((FP6_Array[5]>>4) & 0xfc); + Padded_FP6[7] = FP6_Array[5]<<2; +} + +unsigned char Extract_2_Bits_From_4_PaddedFP6(unsigned char B1, unsigned char B2, unsigned char B3, unsigned char B4) +{ + unsigned char out; + out = (B1&0xc0) | ( (B2&0xc0) >> 2 ) | ( (B3&0xc0) >> 4 ) | ( (B4&0xc0) >> 6 ); + return out; +} + +unsigned char Extract_4_Bits_From_2_PaddedFP6(unsigned char B1, unsigned char B2) // The highest two bits are already extracted by Extract_2_Bits_From_4_PaddedFP6(); +{ + unsigned char out; + out = ( (B1<<2) & 0xf0 ) | ( (B2>>2) & 0x0f ); + return out; +} + +// dealing with 4 1*8 blocks of FP6 +void Assign_32_FP6_To_4_Thread(vector Seg_2bit[], vector Seg_4bit[], unsigned char* PTR_1, unsigned char* PTR_2, unsigned char* PTR_3, unsigned char* PTR_4) +{ + // unpack 8 elemnts in a row of the 8x8 block + unsigned char Padded_8_FP8[4][8]; + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[0], PTR_1); // row of block1 + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[1], PTR_2); // row of block2 + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[2], PTR_3); // row of block3 + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[3], PTR_4); // row of block4 + + // + unsigned char Seg1_Byte1_T[4]; + unsigned char Seg1_Byte2_T[4]; + unsigned char Seg2_Byte1_T[4]; + unsigned char Seg2_Byte2_T[4]; + unsigned char Seg2_Byte3_T[4]; + unsigned char Seg2_Byte4_T[4]; + + // what is this 4? -> 2 elem of each row + for(int t=0; t<4; t++) + { + Seg1_Byte1_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[0][0+t*2], // 2 elem of row1 and 2 elem of row2 + Padded_8_FP8[0][1+t*2], + Padded_8_FP8[1][0+t*2], + Padded_8_FP8[1][1+t*2]); + + Seg1_Byte2_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[2][0+t*2], // 2 elem of row3 and 2 elem of row4 + Padded_8_FP8[2][1+t*2], + Padded_8_FP8[3][0+t*2], + Padded_8_FP8[3][1+t*2]); + + Seg2_Byte1_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[0][0+t*2], Padded_8_FP8[0][1+t*2]); // 2 elem of row1 + Seg2_Byte2_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[1][0+t*2], Padded_8_FP8[1][1+t*2]); // 2 elem of row2 + Seg2_Byte3_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[2][0+t*2], Padded_8_FP8[2][1+t*2]); // 2 elem of row3 + Seg2_Byte4_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[3][0+t*2], Padded_8_FP8[3][1+t*2]); // 2 elem of row4 + } + + for(int t=0; t<4; t++) + { + Seg_2bit[t].push_back(Seg1_Byte1_T[t]); + Seg_2bit[t].push_back(Seg1_Byte2_T[t]); + + Seg_4bit[t].push_back(Seg2_Byte1_T[t]); + Seg_4bit[t].push_back(Seg2_Byte2_T[t]); + Seg_4bit[t].push_back(Seg2_Byte3_T[t]); + Seg_4bit[t].push_back(Seg2_Byte4_T[t]); + } + return; +} + +void BitInterleaving_2bit(unsigned char* PTR_4Bytes) +{ + unsigned int *PTR_UINT = reinterpret_cast(PTR_4Bytes); + unsigned int input = *PTR_UINT; + // transpose and swap rows? + // 1, 2, 3, 4, + // 5, 6, 7, 8, + // 9,10,11,12, + //13,14,15,16 + + // 2, 6,10,14, + // 4, 8,12,16 + // 1, 5, 9,13, + // 3, 7,11,15 + + // 4 bytes -> 4x4 values, since each byte has 4 values + + //int order_2bit[16] = {1,5,9,13,3,7,11,15,2,6,10,14,4,8,12,16}; // pre-defined order for bit-interleaving in FP6-LLM + int order_2bit[16] = {2,6,10,14,4,8,12,16,1,5,9,13,3,7,11,15}; // pre-defined order for bit-interleaving in FP6-LLM + unsigned int Frags_2bit[16]; // The highest 2 bits are used to store the extracted fragments. + + // each 2-bit is an FP6 value + for(int i=0; i<16; i++) + Frags_2bit[i] = ( input << 2*(order_2bit[i]-1) ) & 0xc0000000; + + unsigned int output = 0x00000000; + for(int i=0; i<16; i++) + output |= ( Frags_2bit[i] >> (i*2) ); + // + *PTR_UINT = output; +} + +void BitInterleaving_4bit(unsigned char* PTR_4Bytes) +{ + unsigned int *PTR_UINT = reinterpret_cast(PTR_4Bytes); + unsigned int input = *PTR_UINT; + + // transpose and swap rows + // 1, 2, 3, 4 + // 5, 6, 7, 8 + + // 2, 6, + // 4, 8 + // 1, 5, + // 3, 7 + + //int order_4bit[8] = {1,5,3,7,2,6,4,8}; // pre-defined order for bit-interleaving in FP6-LLM + int order_4bit[8] = {2,6,4,8,1,5,3,7}; // pre-defined order for bit-interleaving in FP6-LLM + unsigned int Frags_4bit[8]; // The highest4 bits are used to store the extracted fragments. + + // each 4-bit is an FP6 value + for(int i=0; i<8; i++) + Frags_4bit[i] = ( input << 4*(order_4bit[i]-1) ) & 0xf0000000; + // + unsigned int output = 0x00000000; + for(int i=0; i<8; i++) + output |= ( Frags_4bit[i] >> (i*4) ); + // + *PTR_UINT = output; +} + +/* + * Inputs: + * (1) unsigned char Weight_6bit [M*K*6/8] + * Outputs: + * (1) unsigned char Weight_2bit [M*K*2/8] + * (2) unsigned char Weight_4bit [M*K*4/8] + * + * Assumption: Weight_6bit, Weight_2bit, Weight_4bit all stored continuously in row-major. + * 8 FP6 = 6 Bytes + * 8 FP4 = 4 Bytes + * 8 FP2 = 2 Bytes + */ +void weight_matrix_prepacking(int* packed_weights, int *FP6Weights, size_t M, size_t K) +{ + assert(M % 64 == 0); + assert(K % 64 == 0); + // + unsigned char* Weight_6bit = reinterpret_cast(FP6Weights); + unsigned char* Weight_2bit = reinterpret_cast(packed_weights); + unsigned char* Weight_4bit = Weight_2bit + M*K*2/8; + // + vector A_Segment_2bit[32]; // 32 vector objects + vector A_Segment_4bit[32]; + // + size_t BytesPerRow = K*6/8; + + // Pass-1: (1) 2+4 split; (2) assign weights to 32 threads. + + // (M, N) + // -> (M/64, 64, N/16, 16) - (64x16) blocks - slice + // -> (M/64, N/16, 64, 16) + + // -> (M/64, N/16, n_slice, 16, 16) - (16x16) blocks - n_slice=4 + + // -> (M/64, N/16, n_slice, subblock_row, 8, subblock_col, 8) - (8x8) sub-block + // -> (M/64, N/16, n_slice, subblock_col, subblock_row, 8, 8) - note subblock order is swapped + + // -> (M/64, N/16, n_slice, subblock_col, subblock_row, 32, 2) - 2 elements + // 2-bit and 4-bit weights + + // -> (32, M/64, N/16, n_slice, subblock_col, subblock_row, 2) + + for (size_t i = 0; i < M / 64; i++) // + { + for (size_t j = 0; j < K / 16; j++) + { + for(size_t k=0; k<64/16; k++) + { + // a 16x16 block + size_t row = i*64 + k*16; + size_t col = j*16; + + // sub-divie 16x16 blocks into 4 8x8 blocks + // |--------|--------| + // |ptr1 |ptr3 | + // | | | + // |--------|--------| + // |ptr2 |ptr4 | + // | | | + // |--------|--------| + unsigned char* StartPTR_1 = Weight_6bit + row*BytesPerRow + col*6/8; // start of the block + unsigned char* StartPTR_2 = StartPTR_1 + 8*BytesPerRow; // 8 rows below + unsigned char* StartPTR_3 = StartPTR_1 + 8*6/8; // 8 cols to the right + unsigned char* StartPTR_4 = StartPTR_2 + 8*6/8; // 8 cols to the right + + // Dealing with each 16*16 blocks then... + // do 1 row (of 8x8 block) at a time + // in an 8x8 block, split into 2 elems at a time (32 groups) + for(int l=0; l<8; l++) + Assign_32_FP6_To_4_Thread(&A_Segment_2bit[l*4], // 0, 4, 8, 16, ..., 28 + &A_Segment_4bit[l*4], + StartPTR_1+l*BytesPerRow, + StartPTR_2+l*BytesPerRow, + StartPTR_3+l*BytesPerRow, + StartPTR_4+l*BytesPerRow); + } + } + } + // Verifying the length of 2_bit segments and 4_bit segments + size_t BytesPerThread_2bit = M*K*2/8/32; + size_t BytesPerThread_4bit = M*K*4/8/32; + for(int i=0; i<32; i++) + { + assert(A_Segment_2bit[i].size()==BytesPerThread_2bit); + assert(A_Segment_4bit[i].size()==BytesPerThread_4bit); + } + + // Pass-2: Optimizing coleasced global memory access + // weight transpose + // (32, I, 4) -> (I, 32, 4) -> reverse last dim + for(size_t i=0; i (I, 32, 4) -> reverse last dim + for(size_t i=0; i +#include + +namespace torchao { + +/* + * Weight prepacking (Pytorch interface). + * [Input & Output] + * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. + * [Output] + * packed_tensor: int tensor of shape [OC, IC // 16 * 3]; + */ +at::Tensor weight_matrix_prepacking_cpu(at::Tensor fp6_tensor) +{ + size_t OC = fp6_tensor.size(0); + size_t IC = fp6_tensor.size(1); + TORCH_CHECK(IC % 3 == 0, "Expect packed input dim % 3 == 0, but receive ", IC, " instead."); + IC = IC * 16 / 3; + TORCH_CHECK((OC % 256 == 0) && (IC % 64 == 0), "Expect output dim % 256 == 0 and input dim % 64 == 0, but receive ", OC, " and ", IC, " instead."); + auto packed_tensor = at::empty_like(fp6_tensor); + auto packed_tensor_ptr = reinterpret_cast(packed_tensor.data_ptr()); + auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); + weight_matrix_prepacking(packed_tensor_ptr, fp6_tensor_ptr, OC, IC); + return packed_tensor; +} + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::prepack_fp6_weight", &weight_matrix_prepacking_cpu); +} + +} From 58bcf2fb0719a51905f6a53a98167766c1031443 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 25 May 2024 20:27:38 +0800 Subject: [PATCH 02/22] add weight splitting logic --- ...pp => weight_prepacking_annotated.cpp.bak} | 27 ++---- torchao/dtypes/tc_float6_e3m2.py | 83 +++++++++++++++++++ 2 files changed, 91 insertions(+), 19 deletions(-) rename torchao/csrc/fp6_llm/{weight_prepacking_annotated.cpp => weight_prepacking_annotated.cpp.bak} (97%) create mode 100644 torchao/dtypes/tc_float6_e3m2.py diff --git a/torchao/csrc/fp6_llm/weight_prepacking_annotated.cpp b/torchao/csrc/fp6_llm/weight_prepacking_annotated.cpp.bak similarity index 97% rename from torchao/csrc/fp6_llm/weight_prepacking_annotated.cpp rename to torchao/csrc/fp6_llm/weight_prepacking_annotated.cpp.bak index 3081629b39..d73c361b50 100644 --- a/torchao/csrc/fp6_llm/weight_prepacking_annotated.cpp +++ b/torchao/csrc/fp6_llm/weight_prepacking_annotated.cpp.bak @@ -100,18 +100,13 @@ void BitInterleaving_2bit(unsigned char* PTR_4Bytes) { unsigned int *PTR_UINT = reinterpret_cast(PTR_4Bytes); unsigned int input = *PTR_UINT; - // transpose and swap rows? - // 1, 2, 3, 4, - // 5, 6, 7, 8, - // 9,10,11,12, - //13,14,15,16 - - // 2, 6,10,14, - // 4, 8,12,16 - // 1, 5, 9,13, - // 3, 7,11,15 // 4 bytes -> 4x4 values, since each byte has 4 values + // LAYOUT + // [12,13,14,15] high address + // [ 8, 9,10,11] + // [ 4, 5, 6, 7] + // [ 0, 1, 2, 3] low address //int order_2bit[16] = {1,5,9,13,3,7,11,15,2,6,10,14,4,8,12,16}; // pre-defined order for bit-interleaving in FP6-LLM int order_2bit[16] = {2,6,10,14,4,8,12,16,1,5,9,13,3,7,11,15}; // pre-defined order for bit-interleaving in FP6-LLM @@ -133,15 +128,6 @@ void BitInterleaving_4bit(unsigned char* PTR_4Bytes) unsigned int *PTR_UINT = reinterpret_cast(PTR_4Bytes); unsigned int input = *PTR_UINT; - // transpose and swap rows - // 1, 2, 3, 4 - // 5, 6, 7, 8 - - // 2, 6, - // 4, 8 - // 1, 5, - // 3, 7 - //int order_4bit[8] = {1,5,3,7,2,6,4,8}; // pre-defined order for bit-interleaving in FP6-LLM int order_4bit[8] = {2,6,4,8,1,5,3,7}; // pre-defined order for bit-interleaving in FP6-LLM unsigned int Frags_4bit[8]; // The highest4 bits are used to store the extracted fragments. @@ -199,6 +185,9 @@ void weight_matrix_prepacking(int* packed_weights, int *FP6Weights, size_t M, si // -> (32, M/64, N/16, n_slice, subblock_col, subblock_row, 2) + // 2-bit weight: 4 values are packed into 1 uint8 + // 4-bit weight: 2 values are packed into 1 uint8 + for (size_t i = 0; i < M / 64; i++) // { for (size_t j = 0; j < K / 16; j++) diff --git a/torchao/dtypes/tc_float6_e3m2.py b/torchao/dtypes/tc_float6_e3m2.py new file mode 100644 index 0000000000..93e8763238 --- /dev/null +++ b/torchao/dtypes/tc_float6_e3m2.py @@ -0,0 +1,83 @@ +# https://arxiv.org/abs/2401.14112 + +import torch +from torch import Tensor + +# NOTE: This implementation requires FP32 denormal numbers to be handled correctly. +# On CPU, denormal numbers might be flushed to zero for performance gain (FTZ and DAZ flags). +def _to_float6_e3m2_pt(tensor: Tensor, packed: bool = False) -> Tensor: + tensor = tensor.float() + + # correct exponent bias. this also handles subnormal numbers correctly + tensor = tensor * 2.0 ** (-127 + 3) + bits = tensor.view(torch.int32) + + sign = ((bits >> 31) & 0x1) << 5 + exp_and_man = (bits >> 21) & 0x1F + result = sign | exp_and_man + + # round to nearest even + remainder = bits & 0x1F_FFFF # truncated mantissa bits + do_round_up = (remainder > 0x10_0000) | ((remainder == 0x10_0000) & ((result & 1) == 1)) + result = torch.where(do_round_up, result + 1, result) + result = result.to(torch.uint8) + + if not packed: + return result + + # bit packing + val0, val1, val2, val3 = result.unflatten(-1, (-1, 4)).unbind(-1) + bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 + bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222 + bits2 = (val2 << 6) | (val3); # 2233 3333 + return torch.stack([bits0, bits1, bits2], dim=-1).flatten(-2) + + +def pack_2bit(x: Tensor) -> Tensor: + return (x[..., ::4] << 6) | (x[..., 1::4] << 4) | (x[..., 2::4] << 2) | x[..., 3::4] + + +def unpack_2bit(x: Tensor) -> Tensor: + return torch.stack([x >> 6, (x >> 4) & 0b11, (x >> 2) & 0b11, x & 0b11], dim=-1).flatten(-2) + + +def pack_4bit(x: Tensor) -> Tensor: + return (x[..., ::2] << 4) | x[..., 1::2] + + +def unpack_4bit(x: Tensor) -> Tensor: + return torch.stack([x >> 4, x & 0b1111], dim=-1).flatten(-2) + + +def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: + assert tensor.ndim == 2 + M, N = tensor.shape + assert (M % 64 == 0) and (N % 64 == 0) + + tensor_fp6 = _to_float6_e3m2_pt(tensor) + + tensor_fp6 = tensor_fp6.view(M // 64, 4, 2, 8, N // 16, 2, 8) + tensor_fp6 = tensor_fp6.permute(0, 4, 1, 5, 2, 3, 6) + tensor_fp6 = tensor_fp6.reshape(-1, 32, 2) + tensor_fp6 = tensor_fp6.permute(1, 0, 2) + tensor_fp6 = tensor_fp6.flatten() + + tensor_2bit = pack_2bit((tensor_fp6 >> 4) & 0b11) + tensor_4bit = pack_4bit(tensor_fp6 & 0b1111) + + tensor_2bit = tensor_2bit.view(32, -1, 4).permute(1, 0, 2).flip(2) + tensor_4bit = tensor_4bit.view(32, -1, 4).permute(1, 0, 2).flip(2) + + tensor_2bit = unpack_2bit(tensor_2bit).view(-1, 16) + tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] + tensor_2bit = tensor_2bit[:, [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14]] + tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] + tensor_2bit = pack_2bit(tensor_2bit) + + tensor_4bit = unpack_4bit(tensor_4bit).view(-1, 8) + tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] + tensor_4bit = tensor_4bit[:, [1, 5, 3, 7, 0, 4, 2, 6]] + tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] + tensor_4bit = pack_4bit(tensor_4bit) + + return torch.cat([tensor_2bit, tensor_4bit], dim=0) From 5e44f1ca59e859ee2ac345e255faf60dd4993862 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 26 May 2024 07:38:43 +0800 Subject: [PATCH 03/22] update from fp6_quant --- test/dtypes/test_float6_e3m2.py | 21 ++++++++++ torchao/dtypes/tc_float6_e3m2.py | 68 ++++++++++++++------------------ 2 files changed, 50 insertions(+), 39 deletions(-) diff --git a/test/dtypes/test_float6_e3m2.py b/test/dtypes/test_float6_e3m2.py index b821504731..d94f2574c1 100644 --- a/test/dtypes/test_float6_e3m2.py +++ b/test/dtypes/test_float6_e3m2.py @@ -1,4 +1,5 @@ import torch +import torchao from torch.testing._internal.common_utils import ( TestCase, instantiate_parametrized_tests, @@ -6,6 +7,7 @@ run_tests, ) from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2 +from torchao.dtypes.tc_float6_e3m2 import to_tc_float6_e3m2 _DTYPES = [torch.float32, torch.float16, torch.bfloat16] @@ -120,7 +122,26 @@ def test_from_float6_e3m2_compile(self, device, no_bit_packing): torch.testing.assert_close(actual, expected) +class TestWeightPrepacking(TestCase): + @parametrize("device", _DEVICES) + def test_weight_prepacking_correctness(self, device): + x = torch.randn(256, 64, device=device) + + expected = torchao.ops.prepack_fp6_weight(to_float6_e3m2(x.cpu()).view(torch.int32)).view(torch.uint8) + actual = to_tc_float6_e3m2(x) + torch.testing.assert_close(actual.view(-1).cpu(), expected.view(-1)) + + @parametrize("device", _DEVICES) + def test_weight_prepacking_compile(self, device): + x = torch.randn(256, 64, device=device) + + expected = to_tc_float6_e3m2(x) + actual = torch.compile(to_tc_float6_e3m2)(x) + torch.testing.assert_close(actual, expected) + + instantiate_parametrized_tests(TestFp6) +instantiate_parametrized_tests(TestWeightPrepacking) if __name__ == "__main__": diff --git a/torchao/dtypes/tc_float6_e3m2.py b/torchao/dtypes/tc_float6_e3m2.py index 93e8763238..7ca969b1e8 100644 --- a/torchao/dtypes/tc_float6_e3m2.py +++ b/torchao/dtypes/tc_float6_e3m2.py @@ -2,35 +2,7 @@ import torch from torch import Tensor - -# NOTE: This implementation requires FP32 denormal numbers to be handled correctly. -# On CPU, denormal numbers might be flushed to zero for performance gain (FTZ and DAZ flags). -def _to_float6_e3m2_pt(tensor: Tensor, packed: bool = False) -> Tensor: - tensor = tensor.float() - - # correct exponent bias. this also handles subnormal numbers correctly - tensor = tensor * 2.0 ** (-127 + 3) - bits = tensor.view(torch.int32) - - sign = ((bits >> 31) & 0x1) << 5 - exp_and_man = (bits >> 21) & 0x1F - result = sign | exp_and_man - - # round to nearest even - remainder = bits & 0x1F_FFFF # truncated mantissa bits - do_round_up = (remainder > 0x10_0000) | ((remainder == 0x10_0000) & ((result & 1) == 1)) - result = torch.where(do_round_up, result + 1, result) - result = result.to(torch.uint8) - - if not packed: - return result - - # bit packing - val0, val1, val2, val3 = result.unflatten(-1, (-1, 4)).unbind(-1) - bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 - bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222 - bits2 = (val2 << 6) | (val3); # 2233 3333 - return torch.stack([bits0, bits1, bits2], dim=-1).flatten(-2) +from .float6_e3m2 import to_float6_e3m2 def pack_2bit(x: Tensor) -> Tensor: @@ -49,13 +21,16 @@ def unpack_4bit(x: Tensor) -> Tensor: return torch.stack([x >> 4, x & 0b1111], dim=-1).flatten(-2) +# this is a literal adaptation of ahead-of-time bit-level pre-packing +# https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: assert tensor.ndim == 2 M, N = tensor.shape assert (M % 64 == 0) and (N % 64 == 0) - tensor_fp6 = _to_float6_e3m2_pt(tensor) + tensor_fp6 = to_float6_e3m2(tensor, no_bit_packing=True) + # Pass 1 from original code tensor_fp6 = tensor_fp6.view(M // 64, 4, 2, 8, N // 16, 2, 8) tensor_fp6 = tensor_fp6.permute(0, 4, 1, 5, 2, 3, 6) tensor_fp6 = tensor_fp6.reshape(-1, 32, 2) @@ -65,19 +40,34 @@ def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: tensor_2bit = pack_2bit((tensor_fp6 >> 4) & 0b11) tensor_4bit = pack_4bit(tensor_fp6 & 0b1111) + # Pass 2 from original code tensor_2bit = tensor_2bit.view(32, -1, 4).permute(1, 0, 2).flip(2) tensor_4bit = tensor_4bit.view(32, -1, 4).permute(1, 0, 2).flip(2) + # Pass 3 from original code tensor_2bit = unpack_2bit(tensor_2bit).view(-1, 16) - tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] - tensor_2bit = tensor_2bit[:, [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14]] - tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] - tensor_2bit = pack_2bit(tensor_2bit) - tensor_4bit = unpack_4bit(tensor_4bit).view(-1, 8) - tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] - tensor_4bit = tensor_4bit[:, [1, 5, 3, 7, 0, 4, 2, 6]] - tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] - tensor_4bit = pack_4bit(tensor_4bit) + + # permutation like the original code (BitInterleaving_2bit) + # the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32 + # while we still unpack/pack the values from/to uint8 + # tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] + # tensor_2bit = tensor_2bit[:, [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14]] + # tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] + + # merged 3 permutations into 1 + tensor_2bit = tensor_2bit[:, [14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1]] + tensor_2bit = pack_2bit(tensor_2bit).view(-1) + + # permutation like the original code (BitInterleaving_4bit) + # the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32 + # while we still unpack/pack the values from/to uint8 + # tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] + # tensor_4bit = tensor_4bit[:, [1, 5, 3, 7, 0, 4, 2, 6]] + # tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] + + # merged 3 permutations into 1 + tensor_4bit = tensor_4bit[:, [4, 0, 6, 2, 5, 1, 7, 3]] + tensor_4bit = pack_4bit(tensor_4bit).view(-1) return torch.cat([tensor_2bit, tensor_4bit], dim=0) From bccc4f63d93bcdac9a42cc6dd543e05ee1feda86 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 26 May 2024 08:21:40 +0800 Subject: [PATCH 04/22] merge to_tc_float6_e3m2 --- torchao/dtypes/__init__.py | 3 +- torchao/dtypes/float6_e3m2.py | 68 +++++++++++++++++++++++++++++ torchao/dtypes/tc_float6_e3m2.py | 73 -------------------------------- 3 files changed, 70 insertions(+), 74 deletions(-) delete mode 100644 torchao/dtypes/tc_float6_e3m2.py diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index d12a6da566..8d2bcd8b8c 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,7 +1,7 @@ from .nf4tensor import NF4Tensor, to_nf4 from .uint4 import UInt4Tensor from .aqt import AffineQuantizedTensor, to_aq -from .float6_e3m2 import to_float6_e3m2, from_float6_e3m2 +from .float6_e3m2 import to_float6_e3m2, from_float6_e3m2, to_tc_float6_e3m2 __all__ = [ "NF4Tensor", @@ -11,4 +11,5 @@ "to_aq", "to_float6_e3m2", "from_float6_e3m2", + "to_tc_float6_e3m2", ] diff --git a/torchao/dtypes/float6_e3m2.py b/torchao/dtypes/float6_e3m2.py index 0c27838d06..d5575c655a 100644 --- a/torchao/dtypes/float6_e3m2.py +++ b/torchao/dtypes/float6_e3m2.py @@ -176,3 +176,71 @@ def from_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False, dtype: torch. val2 = _pt_float6_e3m2_to_float32(((bits1 & 0xF) << 2) | (bits2 >> 6)).to(dtype) val3 = _pt_float6_e3m2_to_float32(bits2 & 0x3F).to(dtype) return torch.stack([val0, val1, val2, val3], dim=-1).flatten(-2) + + +def _pack_2bit(x: Tensor) -> Tensor: + return (x[..., ::4] << 6) | (x[..., 1::4] << 4) | (x[..., 2::4] << 2) | x[..., 3::4] + + +def _unpack_2bit(x: Tensor) -> Tensor: + return torch.stack([x >> 6, (x >> 4) & 0b11, (x >> 2) & 0b11, x & 0b11], dim=-1).flatten(-2) + + +def _pack_4bit(x: Tensor) -> Tensor: + return (x[..., ::2] << 4) | x[..., 1::2] + + +def _unpack_4bit(x: Tensor) -> Tensor: + return torch.stack([x >> 4, x & 0b1111], dim=-1).flatten(-2) + + +# this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing +# https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h +def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: + assert tensor.ndim == 2 + M, N = tensor.shape + assert (M % 64 == 0) and (N % 64 == 0) + + tensor_fp6 = to_float6_e3m2(tensor, no_bit_packing=True) + + # Pass 1 from original code + tensor_fp6 = tensor_fp6.view(M // 64, 4, 2, 8, N // 16, 2, 8) + tensor_fp6 = tensor_fp6.permute(0, 4, 1, 5, 2, 3, 6) + tensor_fp6 = tensor_fp6.reshape(-1, 32, 2) + tensor_fp6 = tensor_fp6.permute(1, 0, 2) + tensor_fp6 = tensor_fp6.flatten() + + tensor_2bit = _pack_2bit((tensor_fp6 >> 4) & 0b11) + tensor_4bit = _pack_4bit(tensor_fp6 & 0b1111) + + # Pass 2 from original code + tensor_2bit = tensor_2bit.view(32, -1, 4).permute(1, 0, 2).flip(2) + tensor_4bit = tensor_4bit.view(32, -1, 4).permute(1, 0, 2).flip(2) + + # Pass 3 from original code + tensor_2bit = _unpack_2bit(tensor_2bit).view(-1, 16) + tensor_4bit = _unpack_4bit(tensor_4bit).view(-1, 8) + + # permutation like the original code (BitInterleaving_2bit) + # the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32 + # while we still unpack/pack the values from/to uint8 + # tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] + # tensor_2bit = tensor_2bit[:, [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14]] + # tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] + + # merged 3 permutations into 1 + tensor_2bit = tensor_2bit[:, [14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1]] + tensor_2bit = _pack_2bit(tensor_2bit).view(-1) + + # permutation like the original code (BitInterleaving_4bit) + # the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32 + # while we still unpack/pack the values from/to uint8 + # tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] + # tensor_4bit = tensor_4bit[:, [1, 5, 3, 7, 0, 4, 2, 6]] + # tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] + + # merged 3 permutations into 1 + tensor_4bit = tensor_4bit[:, [4, 0, 6, 2, 5, 1, 7, 3]] + tensor_4bit = _pack_4bit(tensor_4bit).view(-1) + + return torch.cat([tensor_2bit, tensor_4bit], dim=0) diff --git a/torchao/dtypes/tc_float6_e3m2.py b/torchao/dtypes/tc_float6_e3m2.py deleted file mode 100644 index 7ca969b1e8..0000000000 --- a/torchao/dtypes/tc_float6_e3m2.py +++ /dev/null @@ -1,73 +0,0 @@ -# https://arxiv.org/abs/2401.14112 - -import torch -from torch import Tensor -from .float6_e3m2 import to_float6_e3m2 - - -def pack_2bit(x: Tensor) -> Tensor: - return (x[..., ::4] << 6) | (x[..., 1::4] << 4) | (x[..., 2::4] << 2) | x[..., 3::4] - - -def unpack_2bit(x: Tensor) -> Tensor: - return torch.stack([x >> 6, (x >> 4) & 0b11, (x >> 2) & 0b11, x & 0b11], dim=-1).flatten(-2) - - -def pack_4bit(x: Tensor) -> Tensor: - return (x[..., ::2] << 4) | x[..., 1::2] - - -def unpack_4bit(x: Tensor) -> Tensor: - return torch.stack([x >> 4, x & 0b1111], dim=-1).flatten(-2) - - -# this is a literal adaptation of ahead-of-time bit-level pre-packing -# https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h -def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: - assert tensor.ndim == 2 - M, N = tensor.shape - assert (M % 64 == 0) and (N % 64 == 0) - - tensor_fp6 = to_float6_e3m2(tensor, no_bit_packing=True) - - # Pass 1 from original code - tensor_fp6 = tensor_fp6.view(M // 64, 4, 2, 8, N // 16, 2, 8) - tensor_fp6 = tensor_fp6.permute(0, 4, 1, 5, 2, 3, 6) - tensor_fp6 = tensor_fp6.reshape(-1, 32, 2) - tensor_fp6 = tensor_fp6.permute(1, 0, 2) - tensor_fp6 = tensor_fp6.flatten() - - tensor_2bit = pack_2bit((tensor_fp6 >> 4) & 0b11) - tensor_4bit = pack_4bit(tensor_fp6 & 0b1111) - - # Pass 2 from original code - tensor_2bit = tensor_2bit.view(32, -1, 4).permute(1, 0, 2).flip(2) - tensor_4bit = tensor_4bit.view(32, -1, 4).permute(1, 0, 2).flip(2) - - # Pass 3 from original code - tensor_2bit = unpack_2bit(tensor_2bit).view(-1, 16) - tensor_4bit = unpack_4bit(tensor_4bit).view(-1, 8) - - # permutation like the original code (BitInterleaving_2bit) - # the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32 - # while we still unpack/pack the values from/to uint8 - # tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] - # tensor_2bit = tensor_2bit[:, [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14]] - # tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] - - # merged 3 permutations into 1 - tensor_2bit = tensor_2bit[:, [14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1]] - tensor_2bit = pack_2bit(tensor_2bit).view(-1) - - # permutation like the original code (BitInterleaving_4bit) - # the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32 - # while we still unpack/pack the values from/to uint8 - # tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] - # tensor_4bit = tensor_4bit[:, [1, 5, 3, 7, 0, 4, 2, 6]] - # tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] - - # merged 3 permutations into 1 - tensor_4bit = tensor_4bit[:, [4, 0, 6, 2, 5, 1, 7, 3]] - tensor_4bit = pack_4bit(tensor_4bit).view(-1) - - return torch.cat([tensor_2bit, tensor_4bit], dim=0) From cfa304c2a1f3d6bf0d4cec930b862d300b5a9bdf Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 26 May 2024 08:46:53 +0800 Subject: [PATCH 05/22] add more optimized version --- test/dtypes/test_float6_e3m2.py | 3 +- torchao/dtypes/float6_e3m2.py | 57 +++++++++++++++++++++++---------- 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/test/dtypes/test_float6_e3m2.py b/test/dtypes/test_float6_e3m2.py index d94f2574c1..d07abe5636 100644 --- a/test/dtypes/test_float6_e3m2.py +++ b/test/dtypes/test_float6_e3m2.py @@ -6,8 +6,7 @@ parametrize, run_tests, ) -from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2 -from torchao.dtypes.tc_float6_e3m2 import to_tc_float6_e3m2 +from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2, to_tc_float6_e3m2 _DTYPES = [torch.float32, torch.float16, torch.bfloat16] diff --git a/torchao/dtypes/float6_e3m2.py b/torchao/dtypes/float6_e3m2.py index d5575c655a..3b157d6f21 100644 --- a/torchao/dtypes/float6_e3m2.py +++ b/torchao/dtypes/float6_e3m2.py @@ -196,7 +196,7 @@ def _unpack_4bit(x: Tensor) -> Tensor: # this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing # https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h -def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: +def to_tc_float6_e3m2_original(tensor: Tensor) -> Tensor: assert tensor.ndim == 2 M, N = tensor.shape assert (M % 64 == 0) and (N % 64 == 0) @@ -218,29 +218,52 @@ def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: tensor_4bit = tensor_4bit.view(32, -1, 4).permute(1, 0, 2).flip(2) # Pass 3 from original code - tensor_2bit = _unpack_2bit(tensor_2bit).view(-1, 16) - tensor_4bit = _unpack_4bit(tensor_4bit).view(-1, 8) - - # permutation like the original code (BitInterleaving_2bit) + # BitInterleaving_2bit # the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32 # while we still unpack/pack the values from/to uint8 - # tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] - # tensor_2bit = tensor_2bit[:, [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14]] - # tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] - - # merged 3 permutations into 1 - tensor_2bit = tensor_2bit[:, [14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1]] + tensor_2bit = _unpack_2bit(tensor_2bit).view(-1, 16) + tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] + tensor_2bit = tensor_2bit[:, [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14]] + tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] tensor_2bit = _pack_2bit(tensor_2bit).view(-1) - # permutation like the original code (BitInterleaving_4bit) + # BitInterleaving_4bit # the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32 # while we still unpack/pack the values from/to uint8 - # tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] - # tensor_4bit = tensor_4bit[:, [1, 5, 3, 7, 0, 4, 2, 6]] - # tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] + tensor_4bit = _unpack_4bit(tensor_4bit).view(-1, 8) + tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] + tensor_4bit = tensor_4bit[:, [1, 5, 3, 7, 0, 4, 2, 6]] + tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] + tensor_4bit = _pack_4bit(tensor_4bit).view(-1) + + return torch.cat([tensor_2bit, tensor_4bit], dim=0) + + +# more optimized version of to_tc_float6_e3m2_original() by merging ops +def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: + assert tensor.ndim == 2 + M, N = tensor.shape + assert (M % 64 == 0) and (N % 64 == 0) + + tensor_fp6 = to_float6_e3m2(tensor, no_bit_packing=True) + + tensor_fp6 = tensor_fp6.view(M // 64, 4, 2, 8, N // 16, 2, 8) + tensor_fp6 = tensor_fp6.permute(0, 4, 1, 5, 2, 3, 6) + tensor_fp6 = tensor_fp6.reshape(-1, 32, 2) + + tensor_2bit = (tensor_fp6 >> 4) & 0b11 + tensor_4bit = tensor_fp6 & 0b1111 + + tensor_2bit = tensor_2bit.view(-1, 8, 32, 2) + tensor_2bit = tensor_2bit.permute(0, 2, 1, 3) + tensor_2bit = tensor_2bit.reshape(-1, 16) + tensor_2bit = tensor_2bit[:, [2, 6, 10, 14, 0, 4, 8, 12, 3, 7, 11, 15, 1, 5, 9, 13]] + tensor_2bit = _pack_2bit(tensor_2bit).view(-1) - # merged 3 permutations into 1 - tensor_4bit = tensor_4bit[:, [4, 0, 6, 2, 5, 1, 7, 3]] + tensor_4bit = tensor_4bit.view(-1, 4, 32, 2) + tensor_4bit = tensor_4bit.permute(0, 2, 1, 3) + tensor_4bit = tensor_4bit.reshape(-1, 8) + tensor_4bit = tensor_4bit[:, [2, 6, 0, 4, 3, 7, 1, 5]] tensor_4bit = _pack_4bit(tensor_4bit).view(-1) return torch.cat([tensor_2bit, tensor_4bit], dim=0) From ed9cb1de44d8154d638e322dee8d718746da9f50 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 26 May 2024 15:49:08 +0800 Subject: [PATCH 06/22] add some notes --- torchao/dtypes/float6_e3m2.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/torchao/dtypes/float6_e3m2.py b/torchao/dtypes/float6_e3m2.py index 3b157d6f21..2e5ce38ee5 100644 --- a/torchao/dtypes/float6_e3m2.py +++ b/torchao/dtypes/float6_e3m2.py @@ -247,6 +247,8 @@ def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: tensor_fp6 = to_float6_e3m2(tensor, no_bit_packing=True) + # Section 5.2, Figure 5. + # 64x64 tile, divided into 64x16 slices, and further divided into 8x8 chunks (for FP16 tensor cores) tensor_fp6 = tensor_fp6.view(M // 64, 4, 2, 8, N // 16, 2, 8) tensor_fp6 = tensor_fp6.permute(0, 4, 1, 5, 2, 3, 6) tensor_fp6 = tensor_fp6.reshape(-1, 32, 2) @@ -254,16 +256,27 @@ def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: tensor_2bit = (tensor_fp6 >> 4) & 0b11 tensor_4bit = tensor_fp6 & 0b1111 - tensor_2bit = tensor_2bit.view(-1, 8, 32, 2) + tensor_2bit = tensor_2bit.view(-1, 8, 32, 2) # 8 chunks of 8x8, or 2 16x16 sub-block + + # v1 tensor_2bit = tensor_2bit.permute(0, 2, 1, 3) - tensor_2bit = tensor_2bit.reshape(-1, 16) + tensor_2bit = tensor_2bit.reshape(-1, 16) # 16 x 2-bit = 32-bit tensor_2bit = tensor_2bit[:, [2, 6, 10, 14, 0, 4, 8, 12, 3, 7, 11, 15, 1, 5, 9, 13]] - tensor_2bit = _pack_2bit(tensor_2bit).view(-1) - tensor_4bit = tensor_4bit.view(-1, 4, 32, 2) + # v2. this is slower + # tensor_2bit = tensor_2bit[:, [1, 3, 5, 7, 0, 2, 4, 6]].permute(0, 2, 3, 1) + + tensor_4bit = tensor_4bit.view(-1, 4, 32, 2) # 4 chunks of 8x8, or 1 16x16 sub-block + + # v1 tensor_4bit = tensor_4bit.permute(0, 2, 1, 3) - tensor_4bit = tensor_4bit.reshape(-1, 8) + tensor_4bit = tensor_4bit.reshape(-1, 8) # 8 x 4-bit = 32-bit tensor_4bit = tensor_4bit[:, [2, 6, 0, 4, 3, 7, 1, 5]] - tensor_4bit = _pack_4bit(tensor_4bit).view(-1) + + # v2. this is slower + # tensor_4bit = tensor_4bit[:, [1, 3, 0, 2]].permute(0, 2, 3, 1) + + tensor_2bit = _pack_2bit(tensor_2bit).flatten() + tensor_4bit = _pack_4bit(tensor_4bit).flatten() return torch.cat([tensor_2bit, tensor_4bit], dim=0) From f609b6f82e6e7a1a79a2d8b3d4bf3d76acc1222a Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 26 May 2024 16:14:29 +0800 Subject: [PATCH 07/22] add from_tc_float6_e3m2 --- test/dtypes/test_float6_e3m2.py | 31 ++++++++++++++++++++++++------- torchao/dtypes/__init__.py | 3 ++- torchao/dtypes/float6_e3m2.py | 33 +++++++++++++++++++++++++++++---- 3 files changed, 55 insertions(+), 12 deletions(-) diff --git a/test/dtypes/test_float6_e3m2.py b/test/dtypes/test_float6_e3m2.py index d07abe5636..5c1005d2e0 100644 --- a/test/dtypes/test_float6_e3m2.py +++ b/test/dtypes/test_float6_e3m2.py @@ -6,14 +6,14 @@ parametrize, run_tests, ) -from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2, to_tc_float6_e3m2 +from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2, to_tc_float6_e3m2, from_tc_float6_e3m2 _DTYPES = [torch.float32, torch.float16, torch.bfloat16] _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) -class TestFp6(TestCase): +class TestFloat6E3M2(TestCase): @parametrize("device", _DEVICES) @parametrize("dtype", _DTYPES) @@ -121,9 +121,9 @@ def test_from_float6_e3m2_compile(self, device, no_bit_packing): torch.testing.assert_close(actual, expected) -class TestWeightPrepacking(TestCase): +class TestTcFloat6E3M2(TestCase): @parametrize("device", _DEVICES) - def test_weight_prepacking_correctness(self, device): + def test_to_tc_float6_e3m2_correctness(self, device): x = torch.randn(256, 64, device=device) expected = torchao.ops.prepack_fp6_weight(to_float6_e3m2(x.cpu()).view(torch.int32)).view(torch.uint8) @@ -131,16 +131,33 @@ def test_weight_prepacking_correctness(self, device): torch.testing.assert_close(actual.view(-1).cpu(), expected.view(-1)) @parametrize("device", _DEVICES) - def test_weight_prepacking_compile(self, device): + def test_to_tc_float6_e3m2_compile(self, device): x = torch.randn(256, 64, device=device) expected = to_tc_float6_e3m2(x) actual = torch.compile(to_tc_float6_e3m2)(x) torch.testing.assert_close(actual, expected) + @parametrize("device", _DEVICES) + def test_from_tc_float6_e3m2_correctness(self, device): + x = torch.randn(256, 64, device=device) + x = from_float6_e3m2(to_float6_e3m2(x)) # quantize and dequantize so that the values are exactly representable in FP6 + + actual = from_tc_float6_e3m2(to_tc_float6_e3m2(x), *x.shape) + torch.testing.assert_close(actual, x) + + @parametrize("device", _DEVICES) + def test_from_tc_float6_e3m2_compile(self, device): + M, N = 256, 64 + x = torch.randint(256, size=(M * N * 3 // 4,), dtype=torch.uint8, device=device) + + expected = from_tc_float6_e3m2(x, M, N) + actual = torch.compile(from_tc_float6_e3m2)(x, M, N) + torch.testing.assert_close(actual, expected) + -instantiate_parametrized_tests(TestFp6) -instantiate_parametrized_tests(TestWeightPrepacking) +instantiate_parametrized_tests(TestFloat6E3M2) +instantiate_parametrized_tests(TestTcFloat6E3M2) if __name__ == "__main__": diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 8d2bcd8b8c..eff1bcac3b 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,7 +1,7 @@ from .nf4tensor import NF4Tensor, to_nf4 from .uint4 import UInt4Tensor from .aqt import AffineQuantizedTensor, to_aq -from .float6_e3m2 import to_float6_e3m2, from_float6_e3m2, to_tc_float6_e3m2 +from .float6_e3m2 import to_float6_e3m2, from_float6_e3m2, to_tc_float6_e3m2, from_tc_float6_e3m2 __all__ = [ "NF4Tensor", @@ -12,4 +12,5 @@ "to_float6_e3m2", "from_float6_e3m2", "to_tc_float6_e3m2", + "from_tc_float6_e3m2", ] diff --git a/torchao/dtypes/float6_e3m2.py b/torchao/dtypes/float6_e3m2.py index 2e5ce38ee5..6da0dc7fc4 100644 --- a/torchao/dtypes/float6_e3m2.py +++ b/torchao/dtypes/float6_e3m2.py @@ -259,8 +259,7 @@ def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: tensor_2bit = tensor_2bit.view(-1, 8, 32, 2) # 8 chunks of 8x8, or 2 16x16 sub-block # v1 - tensor_2bit = tensor_2bit.permute(0, 2, 1, 3) - tensor_2bit = tensor_2bit.reshape(-1, 16) # 16 x 2-bit = 32-bit + tensor_2bit = tensor_2bit.permute(0, 2, 1, 3).reshape(-1, 16) # 16 x 2-bit = 32-bit tensor_2bit = tensor_2bit[:, [2, 6, 10, 14, 0, 4, 8, 12, 3, 7, 11, 15, 1, 5, 9, 13]] # v2. this is slower @@ -269,8 +268,7 @@ def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: tensor_4bit = tensor_4bit.view(-1, 4, 32, 2) # 4 chunks of 8x8, or 1 16x16 sub-block # v1 - tensor_4bit = tensor_4bit.permute(0, 2, 1, 3) - tensor_4bit = tensor_4bit.reshape(-1, 8) # 8 x 4-bit = 32-bit + tensor_4bit = tensor_4bit.permute(0, 2, 1, 3).reshape(-1, 8) # 8 x 4-bit = 32-bit tensor_4bit = tensor_4bit[:, [2, 6, 0, 4, 3, 7, 1, 5]] # v2. this is slower @@ -280,3 +278,30 @@ def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: tensor_4bit = _pack_4bit(tensor_4bit).flatten() return torch.cat([tensor_2bit, tensor_4bit], dim=0) + + +def from_tc_float6_e3m2(tensor: Tensor, M: int, N: int, dtype: torch.dtype = torch.float32) -> Tensor: + assert tensor.ndim == 1 + assert (M % 64 == 0) and (N % 64 == 0) + size_2bit = M * N // 4 + size_4bit = M * N // 2 + assert tensor.numel() == size_2bit + size_4bit + + tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit]) + tensor_2bit = _unpack_2bit(tensor_2bit) + tensor_4bit = _unpack_4bit(tensor_4bit) + + tensor_2bit = tensor_2bit.view(-1, 16) + tensor_2bit = tensor_2bit[:, [4, 12, 0, 8, 5, 13, 1, 9, 6, 14, 2, 10, 7, 15, 3, 11]] + tensor_2bit = tensor_2bit.view(-1, 32, 8, 2).permute(0, 2, 1, 3).flatten() + + tensor_4bit = tensor_4bit.view(-1, 8) + tensor_4bit = tensor_4bit[:, [2, 6, 0, 4, 3, 7, 1, 5]] + tensor_4bit = tensor_4bit.view(-1, 32, 4, 2).permute(0, 2, 1, 3).flatten() + + tensor_fp6 = (tensor_2bit << 4) | tensor_4bit + tensor_fp6 = tensor_fp6.view(M // 64, N // 16, 4, 2, 2, 8, 8) + tensor_fp6 = tensor_fp6.permute(0, 2, 4, 5, 1, 3, 6) + tensor_fp6 = tensor_fp6.reshape(M, N) + + return from_float6_e3m2(tensor_fp6, no_bit_packing=True, dtype=dtype) From dec40af7627902f54b372bc0bfcd511bbe9e69cc Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 26 May 2024 16:41:33 +0800 Subject: [PATCH 08/22] add some docs --- .../weight_prepacking_annotated.cpp.bak | 289 ------------------ torchao/dtypes/float6_e3m2.py | 32 +- 2 files changed, 30 insertions(+), 291 deletions(-) delete mode 100644 torchao/csrc/fp6_llm/weight_prepacking_annotated.cpp.bak diff --git a/torchao/csrc/fp6_llm/weight_prepacking_annotated.cpp.bak b/torchao/csrc/fp6_llm/weight_prepacking_annotated.cpp.bak deleted file mode 100644 index d73c361b50..0000000000 --- a/torchao/csrc/fp6_llm/weight_prepacking_annotated.cpp.bak +++ /dev/null @@ -1,289 +0,0 @@ -// Copyright 2024 FP6-LLM authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h - -#include -#include -#include - -using namespace std; - -void Padding_8_FP6_To_8_Bytes(unsigned char Padded_FP6[], unsigned char* FP6_Array) // padding 0 to the lowerest bit location -{ - Padded_FP6[0] = FP6_Array[0] & 0xfc; - Padded_FP6[1] = (FP6_Array[0]<<6) | ((FP6_Array[1]>>2) & 0xfc); - Padded_FP6[2] = (FP6_Array[1]<<4) | ((FP6_Array[2]>>4) & 0xfc ); - Padded_FP6[3] = FP6_Array[2]<<2; - Padded_FP6[4] = FP6_Array[3] & 0xfc; - Padded_FP6[5] = (FP6_Array[3]<<6) | ((FP6_Array[4]>>2) & 0xfc); - Padded_FP6[6] = (FP6_Array[4]<<4) | ((FP6_Array[5]>>4) & 0xfc); - Padded_FP6[7] = FP6_Array[5]<<2; -} - -unsigned char Extract_2_Bits_From_4_PaddedFP6(unsigned char B1, unsigned char B2, unsigned char B3, unsigned char B4) -{ - unsigned char out; - out = (B1&0xc0) | ( (B2&0xc0) >> 2 ) | ( (B3&0xc0) >> 4 ) | ( (B4&0xc0) >> 6 ); - return out; -} - -unsigned char Extract_4_Bits_From_2_PaddedFP6(unsigned char B1, unsigned char B2) // The highest two bits are already extracted by Extract_2_Bits_From_4_PaddedFP6(); -{ - unsigned char out; - out = ( (B1<<2) & 0xf0 ) | ( (B2>>2) & 0x0f ); - return out; -} - -// dealing with 4 1*8 blocks of FP6 -void Assign_32_FP6_To_4_Thread(vector Seg_2bit[], vector Seg_4bit[], unsigned char* PTR_1, unsigned char* PTR_2, unsigned char* PTR_3, unsigned char* PTR_4) -{ - // unpack 8 elemnts in a row of the 8x8 block - unsigned char Padded_8_FP8[4][8]; - Padding_8_FP6_To_8_Bytes(Padded_8_FP8[0], PTR_1); // row of block1 - Padding_8_FP6_To_8_Bytes(Padded_8_FP8[1], PTR_2); // row of block2 - Padding_8_FP6_To_8_Bytes(Padded_8_FP8[2], PTR_3); // row of block3 - Padding_8_FP6_To_8_Bytes(Padded_8_FP8[3], PTR_4); // row of block4 - - // - unsigned char Seg1_Byte1_T[4]; - unsigned char Seg1_Byte2_T[4]; - unsigned char Seg2_Byte1_T[4]; - unsigned char Seg2_Byte2_T[4]; - unsigned char Seg2_Byte3_T[4]; - unsigned char Seg2_Byte4_T[4]; - - // what is this 4? -> 2 elem of each row - for(int t=0; t<4; t++) - { - Seg1_Byte1_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[0][0+t*2], // 2 elem of row1 and 2 elem of row2 - Padded_8_FP8[0][1+t*2], - Padded_8_FP8[1][0+t*2], - Padded_8_FP8[1][1+t*2]); - - Seg1_Byte2_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[2][0+t*2], // 2 elem of row3 and 2 elem of row4 - Padded_8_FP8[2][1+t*2], - Padded_8_FP8[3][0+t*2], - Padded_8_FP8[3][1+t*2]); - - Seg2_Byte1_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[0][0+t*2], Padded_8_FP8[0][1+t*2]); // 2 elem of row1 - Seg2_Byte2_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[1][0+t*2], Padded_8_FP8[1][1+t*2]); // 2 elem of row2 - Seg2_Byte3_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[2][0+t*2], Padded_8_FP8[2][1+t*2]); // 2 elem of row3 - Seg2_Byte4_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[3][0+t*2], Padded_8_FP8[3][1+t*2]); // 2 elem of row4 - } - - for(int t=0; t<4; t++) - { - Seg_2bit[t].push_back(Seg1_Byte1_T[t]); - Seg_2bit[t].push_back(Seg1_Byte2_T[t]); - - Seg_4bit[t].push_back(Seg2_Byte1_T[t]); - Seg_4bit[t].push_back(Seg2_Byte2_T[t]); - Seg_4bit[t].push_back(Seg2_Byte3_T[t]); - Seg_4bit[t].push_back(Seg2_Byte4_T[t]); - } - return; -} - -void BitInterleaving_2bit(unsigned char* PTR_4Bytes) -{ - unsigned int *PTR_UINT = reinterpret_cast(PTR_4Bytes); - unsigned int input = *PTR_UINT; - - // 4 bytes -> 4x4 values, since each byte has 4 values - // LAYOUT - // [12,13,14,15] high address - // [ 8, 9,10,11] - // [ 4, 5, 6, 7] - // [ 0, 1, 2, 3] low address - - //int order_2bit[16] = {1,5,9,13,3,7,11,15,2,6,10,14,4,8,12,16}; // pre-defined order for bit-interleaving in FP6-LLM - int order_2bit[16] = {2,6,10,14,4,8,12,16,1,5,9,13,3,7,11,15}; // pre-defined order for bit-interleaving in FP6-LLM - unsigned int Frags_2bit[16]; // The highest 2 bits are used to store the extracted fragments. - - // each 2-bit is an FP6 value - for(int i=0; i<16; i++) - Frags_2bit[i] = ( input << 2*(order_2bit[i]-1) ) & 0xc0000000; - - unsigned int output = 0x00000000; - for(int i=0; i<16; i++) - output |= ( Frags_2bit[i] >> (i*2) ); - // - *PTR_UINT = output; -} - -void BitInterleaving_4bit(unsigned char* PTR_4Bytes) -{ - unsigned int *PTR_UINT = reinterpret_cast(PTR_4Bytes); - unsigned int input = *PTR_UINT; - - //int order_4bit[8] = {1,5,3,7,2,6,4,8}; // pre-defined order for bit-interleaving in FP6-LLM - int order_4bit[8] = {2,6,4,8,1,5,3,7}; // pre-defined order for bit-interleaving in FP6-LLM - unsigned int Frags_4bit[8]; // The highest4 bits are used to store the extracted fragments. - - // each 4-bit is an FP6 value - for(int i=0; i<8; i++) - Frags_4bit[i] = ( input << 4*(order_4bit[i]-1) ) & 0xf0000000; - // - unsigned int output = 0x00000000; - for(int i=0; i<8; i++) - output |= ( Frags_4bit[i] >> (i*4) ); - // - *PTR_UINT = output; -} - -/* - * Inputs: - * (1) unsigned char Weight_6bit [M*K*6/8] - * Outputs: - * (1) unsigned char Weight_2bit [M*K*2/8] - * (2) unsigned char Weight_4bit [M*K*4/8] - * - * Assumption: Weight_6bit, Weight_2bit, Weight_4bit all stored continuously in row-major. - * 8 FP6 = 6 Bytes - * 8 FP4 = 4 Bytes - * 8 FP2 = 2 Bytes - */ -void weight_matrix_prepacking(int* packed_weights, int *FP6Weights, size_t M, size_t K) -{ - assert(M % 64 == 0); - assert(K % 64 == 0); - // - unsigned char* Weight_6bit = reinterpret_cast(FP6Weights); - unsigned char* Weight_2bit = reinterpret_cast(packed_weights); - unsigned char* Weight_4bit = Weight_2bit + M*K*2/8; - // - vector A_Segment_2bit[32]; // 32 vector objects - vector A_Segment_4bit[32]; - // - size_t BytesPerRow = K*6/8; - - // Pass-1: (1) 2+4 split; (2) assign weights to 32 threads. - - // (M, N) - // -> (M/64, 64, N/16, 16) - (64x16) blocks - slice - // -> (M/64, N/16, 64, 16) - - // -> (M/64, N/16, n_slice, 16, 16) - (16x16) blocks - n_slice=4 - - // -> (M/64, N/16, n_slice, subblock_row, 8, subblock_col, 8) - (8x8) sub-block - // -> (M/64, N/16, n_slice, subblock_col, subblock_row, 8, 8) - note subblock order is swapped - - // -> (M/64, N/16, n_slice, subblock_col, subblock_row, 32, 2) - 2 elements - // 2-bit and 4-bit weights - - // -> (32, M/64, N/16, n_slice, subblock_col, subblock_row, 2) - - // 2-bit weight: 4 values are packed into 1 uint8 - // 4-bit weight: 2 values are packed into 1 uint8 - - for (size_t i = 0; i < M / 64; i++) // - { - for (size_t j = 0; j < K / 16; j++) - { - for(size_t k=0; k<64/16; k++) - { - // a 16x16 block - size_t row = i*64 + k*16; - size_t col = j*16; - - // sub-divie 16x16 blocks into 4 8x8 blocks - // |--------|--------| - // |ptr1 |ptr3 | - // | | | - // |--------|--------| - // |ptr2 |ptr4 | - // | | | - // |--------|--------| - unsigned char* StartPTR_1 = Weight_6bit + row*BytesPerRow + col*6/8; // start of the block - unsigned char* StartPTR_2 = StartPTR_1 + 8*BytesPerRow; // 8 rows below - unsigned char* StartPTR_3 = StartPTR_1 + 8*6/8; // 8 cols to the right - unsigned char* StartPTR_4 = StartPTR_2 + 8*6/8; // 8 cols to the right - - // Dealing with each 16*16 blocks then... - // do 1 row (of 8x8 block) at a time - // in an 8x8 block, split into 2 elems at a time (32 groups) - for(int l=0; l<8; l++) - Assign_32_FP6_To_4_Thread(&A_Segment_2bit[l*4], // 0, 4, 8, 16, ..., 28 - &A_Segment_4bit[l*4], - StartPTR_1+l*BytesPerRow, - StartPTR_2+l*BytesPerRow, - StartPTR_3+l*BytesPerRow, - StartPTR_4+l*BytesPerRow); - } - } - } - // Verifying the length of 2_bit segments and 4_bit segments - size_t BytesPerThread_2bit = M*K*2/8/32; - size_t BytesPerThread_4bit = M*K*4/8/32; - for(int i=0; i<32; i++) - { - assert(A_Segment_2bit[i].size()==BytesPerThread_2bit); - assert(A_Segment_4bit[i].size()==BytesPerThread_4bit); - } - - // Pass-2: Optimizing coleasced global memory access - // weight transpose - // (32, I, 4) -> (I, 32, 4) -> reverse last dim - for(size_t i=0; i (I, 32, 4) -> reverse last dim - for(size_t i=0; i -#include - -namespace torchao { - -/* - * Weight prepacking (Pytorch interface). - * [Input & Output] - * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. - * [Output] - * packed_tensor: int tensor of shape [OC, IC // 16 * 3]; - */ -at::Tensor weight_matrix_prepacking_cpu(at::Tensor fp6_tensor) -{ - size_t OC = fp6_tensor.size(0); - size_t IC = fp6_tensor.size(1); - TORCH_CHECK(IC % 3 == 0, "Expect packed input dim % 3 == 0, but receive ", IC, " instead."); - IC = IC * 16 / 3; - TORCH_CHECK((OC % 256 == 0) && (IC % 64 == 0), "Expect output dim % 256 == 0 and input dim % 64 == 0, but receive ", OC, " and ", IC, " instead."); - auto packed_tensor = at::empty_like(fp6_tensor); - auto packed_tensor_ptr = reinterpret_cast(packed_tensor.data_ptr()); - auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); - weight_matrix_prepacking(packed_tensor_ptr, fp6_tensor_ptr, OC, IC); - return packed_tensor; -} - -TORCH_LIBRARY_IMPL(torchao, CPU, m) { - m.impl("torchao::prepack_fp6_weight", &weight_matrix_prepacking_cpu); -} - -} diff --git a/torchao/dtypes/float6_e3m2.py b/torchao/dtypes/float6_e3m2.py index 6da0dc7fc4..9ed906e5aa 100644 --- a/torchao/dtypes/float6_e3m2.py +++ b/torchao/dtypes/float6_e3m2.py @@ -196,7 +196,7 @@ def _unpack_4bit(x: Tensor) -> Tensor: # this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing # https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h -def to_tc_float6_e3m2_original(tensor: Tensor) -> Tensor: +def _to_tc_float6_e3m2_original(tensor: Tensor) -> Tensor: assert tensor.ndim == 2 M, N = tensor.shape assert (M % 64 == 0) and (N % 64 == 0) @@ -239,8 +239,25 @@ def to_tc_float6_e3m2_original(tensor: Tensor) -> Tensor: return torch.cat([tensor_2bit, tensor_4bit], dim=0) -# more optimized version of to_tc_float6_e3m2_original() by merging ops +# more optimized version of _to_tc_float6_e3m2_original() by merging ops +# https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: + """Convert input tensor to TC-FP6 for use with FP6-LLM. This format has 3 exponent bits and 2 mantissa bits. + See https://arxiv.org/abs/2401.14112 for more information. + + Args: + tensor: Input tensor with shape (M, N), where M and N are multiples of 64. + + Returns: + :class:`torch.Tensor`: TC-FP6 tensor, stored as uint8 data with shape (M * N * 3 / 4,). + + Note: + This TC-FP6 format does not represent +/-inf and NaN. Thus, make sure that input tensor does + not have +/-inf or NaN values, and no values with magnitude >= 30 (largest number in FP6 is 28. + All numbers >= 28 and < 30 will be rounded down to 28, while >= 30 will overflow). + + See also :func:`from_tc_float6_e3m2` + """ assert tensor.ndim == 2 M, N = tensor.shape assert (M % 64 == 0) and (N % 64 == 0) @@ -281,6 +298,17 @@ def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: def from_tc_float6_e3m2(tensor: Tensor, M: int, N: int, dtype: torch.dtype = torch.float32) -> Tensor: + """Convert a TC-FP6 tensor (created by :func:`to_tc_float6_e3m2`) to FP32. + + Args: + tensor: TC-FP6 tensor, stored as uint8 data. + M: first dimension of the weight. + N: second dimension of the weight. + dtype: returned dtype. + + Returns: + :class:`torch.Tensor`: FP32 tensor. + """ assert tensor.ndim == 1 assert (M % 64 == 0) and (N % 64 == 0) size_2bit = M * N // 4 From 5e5dfdcd1cb7ff371f48c7d442700cc415439b5b Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 26 May 2024 17:37:38 +0800 Subject: [PATCH 09/22] make fp6_llm.py --- test/dtypes/test_float6_e3m2.py | 39 +------ test/quantization/test_fp6_llm.py | 55 ++++++++++ torchao/dtypes/__init__.py | 4 +- torchao/dtypes/float6_e3m2.py | 157 ---------------------------- torchao/quantization/fp6_llm.py | 167 ++++++++++++++++++++++++++++++ 5 files changed, 224 insertions(+), 198 deletions(-) create mode 100644 test/quantization/test_fp6_llm.py create mode 100644 torchao/quantization/fp6_llm.py diff --git a/test/dtypes/test_float6_e3m2.py b/test/dtypes/test_float6_e3m2.py index 5c1005d2e0..c3365cffeb 100644 --- a/test/dtypes/test_float6_e3m2.py +++ b/test/dtypes/test_float6_e3m2.py @@ -1,12 +1,11 @@ import torch -import torchao from torch.testing._internal.common_utils import ( TestCase, instantiate_parametrized_tests, parametrize, run_tests, ) -from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2, to_tc_float6_e3m2, from_tc_float6_e3m2 +from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2 _DTYPES = [torch.float32, torch.float16, torch.bfloat16] @@ -121,43 +120,7 @@ def test_from_float6_e3m2_compile(self, device, no_bit_packing): torch.testing.assert_close(actual, expected) -class TestTcFloat6E3M2(TestCase): - @parametrize("device", _DEVICES) - def test_to_tc_float6_e3m2_correctness(self, device): - x = torch.randn(256, 64, device=device) - - expected = torchao.ops.prepack_fp6_weight(to_float6_e3m2(x.cpu()).view(torch.int32)).view(torch.uint8) - actual = to_tc_float6_e3m2(x) - torch.testing.assert_close(actual.view(-1).cpu(), expected.view(-1)) - - @parametrize("device", _DEVICES) - def test_to_tc_float6_e3m2_compile(self, device): - x = torch.randn(256, 64, device=device) - - expected = to_tc_float6_e3m2(x) - actual = torch.compile(to_tc_float6_e3m2)(x) - torch.testing.assert_close(actual, expected) - - @parametrize("device", _DEVICES) - def test_from_tc_float6_e3m2_correctness(self, device): - x = torch.randn(256, 64, device=device) - x = from_float6_e3m2(to_float6_e3m2(x)) # quantize and dequantize so that the values are exactly representable in FP6 - - actual = from_tc_float6_e3m2(to_tc_float6_e3m2(x), *x.shape) - torch.testing.assert_close(actual, x) - - @parametrize("device", _DEVICES) - def test_from_tc_float6_e3m2_compile(self, device): - M, N = 256, 64 - x = torch.randint(256, size=(M * N * 3 // 4,), dtype=torch.uint8, device=device) - - expected = from_tc_float6_e3m2(x, M, N) - actual = torch.compile(from_tc_float6_e3m2)(x, M, N) - torch.testing.assert_close(actual, expected) - - instantiate_parametrized_tests(TestFloat6E3M2) -instantiate_parametrized_tests(TestTcFloat6E3M2) if __name__ == "__main__": diff --git a/test/quantization/test_fp6_llm.py b/test/quantization/test_fp6_llm.py new file mode 100644 index 0000000000..5a97ead824 --- /dev/null +++ b/test/quantization/test_fp6_llm.py @@ -0,0 +1,55 @@ +import torch +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) +from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2 +from torchao.quantization.fp6_llm import to_tc_float6_e3m2, from_tc_float6_e3m2 +from torchao.ops import prepack_fp6_weight + + +_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + + +class TestTcFloat6E3M2(TestCase): + @parametrize("device", _DEVICES) + def test_to_tc_float6_e3m2_correctness(self, device): + x = torch.randn(256, 64, device=device) + + expected = prepack_fp6_weight(to_float6_e3m2(x.cpu()).view(torch.int32)).view(torch.uint8) + actual = to_tc_float6_e3m2(x) + torch.testing.assert_close(actual.view(-1).cpu(), expected.view(-1)) + + @parametrize("device", _DEVICES) + def test_to_tc_float6_e3m2_compile(self, device): + x = torch.randn(256, 64, device=device) + + expected = to_tc_float6_e3m2(x) + actual = torch.compile(to_tc_float6_e3m2)(x) + torch.testing.assert_close(actual, expected) + + @parametrize("device", _DEVICES) + def test_from_tc_float6_e3m2_correctness(self, device): + x = torch.randn(256, 64, device=device) + x = from_float6_e3m2(to_float6_e3m2(x)) # quantize and dequantize so that the values are exactly representable in FP6 + + actual = from_tc_float6_e3m2(to_tc_float6_e3m2(x), *x.shape) + torch.testing.assert_close(actual, x) + + @parametrize("device", _DEVICES) + def test_from_tc_float6_e3m2_compile(self, device): + M, N = 256, 64 + x = torch.randint(256, size=(M * N * 3 // 4,), dtype=torch.uint8, device=device) + + expected = from_tc_float6_e3m2(x, M, N) + actual = torch.compile(from_tc_float6_e3m2)(x, M, N) + torch.testing.assert_close(actual, expected) + + +instantiate_parametrized_tests(TestTcFloat6E3M2) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index eff1bcac3b..d12a6da566 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,7 +1,7 @@ from .nf4tensor import NF4Tensor, to_nf4 from .uint4 import UInt4Tensor from .aqt import AffineQuantizedTensor, to_aq -from .float6_e3m2 import to_float6_e3m2, from_float6_e3m2, to_tc_float6_e3m2, from_tc_float6_e3m2 +from .float6_e3m2 import to_float6_e3m2, from_float6_e3m2 __all__ = [ "NF4Tensor", @@ -11,6 +11,4 @@ "to_aq", "to_float6_e3m2", "from_float6_e3m2", - "to_tc_float6_e3m2", - "from_tc_float6_e3m2", ] diff --git a/torchao/dtypes/float6_e3m2.py b/torchao/dtypes/float6_e3m2.py index 9ed906e5aa..0c27838d06 100644 --- a/torchao/dtypes/float6_e3m2.py +++ b/torchao/dtypes/float6_e3m2.py @@ -176,160 +176,3 @@ def from_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False, dtype: torch. val2 = _pt_float6_e3m2_to_float32(((bits1 & 0xF) << 2) | (bits2 >> 6)).to(dtype) val3 = _pt_float6_e3m2_to_float32(bits2 & 0x3F).to(dtype) return torch.stack([val0, val1, val2, val3], dim=-1).flatten(-2) - - -def _pack_2bit(x: Tensor) -> Tensor: - return (x[..., ::4] << 6) | (x[..., 1::4] << 4) | (x[..., 2::4] << 2) | x[..., 3::4] - - -def _unpack_2bit(x: Tensor) -> Tensor: - return torch.stack([x >> 6, (x >> 4) & 0b11, (x >> 2) & 0b11, x & 0b11], dim=-1).flatten(-2) - - -def _pack_4bit(x: Tensor) -> Tensor: - return (x[..., ::2] << 4) | x[..., 1::2] - - -def _unpack_4bit(x: Tensor) -> Tensor: - return torch.stack([x >> 4, x & 0b1111], dim=-1).flatten(-2) - - -# this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing -# https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h -def _to_tc_float6_e3m2_original(tensor: Tensor) -> Tensor: - assert tensor.ndim == 2 - M, N = tensor.shape - assert (M % 64 == 0) and (N % 64 == 0) - - tensor_fp6 = to_float6_e3m2(tensor, no_bit_packing=True) - - # Pass 1 from original code - tensor_fp6 = tensor_fp6.view(M // 64, 4, 2, 8, N // 16, 2, 8) - tensor_fp6 = tensor_fp6.permute(0, 4, 1, 5, 2, 3, 6) - tensor_fp6 = tensor_fp6.reshape(-1, 32, 2) - tensor_fp6 = tensor_fp6.permute(1, 0, 2) - tensor_fp6 = tensor_fp6.flatten() - - tensor_2bit = _pack_2bit((tensor_fp6 >> 4) & 0b11) - tensor_4bit = _pack_4bit(tensor_fp6 & 0b1111) - - # Pass 2 from original code - tensor_2bit = tensor_2bit.view(32, -1, 4).permute(1, 0, 2).flip(2) - tensor_4bit = tensor_4bit.view(32, -1, 4).permute(1, 0, 2).flip(2) - - # Pass 3 from original code - # BitInterleaving_2bit - # the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32 - # while we still unpack/pack the values from/to uint8 - tensor_2bit = _unpack_2bit(tensor_2bit).view(-1, 16) - tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] - tensor_2bit = tensor_2bit[:, [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14]] - tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] - tensor_2bit = _pack_2bit(tensor_2bit).view(-1) - - # BitInterleaving_4bit - # the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32 - # while we still unpack/pack the values from/to uint8 - tensor_4bit = _unpack_4bit(tensor_4bit).view(-1, 8) - tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] - tensor_4bit = tensor_4bit[:, [1, 5, 3, 7, 0, 4, 2, 6]] - tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] - tensor_4bit = _pack_4bit(tensor_4bit).view(-1) - - return torch.cat([tensor_2bit, tensor_4bit], dim=0) - - -# more optimized version of _to_tc_float6_e3m2_original() by merging ops -# https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h -def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: - """Convert input tensor to TC-FP6 for use with FP6-LLM. This format has 3 exponent bits and 2 mantissa bits. - See https://arxiv.org/abs/2401.14112 for more information. - - Args: - tensor: Input tensor with shape (M, N), where M and N are multiples of 64. - - Returns: - :class:`torch.Tensor`: TC-FP6 tensor, stored as uint8 data with shape (M * N * 3 / 4,). - - Note: - This TC-FP6 format does not represent +/-inf and NaN. Thus, make sure that input tensor does - not have +/-inf or NaN values, and no values with magnitude >= 30 (largest number in FP6 is 28. - All numbers >= 28 and < 30 will be rounded down to 28, while >= 30 will overflow). - - See also :func:`from_tc_float6_e3m2` - """ - assert tensor.ndim == 2 - M, N = tensor.shape - assert (M % 64 == 0) and (N % 64 == 0) - - tensor_fp6 = to_float6_e3m2(tensor, no_bit_packing=True) - - # Section 5.2, Figure 5. - # 64x64 tile, divided into 64x16 slices, and further divided into 8x8 chunks (for FP16 tensor cores) - tensor_fp6 = tensor_fp6.view(M // 64, 4, 2, 8, N // 16, 2, 8) - tensor_fp6 = tensor_fp6.permute(0, 4, 1, 5, 2, 3, 6) - tensor_fp6 = tensor_fp6.reshape(-1, 32, 2) - - tensor_2bit = (tensor_fp6 >> 4) & 0b11 - tensor_4bit = tensor_fp6 & 0b1111 - - tensor_2bit = tensor_2bit.view(-1, 8, 32, 2) # 8 chunks of 8x8, or 2 16x16 sub-block - - # v1 - tensor_2bit = tensor_2bit.permute(0, 2, 1, 3).reshape(-1, 16) # 16 x 2-bit = 32-bit - tensor_2bit = tensor_2bit[:, [2, 6, 10, 14, 0, 4, 8, 12, 3, 7, 11, 15, 1, 5, 9, 13]] - - # v2. this is slower - # tensor_2bit = tensor_2bit[:, [1, 3, 5, 7, 0, 2, 4, 6]].permute(0, 2, 3, 1) - - tensor_4bit = tensor_4bit.view(-1, 4, 32, 2) # 4 chunks of 8x8, or 1 16x16 sub-block - - # v1 - tensor_4bit = tensor_4bit.permute(0, 2, 1, 3).reshape(-1, 8) # 8 x 4-bit = 32-bit - tensor_4bit = tensor_4bit[:, [2, 6, 0, 4, 3, 7, 1, 5]] - - # v2. this is slower - # tensor_4bit = tensor_4bit[:, [1, 3, 0, 2]].permute(0, 2, 3, 1) - - tensor_2bit = _pack_2bit(tensor_2bit).flatten() - tensor_4bit = _pack_4bit(tensor_4bit).flatten() - - return torch.cat([tensor_2bit, tensor_4bit], dim=0) - - -def from_tc_float6_e3m2(tensor: Tensor, M: int, N: int, dtype: torch.dtype = torch.float32) -> Tensor: - """Convert a TC-FP6 tensor (created by :func:`to_tc_float6_e3m2`) to FP32. - - Args: - tensor: TC-FP6 tensor, stored as uint8 data. - M: first dimension of the weight. - N: second dimension of the weight. - dtype: returned dtype. - - Returns: - :class:`torch.Tensor`: FP32 tensor. - """ - assert tensor.ndim == 1 - assert (M % 64 == 0) and (N % 64 == 0) - size_2bit = M * N // 4 - size_4bit = M * N // 2 - assert tensor.numel() == size_2bit + size_4bit - - tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit]) - tensor_2bit = _unpack_2bit(tensor_2bit) - tensor_4bit = _unpack_4bit(tensor_4bit) - - tensor_2bit = tensor_2bit.view(-1, 16) - tensor_2bit = tensor_2bit[:, [4, 12, 0, 8, 5, 13, 1, 9, 6, 14, 2, 10, 7, 15, 3, 11]] - tensor_2bit = tensor_2bit.view(-1, 32, 8, 2).permute(0, 2, 1, 3).flatten() - - tensor_4bit = tensor_4bit.view(-1, 8) - tensor_4bit = tensor_4bit[:, [2, 6, 0, 4, 3, 7, 1, 5]] - tensor_4bit = tensor_4bit.view(-1, 32, 4, 2).permute(0, 2, 1, 3).flatten() - - tensor_fp6 = (tensor_2bit << 4) | tensor_4bit - tensor_fp6 = tensor_fp6.view(M // 64, N // 16, 4, 2, 2, 8, 8) - tensor_fp6 = tensor_fp6.permute(0, 2, 4, 5, 1, 3, 6) - tensor_fp6 = tensor_fp6.reshape(M, N) - - return from_float6_e3m2(tensor_fp6, no_bit_packing=True, dtype=dtype) diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py new file mode 100644 index 0000000000..f6de8f6a87 --- /dev/null +++ b/torchao/quantization/fp6_llm.py @@ -0,0 +1,167 @@ +import torch +from torch import nn, Tensor +from torchao.dtypes.float6_e3m2 import FLOAT6_E3M2_MAX, to_float6_e3m2, from_float6_e3m2 +from torchao.ops import fp16act_fp6weight_linear + + +def _pack_2bit(x: Tensor) -> Tensor: + return (x[..., ::4] << 6) | (x[..., 1::4] << 4) | (x[..., 2::4] << 2) | x[..., 3::4] + + +def _unpack_2bit(x: Tensor) -> Tensor: + return torch.stack([x >> 6, (x >> 4) & 0b11, (x >> 2) & 0b11, x & 0b11], dim=-1).flatten(-2) + + +def _pack_4bit(x: Tensor) -> Tensor: + return (x[..., ::2] << 4) | x[..., 1::2] + + +def _unpack_4bit(x: Tensor) -> Tensor: + return torch.stack([x >> 4, x & 0b1111], dim=-1).flatten(-2) + + +# this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing +# https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h +def _to_tc_float6_e3m2_original(tensor: Tensor) -> Tensor: + assert tensor.ndim == 2 + M, N = tensor.shape + assert (M % 64 == 0) and (N % 64 == 0) + + tensor_fp6 = to_float6_e3m2(tensor, no_bit_packing=True) + + # Pass 1 from original code + tensor_fp6 = tensor_fp6.view(M // 64, 4, 2, 8, N // 16, 2, 8) + tensor_fp6 = tensor_fp6.permute(0, 4, 1, 5, 2, 3, 6) + tensor_fp6 = tensor_fp6.reshape(-1, 32, 2) + tensor_fp6 = tensor_fp6.permute(1, 0, 2) + tensor_fp6 = tensor_fp6.flatten() + + tensor_2bit = _pack_2bit((tensor_fp6 >> 4) & 0b11) + tensor_4bit = _pack_4bit(tensor_fp6 & 0b1111) + + # Pass 2 from original code + tensor_2bit = tensor_2bit.view(32, -1, 4).permute(1, 0, 2).flip(2) + tensor_4bit = tensor_4bit.view(32, -1, 4).permute(1, 0, 2).flip(2) + + # Pass 3 from original code + # BitInterleaving_2bit + # the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32 + # while we still unpack/pack the values from/to uint8 + tensor_2bit = _unpack_2bit(tensor_2bit).view(-1, 16) + tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] + tensor_2bit = tensor_2bit[:, [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14]] + tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] + tensor_2bit = _pack_2bit(tensor_2bit).view(-1) + + # BitInterleaving_4bit + # the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32 + # while we still unpack/pack the values from/to uint8 + tensor_4bit = _unpack_4bit(tensor_4bit).view(-1, 8) + tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] + tensor_4bit = tensor_4bit[:, [1, 5, 3, 7, 0, 4, 2, 6]] + tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] + tensor_4bit = _pack_4bit(tensor_4bit).view(-1) + + return torch.cat([tensor_2bit, tensor_4bit], dim=0) + + +# more optimized version of _to_tc_float6_e3m2_original() by merging ops +# https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h +def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: + assert tensor.ndim == 2 + M, N = tensor.shape + assert (M % 64 == 0) and (N % 64 == 0) + + tensor_fp6 = to_float6_e3m2(tensor, no_bit_packing=True) + + # Section 5.2, Figure 5. + # 64x64 tile, divided into 64x16 slices, and further divided into 8x8 chunks (for FP16 tensor cores) + tensor_fp6 = tensor_fp6.view(M // 64, 4, 2, 8, N // 16, 2, 8) + tensor_fp6 = tensor_fp6.permute(0, 4, 1, 5, 2, 3, 6) + tensor_fp6 = tensor_fp6.reshape(-1, 32, 2) + + tensor_2bit = (tensor_fp6 >> 4) & 0b11 + tensor_4bit = tensor_fp6 & 0b1111 + + tensor_2bit = tensor_2bit.view(-1, 8, 32, 2) # 8 chunks of 8x8, or 2 16x16 sub-block + + # v1 + tensor_2bit = tensor_2bit.permute(0, 2, 1, 3).reshape(-1, 16) # 16 x 2-bit = 32-bit + tensor_2bit = tensor_2bit[:, [2, 6, 10, 14, 0, 4, 8, 12, 3, 7, 11, 15, 1, 5, 9, 13]] + + # v2. this is slower + # tensor_2bit = tensor_2bit[:, [1, 3, 5, 7, 0, 2, 4, 6]].permute(0, 2, 3, 1) + + tensor_4bit = tensor_4bit.view(-1, 4, 32, 2) # 4 chunks of 8x8, or 1 16x16 sub-block + + # v1 + tensor_4bit = tensor_4bit.permute(0, 2, 1, 3).reshape(-1, 8) # 8 x 4-bit = 32-bit + tensor_4bit = tensor_4bit[:, [2, 6, 0, 4, 3, 7, 1, 5]] + + # v2. this is slower + # tensor_4bit = tensor_4bit[:, [1, 3, 0, 2]].permute(0, 2, 3, 1) + + tensor_2bit = _pack_2bit(tensor_2bit).flatten() + tensor_4bit = _pack_4bit(tensor_4bit).flatten() + + return torch.cat([tensor_2bit, tensor_4bit], dim=0) + + +def from_tc_float6_e3m2(tensor: Tensor, M: int, N: int, dtype: torch.dtype = torch.float32) -> Tensor: + assert tensor.ndim == 1 + assert (M % 64 == 0) and (N % 64 == 0) + size_2bit = M * N // 4 + size_4bit = M * N // 2 + assert tensor.numel() == size_2bit + size_4bit + + tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit]) + tensor_2bit = _unpack_2bit(tensor_2bit) + tensor_4bit = _unpack_4bit(tensor_4bit) + + tensor_2bit = tensor_2bit.view(-1, 16) + tensor_2bit = tensor_2bit[:, [4, 12, 0, 8, 5, 13, 1, 9, 6, 14, 2, 10, 7, 15, 3, 11]] + tensor_2bit = tensor_2bit.view(-1, 32, 8, 2).permute(0, 2, 1, 3).flatten() + + tensor_4bit = tensor_4bit.view(-1, 8) + tensor_4bit = tensor_4bit[:, [2, 6, 0, 4, 3, 7, 1, 5]] + tensor_4bit = tensor_4bit.view(-1, 32, 4, 2).permute(0, 2, 1, 3).flatten() + + tensor_fp6 = (tensor_2bit << 4) | tensor_4bit + tensor_fp6 = tensor_fp6.view(M // 64, N // 16, 4, 2, 2, 8, 8) + tensor_fp6 = tensor_fp6.permute(0, 2, 4, 5, 1, 3, 6) + tensor_fp6 = tensor_fp6.reshape(M, N) + + return from_float6_e3m2(tensor_fp6, no_bit_packing=True, dtype=dtype) + + +class Fp6LlmLinear(nn.Module): + def __init__(self, weight: Tensor, scales: Tensor, bias: Tensor | None = None): + super().__init__() + self.register_buffer("weight", weight) + self.register_buffer("scales", scales) + self.register_buffer("bias", bias) + + def forward(self, x: Tensor): + out = fp16act_fp6weight_linear(x, self.weight, self.scales, splitK=1) + if self.bias is not None: + out = out + self.bias.view(-1, 1) + return out + + @classmethod + def from_float(cls, linear: nn.Linear): + fp32_weight = linear.weight.detach().float() + scales = FLOAT6_E3M2_MAX / fp32_weight.amax(1) + tc_fp6_weight = to_tc_float6_e3m2(fp32_weight * scales.view(-1, 1)) + bias = linear.bias.detach() if linear.bias is not None else None + return cls(tc_fp6_weight, scales, bias) + + +def convert_fp6_llm(model: nn.Module, skip_fqn_list: list[str] | None = None, cur_fqn: str = "") -> None: + for name, child in model.named_children(): + new_fqn = name if cur_fqn == "" else f"{cur_fqn}.{name}" + + if ((skip_fqn_list is None) or (new_fqn not in skip_fqn_list)) and (isinstance(child, nn.Linear)): + new_child = Fp6LlmLinear.from_float(child) + setattr(model, name, new_child) + else: + convert_fp6_llm(child, skip_fqn_list, new_fqn) From 5bdcd501ef51bf94a36dcaf9510a311b890d8c16 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 26 May 2024 17:50:54 +0800 Subject: [PATCH 10/22] add test for linear --- test/quantization/test_fp6_llm.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/test/quantization/test_fp6_llm.py b/test/quantization/test_fp6_llm.py index 5a97ead824..9b27d3c8f7 100644 --- a/test/quantization/test_fp6_llm.py +++ b/test/quantization/test_fp6_llm.py @@ -1,3 +1,4 @@ +import pytest import torch from torch.testing._internal.common_utils import ( TestCase, @@ -6,14 +7,14 @@ run_tests, ) from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2 -from torchao.quantization.fp6_llm import to_tc_float6_e3m2, from_tc_float6_e3m2 +from torchao.quantization.fp6_llm import to_tc_float6_e3m2, from_tc_float6_e3m2, Fp6LlmLinear from torchao.ops import prepack_fp6_weight _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) -class TestTcFloat6E3M2(TestCase): +class TestFp6LlmLinear(TestCase): @parametrize("device", _DEVICES) def test_to_tc_float6_e3m2_correctness(self, device): x = torch.randn(256, 64, device=device) @@ -47,8 +48,31 @@ def test_from_tc_float6_e3m2_compile(self, device): actual = torch.compile(from_tc_float6_e3m2)(x, M, N) torch.testing.assert_close(actual, expected) + @pytest.mark.skipif(not torch.cuda.is_available(), "CUDA not available") + @parametrize("bias", [False, True]) + def test_fp6_llm_linear_forward(self, bias, device): + N, OC, IC = 4, 256, 64 + linear = torch.nn.Linear(IC, OC, bias=bias, device=device) + fp6_linear = Fp6LlmLinear.from_float(linear) -instantiate_parametrized_tests(TestTcFloat6E3M2) + x = torch.randn(N, IC) + fp6_linear(x) + + @pytest.mark.skipif(not torch.cuda.is_available(), "CUDA not available") + @parametrize("bias", [False, True]) + def test_fp6_llm_linear_compile(self, bias, device): + N, OC, IC = 4, 256, 64 + linear = torch.nn.Linear(IC, OC, bias=bias, device=device) + fp6_linear = Fp6LlmLinear.from_float(linear) + + x = torch.randn(N, IC) + expected = fp6_linear(x) + actual = torch.compile(fp6_linear)(x) + + torch.testing.assert_close(actual, expected) + + +instantiate_parametrized_tests(TestFp6LlmLinear) if __name__ == "__main__": From 708d48562ad4dbc9328b48d604bf831ad2a97637 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 26 May 2024 20:15:45 +0800 Subject: [PATCH 11/22] fix fp6 llm --- test/quantization/test_fp6_llm.py | 22 +++++++++++++--------- torchao/quantization/fp6_llm.py | 16 ++++++++++------ 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/test/quantization/test_fp6_llm.py b/test/quantization/test_fp6_llm.py index 9b27d3c8f7..8627149fe8 100644 --- a/test/quantization/test_fp6_llm.py +++ b/test/quantization/test_fp6_llm.py @@ -48,24 +48,28 @@ def test_from_tc_float6_e3m2_compile(self, device): actual = torch.compile(from_tc_float6_e3m2)(x, M, N) torch.testing.assert_close(actual, expected) - @pytest.mark.skipif(not torch.cuda.is_available(), "CUDA not available") + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @parametrize("bias", [False, True]) - def test_fp6_llm_linear_forward(self, bias, device): + def test_fp6_llm_linear_forward(self, bias): N, OC, IC = 4, 256, 64 - linear = torch.nn.Linear(IC, OC, bias=bias, device=device) - fp6_linear = Fp6LlmLinear.from_float(linear) + device = "cuda" + + fp16_linear = torch.nn.Linear(IC, OC, bias=bias, device=device) + fp6_linear = Fp6LlmLinear.from_float(fp16_linear) - x = torch.randn(N, IC) + x = torch.randn(N, IC, device=device) fp6_linear(x) - @pytest.mark.skipif(not torch.cuda.is_available(), "CUDA not available") + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @parametrize("bias", [False, True]) - def test_fp6_llm_linear_compile(self, bias, device): + def test_fp6_llm_linear_compile(self, bias): N, OC, IC = 4, 256, 64 + device = "cuda" + linear = torch.nn.Linear(IC, OC, bias=bias, device=device) fp6_linear = Fp6LlmLinear.from_float(linear) - - x = torch.randn(N, IC) + + x = torch.randn(N, IC, device=device) expected = fp6_linear(x) actual = torch.compile(fp6_linear)(x) diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py index f6de8f6a87..921ba027b3 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/quantization/fp6_llm.py @@ -142,18 +142,22 @@ def __init__(self, weight: Tensor, scales: Tensor, bias: Tensor | None = None): self.register_buffer("bias", bias) def forward(self, x: Tensor): - out = fp16act_fp6weight_linear(x, self.weight, self.scales, splitK=1) + out = fp16act_fp6weight_linear(x.half(), self.weight, self.scales, splitK=1) if self.bias is not None: - out = out + self.bias.view(-1, 1) + out = out + self.bias return out @classmethod def from_float(cls, linear: nn.Linear): fp32_weight = linear.weight.detach().float() - scales = FLOAT6_E3M2_MAX / fp32_weight.amax(1) - tc_fp6_weight = to_tc_float6_e3m2(fp32_weight * scales.view(-1, 1)) - bias = linear.bias.detach() if linear.bias is not None else None - return cls(tc_fp6_weight, scales, bias) + scales = fp32_weight.abs().amax(1) / FLOAT6_E3M2_MAX + scales[scales == 0.0] = 1.0 # avoid 0 scale + + tc_fp6_weight = to_tc_float6_e3m2(fp32_weight / scales.view(-1, 1)) + tc_fp6_weight = tc_fp6_weight.view(linear.out_features, -1).view(torch.int32) + + bias = linear.bias.detach().half() if linear.bias is not None else None + return cls(tc_fp6_weight, scales.half(), bias) def convert_fp6_llm(model: nn.Module, skip_fqn_list: list[str] | None = None, cur_fqn: str = "") -> None: From ce0ffc1016be63a18c30191c38e5724515fd7189 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 26 May 2024 20:27:23 +0800 Subject: [PATCH 12/22] switch to v2 since it's faster --- test/quantization/test_fp6_llm.py | 1 - torchao/quantization/fp6_llm.py | 35 ++++++++++--------------------- 2 files changed, 11 insertions(+), 25 deletions(-) diff --git a/test/quantization/test_fp6_llm.py b/test/quantization/test_fp6_llm.py index 8627149fe8..5bd1f445f7 100644 --- a/test/quantization/test_fp6_llm.py +++ b/test/quantization/test_fp6_llm.py @@ -72,7 +72,6 @@ def test_fp6_llm_linear_compile(self, bias): x = torch.randn(N, IC, device=device) expected = fp6_linear(x) actual = torch.compile(fp6_linear)(x) - torch.testing.assert_close(actual, expected) diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py index 921ba027b3..a4dd65bfff 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/quantization/fp6_llm.py @@ -83,25 +83,12 @@ def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: tensor_2bit = (tensor_fp6 >> 4) & 0b11 tensor_4bit = tensor_fp6 & 0b1111 - tensor_2bit = tensor_2bit.view(-1, 8, 32, 2) # 8 chunks of 8x8, or 2 16x16 sub-block - - # v1 - tensor_2bit = tensor_2bit.permute(0, 2, 1, 3).reshape(-1, 16) # 16 x 2-bit = 32-bit - tensor_2bit = tensor_2bit[:, [2, 6, 10, 14, 0, 4, 8, 12, 3, 7, 11, 15, 1, 5, 9, 13]] - - # v2. this is slower - # tensor_2bit = tensor_2bit[:, [1, 3, 5, 7, 0, 2, 4, 6]].permute(0, 2, 3, 1) - - tensor_4bit = tensor_4bit.view(-1, 4, 32, 2) # 4 chunks of 8x8, or 1 16x16 sub-block - - # v1 - tensor_4bit = tensor_4bit.permute(0, 2, 1, 3).reshape(-1, 8) # 8 x 4-bit = 32-bit - tensor_4bit = tensor_4bit[:, [2, 6, 0, 4, 3, 7, 1, 5]] - - # v2. this is slower - # tensor_4bit = tensor_4bit[:, [1, 3, 0, 2]].permute(0, 2, 3, 1) - + tensor_2bit = tensor_2bit.view(-1, 8, 32, 2) # 8 chunks of 8x8, or 2 16x16 sub-tile + tensor_2bit = tensor_2bit[:, [1, 3, 5, 7, 0, 2, 4, 6]].permute(0, 2, 3, 1) tensor_2bit = _pack_2bit(tensor_2bit).flatten() + + tensor_4bit = tensor_4bit.view(-1, 4, 32, 2) # 4 chunks of 8x8, or 1 16x16 sub-tile + tensor_4bit = tensor_4bit[:, [1, 3, 0, 2]].permute(0, 2, 3, 1) tensor_4bit = _pack_4bit(tensor_4bit).flatten() return torch.cat([tensor_2bit, tensor_4bit], dim=0) @@ -118,13 +105,13 @@ def from_tc_float6_e3m2(tensor: Tensor, M: int, N: int, dtype: torch.dtype = tor tensor_2bit = _unpack_2bit(tensor_2bit) tensor_4bit = _unpack_4bit(tensor_4bit) - tensor_2bit = tensor_2bit.view(-1, 16) - tensor_2bit = tensor_2bit[:, [4, 12, 0, 8, 5, 13, 1, 9, 6, 14, 2, 10, 7, 15, 3, 11]] - tensor_2bit = tensor_2bit.view(-1, 32, 8, 2).permute(0, 2, 1, 3).flatten() + tensor_2bit = tensor_2bit.view(-1, 8) + tensor_2bit = tensor_2bit[:, [4, 0, 5, 1, 6, 2, 7, 3]] + tensor_2bit = tensor_2bit.view(-1, 32, 2, 8).permute(0, 3, 1, 2).flatten() - tensor_4bit = tensor_4bit.view(-1, 8) - tensor_4bit = tensor_4bit[:, [2, 6, 0, 4, 3, 7, 1, 5]] - tensor_4bit = tensor_4bit.view(-1, 32, 4, 2).permute(0, 2, 1, 3).flatten() + tensor_4bit = tensor_4bit.view(-1, 4) + tensor_4bit = tensor_4bit[:, [2, 0, 3, 1]] + tensor_4bit = tensor_4bit.view(-1, 32, 2, 4).permute(0, 3, 1, 2).flatten() tensor_fp6 = (tensor_2bit << 4) | tensor_4bit tensor_fp6 = tensor_fp6.view(M // 64, N // 16, 4, 2, 2, 8, 8) From 59e39cefb0618a206de893bcb06b81229b5d0f7c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 26 May 2024 20:31:47 +0800 Subject: [PATCH 13/22] fix type hint for old python --- torchao/quantization/fp6_llm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py index a4dd65bfff..3548751d3c 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/quantization/fp6_llm.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from torch import nn, Tensor from torchao.dtypes.float6_e3m2 import FLOAT6_E3M2_MAX, to_float6_e3m2, from_float6_e3m2 @@ -122,7 +124,7 @@ def from_tc_float6_e3m2(tensor: Tensor, M: int, N: int, dtype: torch.dtype = tor class Fp6LlmLinear(nn.Module): - def __init__(self, weight: Tensor, scales: Tensor, bias: Tensor | None = None): + def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None): super().__init__() self.register_buffer("weight", weight) self.register_buffer("scales", scales) From cb43a7bebd791c22192cd0a5fd8d29826d4059ef Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 26 May 2024 20:48:00 +0800 Subject: [PATCH 14/22] simplify further --- torchao/quantization/fp6_llm.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py index 3548751d3c..3173ee4920 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/quantization/fp6_llm.py @@ -80,17 +80,18 @@ def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: # 64x64 tile, divided into 64x16 slices, and further divided into 8x8 chunks (for FP16 tensor cores) tensor_fp6 = tensor_fp6.view(M // 64, 4, 2, 8, N // 16, 2, 8) tensor_fp6 = tensor_fp6.permute(0, 4, 1, 5, 2, 3, 6) - tensor_fp6 = tensor_fp6.reshape(-1, 32, 2) tensor_2bit = (tensor_fp6 >> 4) & 0b11 tensor_4bit = tensor_fp6 & 0b1111 - tensor_2bit = tensor_2bit.view(-1, 8, 32, 2) # 8 chunks of 8x8, or 2 16x16 sub-tile - tensor_2bit = tensor_2bit[:, [1, 3, 5, 7, 0, 2, 4, 6]].permute(0, 2, 3, 1) + # 8 chunks of 8x8, or 2 16x16 sub-tile + tensor_2bit = tensor_2bit.reshape(-1, 8, 64) # trigger a copy here + tensor_2bit = tensor_2bit[:, [1, 3, 5, 7, 0, 2, 4, 6]].permute(0, 2, 1) tensor_2bit = _pack_2bit(tensor_2bit).flatten() - tensor_4bit = tensor_4bit.view(-1, 4, 32, 2) # 4 chunks of 8x8, or 1 16x16 sub-tile - tensor_4bit = tensor_4bit[:, [1, 3, 0, 2]].permute(0, 2, 3, 1) + # 4 chunks of 8x8, or 1 16x16 sub-tile + tensor_4bit = tensor_4bit.reshape(-1, 4, 64) # trigger a copy here + tensor_4bit = tensor_4bit[:, [1, 3, 0, 2]].permute(0, 2, 1) tensor_4bit = _pack_4bit(tensor_4bit).flatten() return torch.cat([tensor_2bit, tensor_4bit], dim=0) @@ -109,11 +110,11 @@ def from_tc_float6_e3m2(tensor: Tensor, M: int, N: int, dtype: torch.dtype = tor tensor_2bit = tensor_2bit.view(-1, 8) tensor_2bit = tensor_2bit[:, [4, 0, 5, 1, 6, 2, 7, 3]] - tensor_2bit = tensor_2bit.view(-1, 32, 2, 8).permute(0, 3, 1, 2).flatten() + tensor_2bit = tensor_2bit.view(-1, 64, 8).permute(0, 2, 1).flatten() tensor_4bit = tensor_4bit.view(-1, 4) tensor_4bit = tensor_4bit[:, [2, 0, 3, 1]] - tensor_4bit = tensor_4bit.view(-1, 32, 2, 4).permute(0, 3, 1, 2).flatten() + tensor_4bit = tensor_4bit.view(-1, 64, 4).permute(0, 2, 1).flatten() tensor_fp6 = (tensor_2bit << 4) | tensor_4bit tensor_fp6 = tensor_fp6.view(M // 64, N // 16, 4, 2, 2, 8, 8) From b90938ac45fffcd053dc7b0d9d0dff460b545e85 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 26 May 2024 20:48:54 +0800 Subject: [PATCH 15/22] fix typing for old python --- torchao/quantization/fp6_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py index 3173ee4920..b8d4d53746 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/quantization/fp6_llm.py @@ -150,7 +150,7 @@ def from_float(cls, linear: nn.Linear): return cls(tc_fp6_weight, scales.half(), bias) -def convert_fp6_llm(model: nn.Module, skip_fqn_list: list[str] | None = None, cur_fqn: str = "") -> None: +def convert_fp6_llm(model: nn.Module, skip_fqn_list: Optional[list[str]] = None, cur_fqn: str = "") -> None: for name, child in model.named_children(): new_fqn = name if cur_fqn == "" else f"{cur_fqn}.{name}" From 66cbe1df8fd83a94a99368d4446aa5c9e341108e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 26 May 2024 21:00:32 +0800 Subject: [PATCH 16/22] add test --- test/quantization/test_fp6_llm.py | 21 ++++++++++++++++++--- torchao/quantization/fp6_llm.py | 3 +++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/test/quantization/test_fp6_llm.py b/test/quantization/test_fp6_llm.py index 5bd1f445f7..075c62a0c9 100644 --- a/test/quantization/test_fp6_llm.py +++ b/test/quantization/test_fp6_llm.py @@ -1,5 +1,6 @@ import pytest import torch +from torch import nn from torch.testing._internal.common_utils import ( TestCase, instantiate_parametrized_tests, @@ -7,7 +8,7 @@ run_tests, ) from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2 -from torchao.quantization.fp6_llm import to_tc_float6_e3m2, from_tc_float6_e3m2, Fp6LlmLinear +from torchao.quantization.fp6_llm import to_tc_float6_e3m2, from_tc_float6_e3m2, Fp6LlmLinear, convert_fp6_llm from torchao.ops import prepack_fp6_weight @@ -54,8 +55,9 @@ def test_fp6_llm_linear_forward(self, bias): N, OC, IC = 4, 256, 64 device = "cuda" - fp16_linear = torch.nn.Linear(IC, OC, bias=bias, device=device) - fp6_linear = Fp6LlmLinear.from_float(fp16_linear) + linear = torch.nn.Linear(IC, OC, bias=bias, device=device) + fp6_linear = Fp6LlmLinear.from_float(linear) + assert (fp6_linear.bias is not None) == bias x = torch.randn(N, IC, device=device) fp6_linear(x) @@ -74,6 +76,19 @@ def test_fp6_llm_linear_compile(self, bias): actual = torch.compile(fp6_linear)(x) torch.testing.assert_close(actual, expected) + def test_convert_fp6_llm(self): + device = "cuda" + model = nn.Sequential(nn.Linear(64, 256, bias=False), nn.Linear(256, 256)).to(device) + convert_fp6_llm(model) + + assert isinstance(model[0], Fp6LlmLinear) + assert model[0].bias is None + assert isinstance(model[1], Fp6LlmLinear) + assert model[1].bias is not None + + x = torch.randn(4, 64, device=device) + model(x) + instantiate_parametrized_tests(TestFp6LlmLinear) diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py index b8d4d53746..69261bbd2f 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/quantization/fp6_llm.py @@ -125,6 +125,9 @@ def from_tc_float6_e3m2(tensor: Tensor, M: int, N: int, dtype: torch.dtype = tor class Fp6LlmLinear(nn.Module): + """FP6-LLM Linear layer as described in https://arxiv.org/pdf/2401.14112. + """ + def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None): super().__init__() self.register_buffer("weight", weight) From 5ed6767ef13f4853d1e74fd31f123b82511360b2 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 26 May 2024 22:27:13 +0800 Subject: [PATCH 17/22] eliminate indexing.faster on CUDA --- torchao/quantization/fp6_llm.py | 41 +++++++++++---------------------- 1 file changed, 14 insertions(+), 27 deletions(-) diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py index 69261bbd2f..1e34eddce6 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/quantization/fp6_llm.py @@ -75,24 +75,16 @@ def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: assert (M % 64 == 0) and (N % 64 == 0) tensor_fp6 = to_float6_e3m2(tensor, no_bit_packing=True) - - # Section 5.2, Figure 5. - # 64x64 tile, divided into 64x16 slices, and further divided into 8x8 chunks (for FP16 tensor cores) - tensor_fp6 = tensor_fp6.view(M // 64, 4, 2, 8, N // 16, 2, 8) - tensor_fp6 = tensor_fp6.permute(0, 4, 1, 5, 2, 3, 6) + tensor_fp6 = tensor_fp6.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8) + tensor_fp6 = tensor_fp6.flip(3) tensor_2bit = (tensor_fp6 >> 4) & 0b11 - tensor_4bit = tensor_fp6 & 0b1111 + tensor_2bit = tensor_2bit.permute(0, 5, 1, 4, 7, 3, 2, 6) + tensor_2bit = _pack_2bit(tensor_2bit.flatten()) - # 8 chunks of 8x8, or 2 16x16 sub-tile - tensor_2bit = tensor_2bit.reshape(-1, 8, 64) # trigger a copy here - tensor_2bit = tensor_2bit[:, [1, 3, 5, 7, 0, 2, 4, 6]].permute(0, 2, 1) - tensor_2bit = _pack_2bit(tensor_2bit).flatten() - - # 4 chunks of 8x8, or 1 16x16 sub-tile - tensor_4bit = tensor_4bit.reshape(-1, 4, 64) # trigger a copy here - tensor_4bit = tensor_4bit[:, [1, 3, 0, 2]].permute(0, 2, 1) - tensor_4bit = _pack_4bit(tensor_4bit).flatten() + tensor_4bit = tensor_fp6 & 0b1111 + tensor_4bit = tensor_4bit.permute(0, 5, 1, 2, 4, 7, 3, 6) + tensor_4bit = _pack_4bit(tensor_4bit.flatten()) return torch.cat([tensor_2bit, tensor_4bit], dim=0) @@ -105,22 +97,17 @@ def from_tc_float6_e3m2(tensor: Tensor, M: int, N: int, dtype: torch.dtype = tor assert tensor.numel() == size_2bit + size_4bit tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit]) - tensor_2bit = _unpack_2bit(tensor_2bit) - tensor_4bit = _unpack_4bit(tensor_4bit) - tensor_2bit = tensor_2bit.view(-1, 8) - tensor_2bit = tensor_2bit[:, [4, 0, 5, 1, 6, 2, 7, 3]] - tensor_2bit = tensor_2bit.view(-1, 64, 8).permute(0, 2, 1).flatten() + tensor_2bit = _unpack_2bit(tensor_2bit) + tensor_2bit = tensor_2bit.view(M // 64, N // 16, 2, 8, 8, 2, 2, 2) + tensor_2bit = tensor_2bit.permute(0, 2, 6, 5, 3, 1, 7, 4) - tensor_4bit = tensor_4bit.view(-1, 4) - tensor_4bit = tensor_4bit[:, [2, 0, 3, 1]] - tensor_4bit = tensor_4bit.view(-1, 64, 4).permute(0, 2, 1).flatten() + tensor_4bit = _unpack_4bit(tensor_4bit) + tensor_4bit = tensor_4bit.view(M // 64, N // 16, 2, 2, 8, 8, 2, 2) + tensor_4bit = tensor_4bit.permute(0, 2, 3, 6, 4, 1, 7, 5) tensor_fp6 = (tensor_2bit << 4) | tensor_4bit - tensor_fp6 = tensor_fp6.view(M // 64, N // 16, 4, 2, 2, 8, 8) - tensor_fp6 = tensor_fp6.permute(0, 2, 4, 5, 1, 3, 6) - tensor_fp6 = tensor_fp6.reshape(M, N) - + tensor_fp6 = tensor_fp6.flip(3).reshape(M, N) return from_float6_e3m2(tensor_fp6, no_bit_packing=True, dtype=dtype) From fa08a3d3aa7c0090ed37956acaeb301915ea86ab Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 26 May 2024 22:31:33 +0800 Subject: [PATCH 18/22] skip fp6_llm on cpu --- test/quantization/test_fp6_llm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/quantization/test_fp6_llm.py b/test/quantization/test_fp6_llm.py index 075c62a0c9..27da3c00f0 100644 --- a/test/quantization/test_fp6_llm.py +++ b/test/quantization/test_fp6_llm.py @@ -76,6 +76,7 @@ def test_fp6_llm_linear_compile(self, bias): actual = torch.compile(fp6_linear)(x) torch.testing.assert_close(actual, expected) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_convert_fp6_llm(self): device = "cuda" model = nn.Sequential(nn.Linear(64, 256, bias=False), nn.Linear(256, 256)).to(device) From 694549866dffdb4cf79a9c455b7484b6a4365740 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 26 May 2024 23:43:41 +0800 Subject: [PATCH 19/22] improve error message --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 4 ++-- torchao/quantization/fp6_llm.py | 12 ++++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 51413a0874..30b0978a1a 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -144,8 +144,8 @@ torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats, int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); int num_out_channels = _weights.size(0); - assert( num_in_channels%64 == 0 ); - assert( (num_in_channels/16*3) == _weights.size(1) ); // Making sure the K dimension is matched. + TORCH_CHECK(num_in_channels%64 == 0, "Expected in_features to be a multiple of 64, but received ", num_in_channels); + TORCH_CHECK((num_in_channels/16*3) == _weights.size(1)); // Making sure the K dimension is matched. // int M = num_out_channels; int K = num_in_channels; diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py index 1e34eddce6..8ebcbb8bea 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/quantization/fp6_llm.py @@ -120,6 +120,8 @@ def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None self.register_buffer("weight", weight) self.register_buffer("scales", scales) self.register_buffer("bias", bias) + self.out_features = weight.shape[0] + self.in_features = weight.shape[1] * 16 // 3 def forward(self, x: Tensor): out = fp16act_fp6weight_linear(x.half(), self.weight, self.scales, splitK=1) @@ -129,6 +131,8 @@ def forward(self, x: Tensor): @classmethod def from_float(cls, linear: nn.Linear): + assert (linear.in_features % 64 == 0) and (linear.out_features % 256 == 0) + fp32_weight = linear.weight.detach().float() scales = fp32_weight.abs().amax(1) / FLOAT6_E3M2_MAX scales[scales == 0.0] = 1.0 # avoid 0 scale @@ -139,13 +143,17 @@ def from_float(cls, linear: nn.Linear): bias = linear.bias.detach().half() if linear.bias is not None else None return cls(tc_fp6_weight, scales.half(), bias) + def extra_repr(self) -> str: + return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}' + def convert_fp6_llm(model: nn.Module, skip_fqn_list: Optional[list[str]] = None, cur_fqn: str = "") -> None: for name, child in model.named_children(): new_fqn = name if cur_fqn == "" else f"{cur_fqn}.{name}" if ((skip_fqn_list is None) or (new_fqn not in skip_fqn_list)) and (isinstance(child, nn.Linear)): - new_child = Fp6LlmLinear.from_float(child) - setattr(model, name, new_child) + if (child.in_features % 64 == 0) and (child.out_features % 256 == 0): + new_child = Fp6LlmLinear.from_float(child) + setattr(model, name, new_child) else: convert_fp6_llm(child, skip_fqn_list, new_fqn) From 70b5a4c69e198e6e6d316f0674c721fdc8e62553 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 26 May 2024 23:49:10 +0800 Subject: [PATCH 20/22] add support for extra batch dims --- test/quantization/test_fp6_llm.py | 7 ++++--- torchao/quantization/fp6_llm.py | 8 ++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/test/quantization/test_fp6_llm.py b/test/quantization/test_fp6_llm.py index 27da3c00f0..983126d7f3 100644 --- a/test/quantization/test_fp6_llm.py +++ b/test/quantization/test_fp6_llm.py @@ -50,16 +50,17 @@ def test_from_tc_float6_e3m2_compile(self, device): torch.testing.assert_close(actual, expected) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @parametrize("leading_dims", [(4,), (2, 4)]) @parametrize("bias", [False, True]) - def test_fp6_llm_linear_forward(self, bias): - N, OC, IC = 4, 256, 64 + def test_fp6_llm_linear_forward(self, bias, leading_dims): + OC, IC = 256, 64 device = "cuda" linear = torch.nn.Linear(IC, OC, bias=bias, device=device) fp6_linear = Fp6LlmLinear.from_float(linear) assert (fp6_linear.bias is not None) == bias - x = torch.randn(N, IC, device=device) + x = torch.randn(*leading_dims, IC, device=device) fp6_linear(x) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py index 8ebcbb8bea..36285f2536 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/quantization/fp6_llm.py @@ -115,7 +115,7 @@ class Fp6LlmLinear(nn.Module): """FP6-LLM Linear layer as described in https://arxiv.org/pdf/2401.14112. """ - def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None): + def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None) -> None: super().__init__() self.register_buffer("weight", weight) self.register_buffer("scales", scales) @@ -123,11 +123,11 @@ def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None self.out_features = weight.shape[0] self.in_features = weight.shape[1] * 16 // 3 - def forward(self, x: Tensor): - out = fp16act_fp6weight_linear(x.half(), self.weight, self.scales, splitK=1) + def forward(self, x: Tensor) -> Tensor: + out = fp16act_fp6weight_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=1) if self.bias is not None: out = out + self.bias - return out + return out.view(*x.shape[:-1], self.out_features) @classmethod def from_float(cls, linear: nn.Linear): From d6c6b6aeffc72732837dec3a0c2f41d3eba7ac1d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 26 May 2024 23:58:29 +0800 Subject: [PATCH 21/22] cast output to original dtype --- torchao/quantization/fp6_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py index 36285f2536..fd79ef98a1 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/quantization/fp6_llm.py @@ -127,7 +127,7 @@ def forward(self, x: Tensor) -> Tensor: out = fp16act_fp6weight_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=1) if self.bias is not None: out = out + self.bias - return out.view(*x.shape[:-1], self.out_features) + return out.view(*x.shape[:-1], self.out_features).to(x.dtype) @classmethod def from_float(cls, linear: nn.Linear): From d798eaf3a82e638a7840b68ef3e8416d904d4064 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 27 May 2024 00:28:48 +0800 Subject: [PATCH 22/22] fix precision error due to dtype --- test/quantization/test_fp6_llm.py | 4 ++-- torchao/quantization/fp6_llm.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/quantization/test_fp6_llm.py b/test/quantization/test_fp6_llm.py index 983126d7f3..635f78765c 100644 --- a/test/quantization/test_fp6_llm.py +++ b/test/quantization/test_fp6_llm.py @@ -60,7 +60,7 @@ def test_fp6_llm_linear_forward(self, bias, leading_dims): fp6_linear = Fp6LlmLinear.from_float(linear) assert (fp6_linear.bias is not None) == bias - x = torch.randn(*leading_dims, IC, device=device) + x = torch.randn(*leading_dims, IC, device=device, dtype=torch.half) fp6_linear(x) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -72,7 +72,7 @@ def test_fp6_llm_linear_compile(self, bias): linear = torch.nn.Linear(IC, OC, bias=bias, device=device) fp6_linear = Fp6LlmLinear.from_float(linear) - x = torch.randn(N, IC, device=device) + x = torch.randn(N, IC, device=device, dtype=torch.half) expected = fp6_linear(x) actual = torch.compile(fp6_linear)(x) torch.testing.assert_close(actual, expected) diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py index fd79ef98a1..9f559d4164 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/quantization/fp6_llm.py @@ -124,6 +124,7 @@ def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None self.in_features = weight.shape[1] * 16 // 3 def forward(self, x: Tensor) -> Tensor: + # TODO: splitK map out = fp16act_fp6weight_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=1) if self.bias is not None: out = out + self.bias