From 503af8be7728acf483132647accdd84b1356dcd9 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 16 Aug 2024 16:24:59 -0700 Subject: [PATCH] [do not land] printing shapes for autoquant Summary: Generate shapes for micro benchmarking `benchmarks/benchmark_aq.py` but this doesn't seem very helpful for predicting the perf for llama2: https://gist.github.com/jerryzh168/efc0cb1be0a8a29c9edcd87cc01652f6 Test Plan: Reviewers: Subscribers: Tasks: Tags: --- benchmarks/benchmark_aq.py | 30 ++++++++++++++++++++++-------- torchao/_models/llama/generate.py | 13 ++++++++++++- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/benchmarks/benchmark_aq.py b/benchmarks/benchmark_aq.py index 174038d206..00471703b0 100644 --- a/benchmarks/benchmark_aq.py +++ b/benchmarks/benchmark_aq.py @@ -133,25 +133,39 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): WARMUP = 20 RUNS = 100 + torch._dynamo.reset() + m_bf16 = torch.compile(m_bf16, mode='max-autotune', fullgraph=True) + benchmark_model(m_bf16, WARMUP, example_inputs) + bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs) + + torch._dynamo.reset() m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True) benchmark_model(m_ref, WARMUP, example_inputs) ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs) + torch._dynamo.reset() m = torch.compile(m, mode='max-autotune', fullgraph=True) benchmark_model(m, WARMUP, example_inputs) elapsed_time = benchmark_model(m, RUNS, example_inputs) - - m_bf16 = torch.compile(m_bf16, mode='max-autotune', fullgraph=True) - benchmark_model(m_bf16, WARMUP, example_inputs) - bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs) - print(f"{(M, N, K)}: elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}") if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_4 and torch.cuda.is_available(): - all_shapes = [ - (20, 2048, 2048), - ] + # all_shapes = set([ + # (20, 2048, 2048), + # ]) + all_shapes = set([ + (6, 12288, 4096), + (6, 4096, 4096), + (6, 11008, 4096), + (6, 4096, 11008), + (6, 32000, 4096), + (1, 12288, 4096), + (1, 4096, 4096), + (1, 11008, 4096), + (1, 4096, 11008), + (1, 32000, 4096), + ]) print("_int8da_int8w_api") from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index bf1d870b52..75eb809fdb 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -234,6 +234,16 @@ def main( # do autoquantization model.finalize_autoquant() + + from torchao.quantization.autoquant import AUTOQUANT_CACHE + shapes = [] + for k in AUTOQUANT_CACHE.keys(): + act = k[1] + w = k[2] + M, K = act + N = w[0] + shapes.append((M, N, K)) + print("all shapes:", set(shapes)) else: if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(model) @@ -375,10 +385,11 @@ def callback(x): parser.add_argument('--profile', type=Path, default=None, help='Profile path.') parser.add_argument('--device', type=str, default=default_device, help='Device to use') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') + parser.add_argument('--print_autoquant_m_n_k', action='store_true', help='Whether to print the M, N, K shapes in AUTOQUANT_CACHE for micro benchmarking.') parser.add_argument('--write_result', type=Path, default=None, help='Path where to write the result') args = parser.parse_args() main( args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, - args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.save, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result + args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.save, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.print_autoquant_m_n_k, args.write_result )