Skip to content

Commit 3b928b9

Browse files
gau-nernstjerryzh168
authored andcommitted
Fix FP6-LLM API and add .to(device) op (#595)
* fix * add some ops for convenience
1 parent 8903a1a commit 3b928b9

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

test/prototype/test_quant_llm.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torchao.prototype.quant_llm import (
1212
QuantLlmLinearWeight,
1313
quant_llm_fpx_weight_only,
14+
fp6_llm_weight_only,
1415
to_scaled_tc_fpx,
1516
from_scaled_tc_fpx,
1617
)
@@ -65,6 +66,15 @@ def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device):
6566
actual = torch.compile(from_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits, scale)
6667
torch.testing.assert_close(actual, expected)
6768

69+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
70+
@parametrize("ebits,mbits", _FPx_DTYPES)
71+
def test_to_copy_device(self, ebits, mbits):
72+
x = torch.randn(256, 64)
73+
fpx = QuantLlmLinearWeight.from_float(x, ebits, mbits).cuda()
74+
assert fpx.device.type == "cuda"
75+
fpx = fpx.cpu()
76+
assert fpx.device.type == "cpu"
77+
6878
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
6979
@parametrize("ebits,mbits", _FPx_DTYPES)
7080
@parametrize("leading_dims", [(4,), (2, 4)])
@@ -98,6 +108,20 @@ def test_quant_llm_quantize(self, ebits, mbits, bias):
98108
actual = torch.compile(fpx_linear, fullgraph=True)(x)
99109
torch.testing.assert_close(actual, expected)
100110

111+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
112+
def test_fp6_llm_quantize(self):
113+
N, OC, IC = 4, 256, 64
114+
device = "cuda"
115+
116+
linear = torch.nn.Linear(IC, OC, device=device)
117+
fpx_linear = copy.deepcopy(linear)
118+
quantize_(fpx_linear, fp6_llm_weight_only())
119+
120+
x = torch.randn(N, IC, device=device, dtype=torch.half)
121+
expected = fpx_linear(x)
122+
actual = torch.compile(fpx_linear, fullgraph=True)(x)
123+
torch.testing.assert_close(actual, expected)
124+
101125

102126
instantiate_parametrized_tests(TestQuantLlmLinearWeight)
103127

torchao/prototype/quant_llm/quant_llm.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torchao.quantization.quant_api import _get_linear_subclass_inserter
1111

1212

13+
aten = torch.ops.aten
1314
_ONES_TABLE = [_n_ones(i) for i in range(8)]
1415

1516

@@ -430,11 +431,27 @@ def _(func, types, args, kwargs):
430431
return out.view(*act.shape[:-1], out_dim).to(act.dtype)
431432

432433

433-
@QuantLlmLinearWeight.implements(torch.ops.aten.detach.default)
434+
@QuantLlmLinearWeight.implements(aten.detach.default)
434435
def _(func, types, args, kwargs):
435436
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))
436437

437438

439+
@QuantLlmLinearWeight.implements(aten.clone.default)
440+
def _(func, types, args, kwargs):
441+
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.clone))
442+
443+
444+
@QuantLlmLinearWeight.implements(aten._to_copy.default)
445+
def _(func, types, args, kwargs):
446+
# only support device kwargs, ignore the rest
447+
return return_and_correct_aliasing(
448+
func,
449+
args,
450+
kwargs,
451+
args[0]._apply_fn_to_data(lambda x: x.to(device=kwargs.pop("device", None))),
452+
)
453+
454+
438455
def quant_llm_fpx_weight_only(ebits: int, mbits: int):
439456
def apply_quant_llm(weight: Tensor) -> Tensor:
440457
out_dim, in_dim = weight.shape
@@ -445,4 +462,4 @@ def apply_quant_llm(weight: Tensor) -> Tensor:
445462

446463

447464
def fp6_llm_weight_only():
448-
return _get_linear_subclass_inserter(quant_llm_fpx_weight_only(3, 2))
465+
return quant_llm_fpx_weight_only(3, 2)

0 commit comments

Comments
 (0)