|
11 | 11 | from torchao.prototype.quant_llm import (
|
12 | 12 | QuantLlmLinearWeight,
|
13 | 13 | quant_llm_fpx_weight_only,
|
| 14 | + fp6_llm_weight_only, |
14 | 15 | to_scaled_tc_fpx,
|
15 | 16 | from_scaled_tc_fpx,
|
16 | 17 | )
|
@@ -65,6 +66,15 @@ def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device):
|
65 | 66 | actual = torch.compile(from_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits, scale)
|
66 | 67 | torch.testing.assert_close(actual, expected)
|
67 | 68 |
|
| 69 | + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
| 70 | + @parametrize("ebits,mbits", _FPx_DTYPES) |
| 71 | + def test_to_copy_device(self, ebits, mbits): |
| 72 | + x = torch.randn(256, 64) |
| 73 | + fpx = QuantLlmLinearWeight.from_float(x, ebits, mbits).cuda() |
| 74 | + assert fpx.device.type == "cuda" |
| 75 | + fpx = fpx.cpu() |
| 76 | + assert fpx.device.type == "cpu" |
| 77 | + |
68 | 78 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
69 | 79 | @parametrize("ebits,mbits", _FPx_DTYPES)
|
70 | 80 | @parametrize("leading_dims", [(4,), (2, 4)])
|
@@ -98,6 +108,20 @@ def test_quant_llm_quantize(self, ebits, mbits, bias):
|
98 | 108 | actual = torch.compile(fpx_linear, fullgraph=True)(x)
|
99 | 109 | torch.testing.assert_close(actual, expected)
|
100 | 110 |
|
| 111 | + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
| 112 | + def test_fp6_llm_quantize(self): |
| 113 | + N, OC, IC = 4, 256, 64 |
| 114 | + device = "cuda" |
| 115 | + |
| 116 | + linear = torch.nn.Linear(IC, OC, device=device) |
| 117 | + fpx_linear = copy.deepcopy(linear) |
| 118 | + quantize_(fpx_linear, fp6_llm_weight_only()) |
| 119 | + |
| 120 | + x = torch.randn(N, IC, device=device, dtype=torch.half) |
| 121 | + expected = fpx_linear(x) |
| 122 | + actual = torch.compile(fpx_linear, fullgraph=True)(x) |
| 123 | + torch.testing.assert_close(actual, expected) |
| 124 | + |
101 | 125 |
|
102 | 126 | instantiate_parametrized_tests(TestQuantLlmLinearWeight)
|
103 | 127 |
|
|
0 commit comments