Skip to content

Commit 2623bab

Browse files
committed
finalizing PR
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 710e149 commit 2623bab

File tree

4 files changed

+33
-42
lines changed

4 files changed

+33
-42
lines changed

test/integration/test_integration.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,13 +1173,13 @@ def test_on_dummy_distilbert(self):
11731173
class TestAutoQuant(unittest.TestCase):
11741174
@parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE,
11751175
[
1176-
# (16, 128, 128),
1177-
# (64, 128, 128),
1176+
(16, 128, 128),
1177+
(64, 128, 128),
11781178
# (2**15, 128, 128), TODO: Runs out of shared memory on T4
1179-
(2, 128, 256),
1179+
(16, 128, 256),
11801180
# (64, 128, 256), # TODO: Runs out of shared memory on T4
1181-
# (16, 256, 128),
1182-
# (64, 256, 128),
1181+
(16, 256, 128),
1182+
(64, 256, 128),
11831183
# (256, 256, 128), TODO: Runs out of shared memory on T4
11841184
]))
11851185
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
@@ -1194,7 +1194,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
11941194
if m == 1:
11951195
self.skipTest(f"Shape {(m, k, n)} requires sm80+")
11961196
torch._inductor.config.epilogue_fusion = False
1197-
# torch._inductor.config.use_mixed_mm = True
1197+
torch._inductor.config.use_mixed_mm = True
11981198
torch._inductor.config.force_fuse_int_mm_with_mul = True
11991199
torch._dynamo.config.automatic_dynamic_shapes = False
12001200

torchao/_models/llama/benchmark_results.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,3 @@ kv cache quantization:
2727
20240801094415, tok/s= 87.20, mem/s=1308.88 GB/s, peak_mem=17.22 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8
2828
20240801095615, tok/s= 80.87, mem/s=1213.82 GB/s, peak_mem=19.77 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8
2929
20240801100912, tok/s= 74.65, mem/s=1120.41 GB/s, peak_mem=19.29 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8
30-
31-
20240806071013, tok/s=172.58, mem/s=1161.55 GB/s, peak_mem= 8.90 GB, model_size= 6.73 GB quant: autoquant, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
32-
20240806073549, tok/s=158.04, mem/s=1192.77 GB/s, peak_mem= 9.99 GB, model_size= 7.55 GB quant: autoquant, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8

torchao/_models/llama/benchmarks.sh

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,31 @@ export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder
22

33

44
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
5-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
6-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
7-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
8-
# # in readme
9-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
10-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
11-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
12-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
5+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
6+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
7+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
8+
# in readme
9+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
10+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
11+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
12+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
1313
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
1414

1515
export MODEL_REPO=meta-llama/Meta-Llama-3-8B
16-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
17-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
18-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
19-
# # in readme
20-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
21-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
22-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
23-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
16+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
17+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
18+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
19+
# in readme
20+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
21+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
22+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
23+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
2424
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
2525

26-
# export MODEL_REPO=meta-llama/Meta-Llama-3-8B
27-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
28-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization
29-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 2048
30-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 2048
31-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 8192
32-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 8192
26+
export MODEL_REPO=meta-llama/Meta-Llama-3-8B
27+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
28+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization
29+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 2048
30+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 2048
31+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 8192
32+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 8192

torchao/quantization/autoquant.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -269,19 +269,13 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
269269
print(f">>time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ")
270270
return res
271271

272-
###### TODO !!!!!!!!!!!!!!!
273-
# 1) make class method from_float (just duplicate code)
274-
# 2) undo changes to quant_api?
275-
# 3) point to new quantized_op location
276-
# 4) rewrite the dynamic autoquant test
277-
278272
class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor):
279273
"""
280274
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
281275
"""
282276
@classmethod
283277
def from_float(cls, weight):
284-
in_features = weight.shape[1]
278+
# in_features = weight.shape[1]
285279
# int8 dynamic quantization only has benefit when in_feature > 16
286280
# if in_features <= 16:
287281
# return weight
@@ -352,8 +346,8 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
352346
print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms")
353347

354348
# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
355-
# if res_matmul>=best_time:
356-
# return res_matmul
349+
if res_matmul>=best_time:
350+
return res_matmul
357351

358352
# calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT
359353
to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul)
@@ -448,7 +442,7 @@ def from_float(cls, weight):
448442
AQWeightOnlyQuantizedLinearWeight,
449443
AQWeightOnlyQuantizedLinearWeight2,
450444
# AQWeightOnlyQuantizedLinearWeight3,
451-
# # TODO this gets picked in places where it makes perf worse, why?
445+
# TODO this gets picked in places where it makes perf worse, why?
452446
AQInt8DynamicallyQuantizedLinearWeight,
453447
]
454448

0 commit comments

Comments
 (0)