Skip to content

Commit d4d5abe

Browse files
committed
autoquant using aqt
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent de4a1fb commit d4d5abe

File tree

6 files changed

+120
-50
lines changed

6 files changed

+120
-50
lines changed

test/integration/test_integration.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,13 +1173,13 @@ def test_on_dummy_distilbert(self):
11731173
class TestAutoQuant(unittest.TestCase):
11741174
@parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE,
11751175
[
1176-
(16, 128, 128),
1177-
(64, 128, 128),
1176+
# (16, 128, 128),
1177+
# (64, 128, 128),
11781178
# (2**15, 128, 128), TODO: Runs out of shared memory on T4
1179-
(16, 128, 256),
1179+
(2, 128, 256),
11801180
# (64, 128, 256), # TODO: Runs out of shared memory on T4
1181-
(16, 256, 128),
1182-
(64, 256, 128),
1181+
# (16, 256, 128),
1182+
# (64, 256, 128),
11831183
# (256, 256, 128), TODO: Runs out of shared memory on T4
11841184
]))
11851185
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
@@ -1194,7 +1194,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
11941194
if m == 1:
11951195
self.skipTest(f"Shape {(m, k, n)} requires sm80+")
11961196
torch._inductor.config.epilogue_fusion = False
1197-
torch._inductor.config.use_mixed_mm = True
1197+
# torch._inductor.config.use_mixed_mm = True
11981198
torch._inductor.config.force_fuse_int_mm_with_mul = True
11991199
torch._dynamo.config.automatic_dynamic_shapes = False
12001200

torchao/_models/llama/benchmark_results.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,6 @@ kv cache quantization:
2727
20240801094415, tok/s= 87.20, mem/s=1308.88 GB/s, peak_mem=17.22 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8
2828
20240801095615, tok/s= 80.87, mem/s=1213.82 GB/s, peak_mem=19.77 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8
2929
20240801100912, tok/s= 74.65, mem/s=1120.41 GB/s, peak_mem=19.29 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8
30+
31+
20240806071013, tok/s=172.58, mem/s=1161.55 GB/s, peak_mem= 8.90 GB, model_size= 6.73 GB quant: autoquant, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
32+
20240806073549, tok/s=158.04, mem/s=1192.77 GB/s, peak_mem= 9.99 GB, model_size= 7.55 GB quant: autoquant, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8

torchao/_models/llama/benchmarks.sh

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,31 @@ export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder
22

33

44
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
5-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
6-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
7-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
8-
# in readme
9-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
10-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
11-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
12-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
5+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
6+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
7+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
8+
# # in readme
9+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
10+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
11+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
12+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
1313
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
1414

1515
export MODEL_REPO=meta-llama/Meta-Llama-3-8B
16-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
17-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
18-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
19-
# in readme
20-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
21-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
22-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
23-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
16+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
17+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
18+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
19+
# # in readme
20+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
21+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
22+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
23+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
2424
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
2525

26-
export MODEL_REPO=meta-llama/Meta-Llama-3-8B
27-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
28-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization
29-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 2048
30-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 2048
31-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 8192
32-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 8192
26+
# export MODEL_REPO=meta-llama/Meta-Llama-3-8B
27+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
28+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization
29+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 2048
30+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 2048
31+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 8192
32+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 8192

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ def dequantize(self, output_dtype=None):
155155
int_data, scale, zero_point = self.layout_tensor.get_plain()
156156
return dequantize_affine(int_data, self.block_size, scale, zero_point, int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype)
157157

158+
@staticmethod
159+
def _quantized_linear_op(input_tensor, weight_tensor, bias):
160+
return _quantized_linear_op(input_tensor, weight_tensor, bias)
161+
158162
def __tensor_flatten__(self):
159163
return ["layout_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]
160164

@@ -832,7 +836,7 @@ def _(func, types, args, kwargs):
832836
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
833837
# make the branches easier to understand in `_quantized_linear_op`
834838
try:
835-
return _quantized_linear_op(input_tensor, weight_tensor, bias)
839+
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
836840
except:
837841
if isinstance(input_tensor, AffineQuantizedTensor):
838842
input_tensor = input_tensor.dequantize()

torchao/quantization/autoquant.py

Lines changed: 75 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
import torch
22
import torchao
3+
from torchao.quantization.quant_primitives import (
4+
MappingType,
5+
ZeroPointDomain,
6+
)
37
from .subclass import ( # noqa
48
Int8DynamicallyQuantizedLinearWeight,
59
Int8WeightOnlyQuantizedLinearWeight,
610
QuantizedLinearWeightBase,
711
)
12+
from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType
13+
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
814
from torch.utils._python_dispatch import return_and_correct_aliasing
915
from .quant_primitives import (
1016
safe_int_mm,
@@ -252,9 +258,9 @@ class AQMixin():
252258
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
253259
w_qtensor = cls.from_float(weight)
254260
if _is_interpolate_mode(mode):
255-
q_c_op = torch.compile(cls._quantized_op, mode="max-autotune-no-cudagraphs")
261+
q_c_op = torch.compile(cls._quantized_linear_op, mode="max-autotune-no-cudagraphs")
256262
else:
257-
func = lambda a,b,c: F.relu(cls._quantized_op(F.relu(a), b, c))
263+
func = lambda a,b,c: F.relu(cls._quantized_linear_op(F.relu(a), b, c))
258264
q_c_op = torch.compile(func, mode="max-autotune-no-cudagraphs")
259265
res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=100)
260266
if res < best_time*1.1:
@@ -263,10 +269,53 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
263269
print(f">>time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ")
264270
return res
265271

266-
class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight):
272+
###### TODO !!!!!!!!!!!!!!!
273+
# 1) make class method from_float (just duplicate code)
274+
# 2) undo changes to quant_api?
275+
# 3) point to new quantized_op location
276+
# 4) rewrite the dynamic autoquant test
277+
278+
class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor):
267279
"""
268280
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
269281
"""
282+
@classmethod
283+
def from_float(cls, weight):
284+
in_features = weight.shape[1]
285+
# int8 dynamic quantization only has benefit when in_feature > 16
286+
# if in_features <= 16:
287+
# return weight
288+
289+
# avoid circular dep
290+
from torchao.dtypes import to_affine_quantized
291+
# weight settings
292+
mapping_type = MappingType.SYMMETRIC
293+
def get_weight_block_size(x):
294+
return (1, x.shape[1])
295+
target_dtype = torch.int8
296+
eps = torch.finfo(torch.float32).eps
297+
zero_point_dtype = torch.int64
298+
299+
# input settings
300+
def get_per_token_block_size(x):
301+
block_size = list(x.shape)
302+
for i in range(len(block_size)-1):
303+
block_size[i] = 1
304+
return block_size
305+
306+
input_mapping_type = MappingType.SYMMETRIC
307+
input_target_dtype = torch.int8
308+
input_eps = 1e-5
309+
input_quant_min = -127
310+
input_quant_max = 127
311+
layout_type = PlainLayoutType()
312+
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
313+
314+
block_size = get_weight_block_size(weight)
315+
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type)
316+
weight = super(AQInt8DynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func)
317+
return weight
318+
270319
@classmethod
271320
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
272321
"""
@@ -298,12 +347,13 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
298347
)
299348
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs")
300349
with torch.no_grad():
301-
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_qtensor.int_data)
350+
w_vals_int8 = w_qtensor.original_weight_tensor.layout_tensor.int_data.contiguous().t()
351+
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_vals_int8)
302352
print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms")
303353

304354
# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
305-
if res_matmul>=best_time:
306-
return res_matmul
355+
# if res_matmul>=best_time:
356+
# return res_matmul
307357

308358
# calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT
309359
to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul)
@@ -313,18 +363,27 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
313363
print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}")
314364
return res_f
315365

316-
class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
366+
class AQWeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
317367
"""
318368
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight
319369
"""
370+
@classmethod
371+
def from_float(cls, weight):
372+
mapping_type = MappingType.SYMMETRIC
373+
target_dtype = torch.int8
374+
eps = torch.finfo(torch.float32).eps
375+
zero_point_dtype = torch.int64
376+
block_size = (1, weight.shape[1])
377+
return super(AQWeightOnlyQuantizedLinearWeight, cls).from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
378+
320379

321-
class AQWeightOnlyQuantizedLinearWeight2(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
380+
class AQWeightOnlyQuantizedLinearWeight2(AQWeightOnlyQuantizedLinearWeight, AQMixin):
322381
"""
323382
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
324383
uses a different kernel
325384
"""
326385
@staticmethod
327-
def _quantized_op(act_mat, w_qtensor, bias):
386+
def _quantized_linear_op(act_mat, w_qtensor, bias):
328387
"""
329388
Performs the quantized linear operations
330389
@@ -339,8 +398,8 @@ def _quantized_op(act_mat, w_qtensor, bias):
339398
orig_dtype = act_mat.dtype
340399
orig_shape = act_mat.shape
341400
act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1)
342-
y = (act_mat*w_qtensor.int_data.unsqueeze(0)).sum(dim=-2)
343-
y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.q_scales
401+
y = (act_mat*w_qtensor.layout_tensor.int_data.t().unsqueeze(0)).sum(dim=-2)
402+
y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.layout_tensor.scale
344403
if bias is not None:
345404
y += bias
346405
return y.to(orig_dtype)
@@ -352,14 +411,14 @@ def _autoquant_test(cls, act_mat, *args):
352411
return torch.inf
353412
return super()._autoquant_test(act_mat, *args)
354413

355-
class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
414+
class AQWeightOnlyQuantizedLinearWeight3(AQWeightOnlyQuantizedLinearWeight, AQMixin):
356415
"""
357416
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
358417
uses a different kernel
359418
"""
360-
def _quantized_op(act_mat, w_qtensor, bias):
419+
def _quantized_linear_op(act_mat, w_qtensor, bias):
361420
orig_shape = act_mat.shape
362-
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales)
421+
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.layout_tensor.int_data.t()*w_qtensor.layout_tensor.scale)
363422
y=y.reshape(*orig_shape[:-1], y.shape[-1])
364423
if bias is not None:
365424
y += bias
@@ -377,7 +436,7 @@ def __init__(self):
377436
super().__init__()
378437

379438
@staticmethod
380-
def _quantized_op(act_mat, w_qtensor, bias):
439+
def _quantized_linear_op(act_mat, w_qtensor, bias):
381440
return torch.nn.functional.linear(act_mat, w_qtensor, bias)
382441

383442
@classmethod
@@ -389,7 +448,7 @@ def from_float(cls, weight):
389448
AQWeightOnlyQuantizedLinearWeight,
390449
AQWeightOnlyQuantizedLinearWeight2,
391450
# AQWeightOnlyQuantizedLinearWeight3,
392-
# TODO this gets picked in places where it makes perf worse, why?
451+
# # TODO this gets picked in places where it makes perf worse, why?
393452
AQInt8DynamicallyQuantizedLinearWeight,
394453
]
395454

0 commit comments

Comments
 (0)