|
4 | 4 | import torchao
|
5 | 5 | from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
|
6 | 6 | import unittest
|
| 7 | +from parameterized import parameterized |
7 | 8 |
|
8 | 9 |
|
9 | 10 | # torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...):
|
@@ -42,6 +43,98 @@ def test_nms(self):
|
42 | 43 | test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
|
43 | 44 | opcheck(torch.ops.torchao.nms, (boxes, scores, iou), test_utils=test_utils)
|
44 | 45 |
|
| 46 | + def _create_fp6_inputs(self, BS: int, OC: int, IC: int): |
| 47 | + # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. |
| 48 | + fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int) |
| 49 | + fp16_scale = torch.rand(OC).half() + 0.5 |
| 50 | + fp16_activation = torch.rand(BS, IC).half() + 0.5 |
| 51 | + return fp6_weight, fp16_scale, fp16_activation |
| 52 | + |
| 53 | + def test_prepack_fp6_weight(self): |
| 54 | + OC = 256 |
| 55 | + IC = 256 |
| 56 | + fp6_weight, _, _ = self._create_fp6_inputs(0, OC, IC) |
| 57 | + |
| 58 | + # smoke test |
| 59 | + torchao.ops.prepack_fp6_weight(fp6_weight) |
| 60 | + |
| 61 | + # comprehensive testing |
| 62 | + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] |
| 63 | + opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils) |
| 64 | + |
| 65 | + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") |
| 66 | + def test_fp16_to_fp6(self): |
| 67 | + OC = 256 |
| 68 | + IC = 256 |
| 69 | + |
| 70 | + # in this fp6, we use 3 bits for exponent and 2 bits for mantissa |
| 71 | + # also, we don't have nan/inf |
| 72 | + fp6_absmax = 28.0 # 2 ** (0b111 - 0b011) * (1 + 0.5 + 0.25), where E=111, M=11 |
| 73 | + fp6_absmin = 0.0625 # 2 ** (-0b010) * 0.25, where E=000, M=01 (subnormal number) |
| 74 | + fp16_weight = torch.randn((OC, IC), dtype=torch.float16) |
| 75 | + fp16_weight.clip_(-fp6_absmax, fp6_absmax) |
| 76 | + fp16_weight[fp16_weight.abs() < fp6_absmin] = 0 |
| 77 | + |
| 78 | + # smoke test |
| 79 | + torchao.ops.fp16_to_fp6(fp16_weight) |
| 80 | + |
| 81 | + # comprehensive testing |
| 82 | + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] |
| 83 | + opcheck(torch.ops.torchao.fp16_to_fp6, (fp16_weight,), test_utils=test_utils) |
| 84 | + |
| 85 | + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") |
| 86 | + def test_fp16act_fp6weight_linear(self): |
| 87 | + BS = 2 |
| 88 | + OC = 256 |
| 89 | + IC = 256 |
| 90 | + splitK = 1 |
| 91 | + fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC) |
| 92 | + |
| 93 | + fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight) |
| 94 | + act_cuda = fp16_activation.cuda() |
| 95 | + weight_cuda = fp6_weight_packed.cuda() |
| 96 | + scale_cuda = fp16_scale.cuda() |
| 97 | + |
| 98 | + # smoke test |
| 99 | + torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK) |
| 100 | + |
| 101 | + # comprehensive testing |
| 102 | + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] |
| 103 | + opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (act_cuda, weight_cuda, scale_cuda, splitK), test_utils=test_utils) |
| 104 | + |
| 105 | + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") |
| 106 | + def test_fp6_weight_dequant(self): |
| 107 | + OC = 256 |
| 108 | + IC = 256 |
| 109 | + fp6_weight, fp16_scale, _ = self._create_fp6_inputs(0, OC, IC) |
| 110 | + |
| 111 | + # smoke test |
| 112 | + torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale) |
| 113 | + |
| 114 | + # comprehensive testing |
| 115 | + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] |
| 116 | + opcheck(torch.ops.torchao.fp6_weight_dequant, (fp6_weight, fp16_scale), test_utils=test_utils) |
| 117 | + |
| 118 | + # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py |
| 119 | + @parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) |
| 120 | + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") |
| 121 | + def test_fp6_matmul_correctness(self, BS, OC, IC, splitK): |
| 122 | + fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC) |
| 123 | + |
| 124 | + fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight) |
| 125 | + act_cuda = fp16_activation.cuda() |
| 126 | + weight_cuda = fp6_weight_packed.cuda() |
| 127 | + scale_cuda = fp16_scale.cuda() |
| 128 | + |
| 129 | + results_fp6 = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK) |
| 130 | + |
| 131 | + fp16_weight = torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale).cuda() |
| 132 | + results_fp16 = act_cuda @ fp16_weight.T |
| 133 | + |
| 134 | + error = (results_fp6 - results_fp16).abs() |
| 135 | + relative_error = error / results_fp16.abs() |
| 136 | + assert relative_error.mean() < 1e-2 |
| 137 | + |
45 | 138 |
|
46 | 139 | if __name__ == "__main__":
|
47 | 140 | unittest.main()
|
0 commit comments