Skip to content

Commit 64480e7

Browse files
committed
Merge branch 'main' into fix_llama
2 parents bd01882 + 05224a9 commit 64480e7

File tree

66 files changed

+2037
-797
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+2037
-797
lines changed

benchmarks/benchmark_fp6_llm.py renamed to benchmarks/benchmark_fp6.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
import torch
22
import pandas as pd
33
import torch.nn.functional as F
4-
from torchao.prototype.quant_llm import QuantLlmLinearWeight
4+
from torchao.dtypes import to_affine_quantized_fpx
5+
from torchao.dtypes.fpx import FpxTensorCoreAQTLayout, FpxTensorCoreLayoutType
56
from torchao.utils import benchmark_torch_function_in_microseconds
67
from tqdm import tqdm
78

89

910
def benchmark(m: int, k: int, n: int):
10-
fp6_data = torch.randint(256, size=(n, k * 3 // 4), dtype=torch.uint8, device="cuda")
11-
scale = torch.rand(n, dtype=torch.half, device="cuda") + 0.5
12-
fp6_weight = QuantLlmLinearWeight(fp6_data, scale, 3, 2)
13-
11+
float_data = torch.randn(n, k, dtype=torch.half, device="cuda")
12+
fp6_weight = to_affine_quantized_fpx(float_data, FpxTensorCoreLayoutType(3, 2))
1413
fp16_weight = fp6_weight.dequantize(torch.half)
1514

1615
fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")

benchmarks/float8/bench_linear_float8.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def float8_pct_top_peak(self):
9191
return self.float8_tops_sec / dtype_to_peak_tops[torch.float8_e4m3fn]
9292

9393

94+
# TODO(future PR): add option to measure GPU kernel time, as in other
95+
# scripts in this folder
9496
def main(
9597
sweep_path: Optional[Path] = None,
9698
compile: bool = True,
@@ -112,10 +114,33 @@ def main(
112114
scaling_type_input = ScalingType(scaling_type_input)
113115
scaling_type_weight = ScalingType(scaling_type_weight)
114116
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
117+
118+
if scaling_type_input is ScalingType.STATIC:
119+
cast_config_input=CastConfig(
120+
scaling_type=scaling_type_input,
121+
static_scale=torch.tensor([1.0], device="cuda"),
122+
)
123+
else:
124+
cast_config_input=CastConfig(scaling_type=scaling_type_input)
125+
if scaling_type_weight is ScalingType.STATIC:
126+
cast_config_weight=CastConfig(
127+
scaling_type=scaling_type_weight,
128+
static_scale=torch.tensor([1.0], device="cuda"),
129+
)
130+
else:
131+
cast_config_weight=CastConfig(scaling_type=scaling_type_weight)
132+
if scaling_type_grad_output is ScalingType.STATIC:
133+
cast_config_grad_output=CastConfig(
134+
scaling_type=scaling_type_grad_output,
135+
static_scale=torch.tensor([1.0], device="cuda"),
136+
)
137+
else:
138+
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output)
139+
115140
config = Float8LinearConfig(
116-
cast_config_input=CastConfig(scaling_type=scaling_type_input),
117-
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
118-
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
141+
cast_config_input=cast_config_input,
142+
cast_config_weight=cast_config_weight,
143+
cast_config_grad_output=cast_config_grad_output,
119144
)
120145

121146
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)

benchmarks/float8/profile_linear_float8.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -263,13 +263,35 @@ def main(
263263
scaling_type_input = ScalingType(scaling_type_input)
264264
scaling_type_weight = ScalingType(scaling_type_weight)
265265
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
266+
267+
if scaling_type_input is ScalingType.STATIC:
268+
cast_config_input=CastConfig(
269+
scaling_type=scaling_type_input,
270+
static_scale=torch.tensor([1.0], device="cuda"),
271+
)
272+
else:
273+
cast_config_input=CastConfig(scaling_type=scaling_type_input)
274+
if scaling_type_weight is ScalingType.STATIC:
275+
cast_config_weight=CastConfig(
276+
scaling_type=scaling_type_weight,
277+
static_scale=torch.tensor([1.0], device="cuda"),
278+
)
279+
else:
280+
cast_config_weight=CastConfig(scaling_type=scaling_type_weight)
281+
if scaling_type_grad_output is ScalingType.STATIC:
282+
cast_config_grad_output=CastConfig(
283+
scaling_type=scaling_type_grad_output,
284+
static_scale=torch.tensor([1.0], device="cuda"),
285+
)
286+
else:
287+
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output)
288+
266289
config = Float8LinearConfig(
267-
cast_config_input=CastConfig(scaling_type=scaling_type_input),
268-
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
269-
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
270-
enable_amax_init=False,
271-
enable_pre_and_post_forward=False,
290+
cast_config_input=cast_config_input,
291+
cast_config_weight=cast_config_weight,
292+
cast_config_grad_output=cast_config_grad_output,
272293
)
294+
273295
scaling_repr = "_".join(
274296
[
275297
s.short_str()

docs/source/api_ref_dtypes.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ torchao.dtypes
1111
:nosignatures:
1212

1313
to_nf4
14-
to_affine_quantized
14+
to_affine_quantized_intx
15+
to_affine_quantized_floatx
16+
to_affine_quantized_intx_static
1517
AffineQuantizedTensor
1618

1719
..

scripts/hf_eval.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
int8_dynamic_activation_int8_weight,
2121
quantize_,
2222
autoquant,
23+
fpx_weight_only,
2324
)
2425
from torchao.sparsity import (
2526
sparsify_,
@@ -59,6 +60,8 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, spars
5960
elif quantization == "int4wo":
6061
# note cannot quantize this model on cpu and run it on cuda at this time
6162
quantize_(model.to(device=device), int4_weight_only())
63+
elif quantization == "fp6":
64+
quantize_(model, fpx_weight_only(3, 2))
6265
elif quantization == "autoquant":
6366
model = autoquant(model.to(device=device))
6467

@@ -79,7 +82,7 @@ def all_linear(mod, name):
7982
return False
8083
torch.sparse.semi_structured._FORCE_CUTLASS = False
8184
sparsify_(model, semi_sparse_weight(), filter_fn=all_linear)
82-
85+
8386
if sparsity and compile:
8487
model = torch.compile(model, mode="max-autotune", fullgraph=True)
8588

@@ -111,7 +114,7 @@ def all_linear(mod, name):
111114
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
112115
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
113116
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
114-
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply')
117+
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "fp6", "None"], help='Which quantization technique to apply')
115118
parser.add_argument('-s', '--sparsity', default = "None", choices=["semi_sparse", "semi_sparse_mlp_only", "None"], help='Which sparsity technique to apply')
116119
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
117120
parser.add_argument('--save', action='store_true', help='Whether to save the model.')

test/dtypes/test_affine_quantized.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,15 @@
88
int8_dynamic_activation_int4_weight,
99
int8_dynamic_activation_int8_weight,
1010
int8_dynamic_activation_int8_semi_sparse_weight,
11-
)
12-
from torchao.dtypes import (
13-
to_affine_quantized,
11+
float8_weight_only,
1412
)
1513
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1614

1715
import torch
1816
import unittest
1917
import tempfile
2018

19+
2120
class TestAffineQuantized(TestCase):
2221
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
2322
def test_tensor_core_layout_transpose(self):
@@ -40,7 +39,8 @@ def test_tensor_core_layout_transpose(self):
4039

4140
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
4241
def test_weights_only(self):
43-
for apply_quant in [int4_weight_only(group_size=32), int8_weight_only(), int8_dynamic_activation_int4_weight(), int8_dynamic_activation_int8_weight(), int8_dynamic_activation_int8_semi_sparse_weight()]:
42+
for apply_quant in [int4_weight_only(group_size=32), int8_weight_only(), int8_dynamic_activation_int4_weight(),
43+
int8_dynamic_activation_int8_weight(), int8_dynamic_activation_int8_semi_sparse_weight(), float8_weight_only()]:
4444
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
4545
ql = apply_quant(l)
4646
with tempfile.NamedTemporaryFile() as f:
@@ -69,6 +69,5 @@ def test_to_device(self):
6969
ql.cuda()
7070

7171

72-
7372
if __name__ == "__main__":
7473
run_tests()

test/prototype/test_quant_llm.py renamed to test/dtypes/test_fpx.py

Lines changed: 29 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,27 @@
88
parametrize,
99
run_tests,
1010
)
11-
from torchao.prototype.quant_llm import (
12-
QuantLlmLinearWeight,
13-
quant_llm_fpx_weight_only,
14-
fp6_llm_weight_only,
11+
from torchao.dtypes.fpx import (
12+
FpxTensorCoreAQTLayout,
13+
FpxTensorCoreLayoutType,
1514
to_scaled_tc_fpx,
1615
from_scaled_tc_fpx,
1716
)
18-
from torchao.prototype.quant_llm.quant_llm import _pack_tc_fpx, _pack_tc_fp6
17+
from torchao.dtypes.fpx.fpx import _pack_tc_fpx, _pack_tc_fp6
1918
from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32
20-
from torchao.quantization.quant_api import quantize_
19+
from torchao.quantization import (
20+
quantize_,
21+
fpx_weight_only,
22+
)
23+
24+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
2125

2226

2327
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
2428
_FPx_DTYPES = [(3, 2), (2, 2)]
2529

2630

27-
class TestQuantLlmLinearWeight(TestCase):
31+
class TestFpxTensorCoreAQTLayout(TestCase):
2832
@parametrize("device", _DEVICES)
2933
def test_pack_tc_fp6_correctness(self, device):
3034
x = torch.randint(256, size=(256, 64), dtype=torch.uint8, device=device)
@@ -69,61 +73,40 @@ def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device):
6973
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
7074
@parametrize("ebits,mbits", _FPx_DTYPES)
7175
def test_to_copy_device(self, ebits, mbits):
72-
x = torch.randn(256, 64)
73-
fpx = QuantLlmLinearWeight.from_float(x, ebits, mbits).cuda()
74-
assert fpx.device.type == "cuda"
75-
fpx = fpx.cpu()
76-
assert fpx.device.type == "cpu"
76+
from torchao.quantization.quant_primitives import (
77+
choose_qparams_affine_fpx,
78+
quantize_affine_fpx,
79+
)
7780

78-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
79-
@parametrize("ebits,mbits", _FPx_DTYPES)
80-
@parametrize("leading_dims", [(4,), (2, 4)])
81-
@parametrize("bias", [False, True])
82-
def test_quant_llm_linear_weight(self, ebits, mbits, bias, leading_dims):
83-
OC, IC = 256, 64
84-
device = "cuda"
85-
86-
fp16_weight = torch.randn(OC, IC, device=device, dtype=torch.half)
87-
fp16_bias = torch.randn(OC, device=device, dtype=torch.half) if bias else None
88-
89-
fpx_weight = QuantLlmLinearWeight.from_float(fp16_weight, ebits, mbits)
90-
91-
x = torch.randn(*leading_dims, IC, device=device, dtype=torch.half)
92-
out = torch.nn.functional.linear(x, fpx_weight, fp16_bias)
93-
assert out.shape == leading_dims + (OC,)
81+
x = torch.randn(256, 64)
82+
scale = choose_qparams_affine_fpx(x, ebits, mbits)
83+
x = quantize_affine_fpx(x, scale, ebits, mbits)
84+
layout_type = FpxTensorCoreLayoutType(ebits, mbits)
85+
fpx_layout_tensor = FpxTensorCoreAQTLayout.from_plain(x, scale, None, layout_type).cuda()
86+
assert fpx_layout_tensor.device.type == "cuda"
87+
fpx_layout_tensor = fpx_layout_tensor.cpu()
88+
assert fpx_layout_tensor.device.type == "cpu"
9489

9590
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
91+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+")
9692
@parametrize("ebits,mbits", _FPx_DTYPES)
9793
@parametrize("bias", [False, True])
98-
def test_quant_llm_quantize(self, ebits, mbits, bias):
99-
N, OC, IC = 4, 256, 64
100-
device = "cuda"
101-
102-
linear = torch.nn.Linear(IC, OC, bias=bias, device=device)
103-
fpx_linear = copy.deepcopy(linear)
104-
quantize_(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits))
105-
106-
x = torch.randn(N, IC, device=device, dtype=torch.half)
107-
expected = fpx_linear(x)
108-
actual = torch.compile(fpx_linear, fullgraph=True)(x)
109-
torch.testing.assert_close(actual, expected)
110-
111-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
112-
def test_fp6_llm_quantize(self):
94+
def test_fpx_weight_only(self, ebits, mbits, bias):
11395
N, OC, IC = 4, 256, 64
11496
device = "cuda"
11597

116-
linear = torch.nn.Linear(IC, OC, device=device)
98+
linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=torch.half)
11799
fpx_linear = copy.deepcopy(linear)
118-
quantize_(fpx_linear, fp6_llm_weight_only())
100+
quantize_(fpx_linear, fpx_weight_only(ebits, mbits))
119101

120102
x = torch.randn(N, IC, device=device, dtype=torch.half)
121103
expected = fpx_linear(x)
122104
actual = torch.compile(fpx_linear, fullgraph=True)(x)
105+
# somehow compile now changes the result a bit
123106
torch.testing.assert_close(actual, expected)
124107

125108

126-
instantiate_parametrized_tests(TestQuantLlmLinearWeight)
109+
instantiate_parametrized_tests(TestFpxTensorCoreAQTLayout)
127110

128111

129112
if __name__ == "__main__":

test/float8/test_base.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def test_copy_(self):
134134
fp8_b.copy_(fp8_a)
135135
torch.testing.assert_close(fp8_a._data, fp8_b._data)
136136

137+
@pytest.mark.skip("broken")
137138
def test_weights_only_load(self):
138139
module = nn.Linear(16, 16)
139140
# Save model state dict
@@ -226,14 +227,16 @@ def _test_linear_impl(
226227
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
227228
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
228229
@pytest.mark.parametrize(
229-
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC]
230+
"scaling_type_input",
231+
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
230232
)
231233
@pytest.mark.parametrize(
232-
"scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC]
234+
"scaling_type_weight",
235+
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
233236
)
234237
@pytest.mark.parametrize(
235238
"scaling_type_grad_output",
236-
[ScalingType.DELAYED, ScalingType.DYNAMIC],
239+
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
237240
)
238241
@pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32])
239242
@pytest.mark.parametrize("linear_bias", [False, True])
@@ -259,10 +262,33 @@ def test_linear(
259262
pytest.skip()
260263
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
261264
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
265+
266+
if scaling_type_input is ScalingType.STATIC:
267+
cast_config_input = CastConfig(
268+
scaling_type=scaling_type_input,
269+
static_scale=torch.tensor([1.0], device="cuda"),
270+
)
271+
else:
272+
cast_config_input = CastConfig(scaling_type=scaling_type_input)
273+
if scaling_type_weight is ScalingType.STATIC:
274+
cast_config_weight = CastConfig(
275+
scaling_type=scaling_type_weight,
276+
static_scale=torch.tensor([1.0], device="cuda"),
277+
)
278+
else:
279+
cast_config_weight = CastConfig(scaling_type=scaling_type_weight)
280+
if scaling_type_grad_output is ScalingType.STATIC:
281+
cast_config_grad_output = CastConfig(
282+
scaling_type=scaling_type_grad_output,
283+
static_scale=torch.tensor([1.0], device="cuda"),
284+
)
285+
else:
286+
cast_config_grad_output = CastConfig(scaling_type=scaling_type_grad_output)
287+
262288
config = Float8LinearConfig(
263-
cast_config_input=CastConfig(scaling_type=scaling_type_input),
264-
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
265-
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
289+
cast_config_input=cast_config_input,
290+
cast_config_weight=cast_config_weight,
291+
cast_config_grad_output=cast_config_grad_output,
266292
emulate=emulate,
267293
)
268294
self._test_linear_impl(

0 commit comments

Comments
 (0)