Skip to content

Commit a1b853e

Browse files
committed
Add support for save quantized checkpoint in llama code
Summary: The goal is to upload a torchao quantized model to huggingface so that we can run the model in huggingface Test Plan: python generate.py -q int4wo-32 --save Reviewers: Subscribers: Tasks: Tags:
1 parent 04e5a9e commit a1b853e

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

scripts/hf_eval.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ def format_value(value):
4040

4141
print(tabulate(main_table, headers=['Task', 'Metrics'], tablefmt='grid'))
4242

43-
def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, batch_size, max_length):
43+
def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, save, batch_size, max_length):
4444

4545
tokenizer = AutoTokenizer.from_pretrained(repo_id)
4646
model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision)
4747

48-
if compile:
48+
if quantization == "autoquant" and compile:
4949
model = torch.compile(model, mode="max-autotune", fullgraph=True)
5050

5151
if quantization == "int8dq":
@@ -57,6 +57,10 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
5757
quantize_(model.to(device=device), int4_weight_only())
5858
elif quantization == "autoquant":
5959
model = autoquant(model.to(device=device))
60+
61+
if quantization != "autoquant" and compile:
62+
model = torch.compile(model, mode="max-autotune", fullgraph=True)
63+
6064
with torch.no_grad():
6165
result = evaluate(
6266
HFLM(
@@ -70,6 +74,12 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
7074

7175
pretty_print_nested_results(result)
7276

77+
if save:
78+
# This doesn't work yet: https://github.com/huggingface/transformers/issues/32364
79+
# model.save_pretrained("quantized_model_test", safe_serialization=False)
80+
file_name = repo_id.split("/")[-1] + "-" + quantization + ".pt"
81+
torch.save(model.state_dict(), file_name)
82+
7383

7484
if __name__ == '__main__':
7585
import argparse
@@ -81,8 +91,9 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
8191
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
8292
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply')
8393
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
94+
parser.add_argument('--save', action='store_true', help='Whether to save the model.')
8495
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes')
8596
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
8697

8798
args = parser.parse_args()
88-
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length)
99+
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.save, args.batch_size, args.max_length)

torchao/_models/llama/generate.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import os
67
import sys
78
import time
89
from pathlib import Path
@@ -165,6 +166,7 @@ def main(
165166
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
166167
quantization: Optional[str] = None,
167168
kv_cache_quantization: bool = False,
169+
save: bool = False,
168170
compile: bool = True,
169171
compile_prefill: bool = False,
170172
profile: Optional[Path] = None,
@@ -238,6 +240,11 @@ def main(
238240

239241
model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9
240242

243+
if save:
244+
output_dir = str(checkpoint_path.cwd())
245+
filename = str(checkpoint_path.name).split(".")[0]
246+
torch.save(model.state_dict(), os.path.join(output_dir, filename + f"-{quantization}.pt"))
247+
241248
if compile:
242249
print("Compiling Model")
243250
global decode_one_token, prefill
@@ -362,6 +369,7 @@ def callback(x):
362369
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
363370
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant')
364371
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
372+
parser.add_argument('--save', action='store_true', help='Whether to save the quantized model.')
365373
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
366374
parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
367375
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
@@ -372,5 +380,5 @@ def callback(x):
372380
args = parser.parse_args()
373381
main(
374382
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
375-
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result
383+
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.save, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result
376384
)

0 commit comments

Comments
 (0)