diff --git a/test/dtypes/test_float6_e3m2.py b/test/dtypes/test_float6_e3m2.py index b821504731..c3365cffeb 100644 --- a/test/dtypes/test_float6_e3m2.py +++ b/test/dtypes/test_float6_e3m2.py @@ -12,7 +12,7 @@ _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) -class TestFp6(TestCase): +class TestFloat6E3M2(TestCase): @parametrize("device", _DEVICES) @parametrize("dtype", _DTYPES) @@ -120,7 +120,7 @@ def test_from_float6_e3m2_compile(self, device, no_bit_packing): torch.testing.assert_close(actual, expected) -instantiate_parametrized_tests(TestFp6) +instantiate_parametrized_tests(TestFloat6E3M2) if __name__ == "__main__": diff --git a/test/quantization/test_fp6_llm.py b/test/quantization/test_fp6_llm.py new file mode 100644 index 0000000000..635f78765c --- /dev/null +++ b/test/quantization/test_fp6_llm.py @@ -0,0 +1,99 @@ +import pytest +import torch +from torch import nn +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) +from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2 +from torchao.quantization.fp6_llm import to_tc_float6_e3m2, from_tc_float6_e3m2, Fp6LlmLinear, convert_fp6_llm +from torchao.ops import prepack_fp6_weight + + +_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + + +class TestFp6LlmLinear(TestCase): + @parametrize("device", _DEVICES) + def test_to_tc_float6_e3m2_correctness(self, device): + x = torch.randn(256, 64, device=device) + + expected = prepack_fp6_weight(to_float6_e3m2(x.cpu()).view(torch.int32)).view(torch.uint8) + actual = to_tc_float6_e3m2(x) + torch.testing.assert_close(actual.view(-1).cpu(), expected.view(-1)) + + @parametrize("device", _DEVICES) + def test_to_tc_float6_e3m2_compile(self, device): + x = torch.randn(256, 64, device=device) + + expected = to_tc_float6_e3m2(x) + actual = torch.compile(to_tc_float6_e3m2)(x) + torch.testing.assert_close(actual, expected) + + @parametrize("device", _DEVICES) + def test_from_tc_float6_e3m2_correctness(self, device): + x = torch.randn(256, 64, device=device) + x = from_float6_e3m2(to_float6_e3m2(x)) # quantize and dequantize so that the values are exactly representable in FP6 + + actual = from_tc_float6_e3m2(to_tc_float6_e3m2(x), *x.shape) + torch.testing.assert_close(actual, x) + + @parametrize("device", _DEVICES) + def test_from_tc_float6_e3m2_compile(self, device): + M, N = 256, 64 + x = torch.randint(256, size=(M * N * 3 // 4,), dtype=torch.uint8, device=device) + + expected = from_tc_float6_e3m2(x, M, N) + actual = torch.compile(from_tc_float6_e3m2)(x, M, N) + torch.testing.assert_close(actual, expected) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @parametrize("leading_dims", [(4,), (2, 4)]) + @parametrize("bias", [False, True]) + def test_fp6_llm_linear_forward(self, bias, leading_dims): + OC, IC = 256, 64 + device = "cuda" + + linear = torch.nn.Linear(IC, OC, bias=bias, device=device) + fp6_linear = Fp6LlmLinear.from_float(linear) + assert (fp6_linear.bias is not None) == bias + + x = torch.randn(*leading_dims, IC, device=device, dtype=torch.half) + fp6_linear(x) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @parametrize("bias", [False, True]) + def test_fp6_llm_linear_compile(self, bias): + N, OC, IC = 4, 256, 64 + device = "cuda" + + linear = torch.nn.Linear(IC, OC, bias=bias, device=device) + fp6_linear = Fp6LlmLinear.from_float(linear) + + x = torch.randn(N, IC, device=device, dtype=torch.half) + expected = fp6_linear(x) + actual = torch.compile(fp6_linear)(x) + torch.testing.assert_close(actual, expected) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_convert_fp6_llm(self): + device = "cuda" + model = nn.Sequential(nn.Linear(64, 256, bias=False), nn.Linear(256, 256)).to(device) + convert_fp6_llm(model) + + assert isinstance(model[0], Fp6LlmLinear) + assert model[0].bias is None + assert isinstance(model[1], Fp6LlmLinear) + assert model[1].bias is not None + + x = torch.randn(4, 64, device=device) + model(x) + + +instantiate_parametrized_tests(TestFp6LlmLinear) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 51413a0874..30b0978a1a 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -144,8 +144,8 @@ torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats, int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); int num_out_channels = _weights.size(0); - assert( num_in_channels%64 == 0 ); - assert( (num_in_channels/16*3) == _weights.size(1) ); // Making sure the K dimension is matched. + TORCH_CHECK(num_in_channels%64 == 0, "Expected in_features to be a multiple of 64, but received ", num_in_channels); + TORCH_CHECK((num_in_channels/16*3) == _weights.size(1)); // Making sure the K dimension is matched. // int M = num_out_channels; int K = num_in_channels; diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py new file mode 100644 index 0000000000..9f559d4164 --- /dev/null +++ b/torchao/quantization/fp6_llm.py @@ -0,0 +1,160 @@ +from typing import Optional + +import torch +from torch import nn, Tensor +from torchao.dtypes.float6_e3m2 import FLOAT6_E3M2_MAX, to_float6_e3m2, from_float6_e3m2 +from torchao.ops import fp16act_fp6weight_linear + + +def _pack_2bit(x: Tensor) -> Tensor: + return (x[..., ::4] << 6) | (x[..., 1::4] << 4) | (x[..., 2::4] << 2) | x[..., 3::4] + + +def _unpack_2bit(x: Tensor) -> Tensor: + return torch.stack([x >> 6, (x >> 4) & 0b11, (x >> 2) & 0b11, x & 0b11], dim=-1).flatten(-2) + + +def _pack_4bit(x: Tensor) -> Tensor: + return (x[..., ::2] << 4) | x[..., 1::2] + + +def _unpack_4bit(x: Tensor) -> Tensor: + return torch.stack([x >> 4, x & 0b1111], dim=-1).flatten(-2) + + +# this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing +# https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h +def _to_tc_float6_e3m2_original(tensor: Tensor) -> Tensor: + assert tensor.ndim == 2 + M, N = tensor.shape + assert (M % 64 == 0) and (N % 64 == 0) + + tensor_fp6 = to_float6_e3m2(tensor, no_bit_packing=True) + + # Pass 1 from original code + tensor_fp6 = tensor_fp6.view(M // 64, 4, 2, 8, N // 16, 2, 8) + tensor_fp6 = tensor_fp6.permute(0, 4, 1, 5, 2, 3, 6) + tensor_fp6 = tensor_fp6.reshape(-1, 32, 2) + tensor_fp6 = tensor_fp6.permute(1, 0, 2) + tensor_fp6 = tensor_fp6.flatten() + + tensor_2bit = _pack_2bit((tensor_fp6 >> 4) & 0b11) + tensor_4bit = _pack_4bit(tensor_fp6 & 0b1111) + + # Pass 2 from original code + tensor_2bit = tensor_2bit.view(32, -1, 4).permute(1, 0, 2).flip(2) + tensor_4bit = tensor_4bit.view(32, -1, 4).permute(1, 0, 2).flip(2) + + # Pass 3 from original code + # BitInterleaving_2bit + # the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32 + # while we still unpack/pack the values from/to uint8 + tensor_2bit = _unpack_2bit(tensor_2bit).view(-1, 16) + tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] + tensor_2bit = tensor_2bit[:, [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14]] + tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] + tensor_2bit = _pack_2bit(tensor_2bit).view(-1) + + # BitInterleaving_4bit + # the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32 + # while we still unpack/pack the values from/to uint8 + tensor_4bit = _unpack_4bit(tensor_4bit).view(-1, 8) + tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] + tensor_4bit = tensor_4bit[:, [1, 5, 3, 7, 0, 4, 2, 6]] + tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] + tensor_4bit = _pack_4bit(tensor_4bit).view(-1) + + return torch.cat([tensor_2bit, tensor_4bit], dim=0) + + +# more optimized version of _to_tc_float6_e3m2_original() by merging ops +# https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h +def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: + assert tensor.ndim == 2 + M, N = tensor.shape + assert (M % 64 == 0) and (N % 64 == 0) + + tensor_fp6 = to_float6_e3m2(tensor, no_bit_packing=True) + tensor_fp6 = tensor_fp6.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8) + tensor_fp6 = tensor_fp6.flip(3) + + tensor_2bit = (tensor_fp6 >> 4) & 0b11 + tensor_2bit = tensor_2bit.permute(0, 5, 1, 4, 7, 3, 2, 6) + tensor_2bit = _pack_2bit(tensor_2bit.flatten()) + + tensor_4bit = tensor_fp6 & 0b1111 + tensor_4bit = tensor_4bit.permute(0, 5, 1, 2, 4, 7, 3, 6) + tensor_4bit = _pack_4bit(tensor_4bit.flatten()) + + return torch.cat([tensor_2bit, tensor_4bit], dim=0) + + +def from_tc_float6_e3m2(tensor: Tensor, M: int, N: int, dtype: torch.dtype = torch.float32) -> Tensor: + assert tensor.ndim == 1 + assert (M % 64 == 0) and (N % 64 == 0) + size_2bit = M * N // 4 + size_4bit = M * N // 2 + assert tensor.numel() == size_2bit + size_4bit + + tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit]) + + tensor_2bit = _unpack_2bit(tensor_2bit) + tensor_2bit = tensor_2bit.view(M // 64, N // 16, 2, 8, 8, 2, 2, 2) + tensor_2bit = tensor_2bit.permute(0, 2, 6, 5, 3, 1, 7, 4) + + tensor_4bit = _unpack_4bit(tensor_4bit) + tensor_4bit = tensor_4bit.view(M // 64, N // 16, 2, 2, 8, 8, 2, 2) + tensor_4bit = tensor_4bit.permute(0, 2, 3, 6, 4, 1, 7, 5) + + tensor_fp6 = (tensor_2bit << 4) | tensor_4bit + tensor_fp6 = tensor_fp6.flip(3).reshape(M, N) + return from_float6_e3m2(tensor_fp6, no_bit_packing=True, dtype=dtype) + + +class Fp6LlmLinear(nn.Module): + """FP6-LLM Linear layer as described in https://arxiv.org/pdf/2401.14112. + """ + + def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None) -> None: + super().__init__() + self.register_buffer("weight", weight) + self.register_buffer("scales", scales) + self.register_buffer("bias", bias) + self.out_features = weight.shape[0] + self.in_features = weight.shape[1] * 16 // 3 + + def forward(self, x: Tensor) -> Tensor: + # TODO: splitK map + out = fp16act_fp6weight_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=1) + if self.bias is not None: + out = out + self.bias + return out.view(*x.shape[:-1], self.out_features).to(x.dtype) + + @classmethod + def from_float(cls, linear: nn.Linear): + assert (linear.in_features % 64 == 0) and (linear.out_features % 256 == 0) + + fp32_weight = linear.weight.detach().float() + scales = fp32_weight.abs().amax(1) / FLOAT6_E3M2_MAX + scales[scales == 0.0] = 1.0 # avoid 0 scale + + tc_fp6_weight = to_tc_float6_e3m2(fp32_weight / scales.view(-1, 1)) + tc_fp6_weight = tc_fp6_weight.view(linear.out_features, -1).view(torch.int32) + + bias = linear.bias.detach().half() if linear.bias is not None else None + return cls(tc_fp6_weight, scales.half(), bias) + + def extra_repr(self) -> str: + return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}' + + +def convert_fp6_llm(model: nn.Module, skip_fqn_list: Optional[list[str]] = None, cur_fqn: str = "") -> None: + for name, child in model.named_children(): + new_fqn = name if cur_fqn == "" else f"{cur_fqn}.{name}" + + if ((skip_fqn_list is None) or (new_fqn not in skip_fqn_list)) and (isinstance(child, nn.Linear)): + if (child.in_features % 64 == 0) and (child.out_features % 256 == 0): + new_child = Fp6LlmLinear.from_float(child) + setattr(model, name, new_child) + else: + convert_fp6_llm(child, skip_fqn_list, new_fqn)