Skip to content

[do not land] printing shapes for autoquant #695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions benchmarks/benchmark_aq.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,25 +133,39 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
WARMUP = 20
RUNS = 100

torch._dynamo.reset()
m_bf16 = torch.compile(m_bf16, mode='max-autotune', fullgraph=True)
benchmark_model(m_bf16, WARMUP, example_inputs)
bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs)

torch._dynamo.reset()
m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True)
benchmark_model(m_ref, WARMUP, example_inputs)
ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs)

torch._dynamo.reset()
m = torch.compile(m, mode='max-autotune', fullgraph=True)
benchmark_model(m, WARMUP, example_inputs)
elapsed_time = benchmark_model(m, RUNS, example_inputs)


m_bf16 = torch.compile(m_bf16, mode='max-autotune', fullgraph=True)
benchmark_model(m_bf16, WARMUP, example_inputs)
bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs)

print(f"{(M, N, K)}: elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}")

if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_4 and torch.cuda.is_available():
all_shapes = [
(20, 2048, 2048),
]
# all_shapes = set([
# (20, 2048, 2048),
# ])
all_shapes = set([
(6, 12288, 4096),
(6, 4096, 4096),
(6, 11008, 4096),
(6, 4096, 11008),
(6, 32000, 4096),
(1, 12288, 4096),
(1, 4096, 4096),
(1, 11008, 4096),
(1, 4096, 11008),
(1, 32000, 4096),
])

print("_int8da_int8w_api")
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
Expand Down
13 changes: 12 additions & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,16 @@ def main(

# do autoquantization
model.finalize_autoquant()

from torchao.quantization.autoquant import AUTOQUANT_CACHE
shapes = []
for k in AUTOQUANT_CACHE.keys():
act = k[1]
w = k[2]
M, K = act
N = w[0]
shapes.append((M, N, K))
print("all shapes:", set(shapes))
else:
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(model)
Expand Down Expand Up @@ -375,10 +385,11 @@ def callback(x):
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
parser.add_argument('--device', type=str, default=default_device, help='Device to use')
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
parser.add_argument('--print_autoquant_m_n_k', action='store_true', help='Whether to print the M, N, K shapes in AUTOQUANT_CACHE for micro benchmarking.')
parser.add_argument('--write_result', type=Path, default=None, help='Path where to write the result')

args = parser.parse_args()
main(
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.save, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.save, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.print_autoquant_m_n_k, args.write_result
)
Loading