Skip to content

Commit f0cefe6

Browse files
committed
Update QAT recipe to match full finetune recipe (5/12/25)
**Summary:** Similar to meta-pytorch#1854. Update `qat_distributed` recipe to mirror `full_finetune_distributed` up until a6db644. The new major feature that is excluded from `qat_distributed` is FP8 finetuning (meta-pytorch#2546), since QAT FP8 is not supported in torchao yet. Diff between full finetune and QAT recipes: P1809370361 ``` diff --color recipes/full_finetune_distributed.py recipes/qat_distributed.py ``` **Test Plan:** Finetune: ``` tune run --nnodes 1 --nproc_per_node 4 qat_distributed --config llama3_2/3B_qat_full \ epochs=1 \ batch_size=16 \ dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \ checkpointer.output_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat \ output_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/metrics \ metric_logger.log_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/metrics \ quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer \ quantizer.groupsize=32 ``` Quantize: ``` tune run quantize --config quantization \ model._component_=torchtune.models.llama3_2.llama3_2_3b \ checkpointer._component_=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/epoch_0 \ checkpointer.output_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/epoch_0_out \ 'checkpointer.checkpoint_files=[model-00001-of-00002.safetensors,model-00002-of-00002.safetensors]' \ checkpointer.model_type=LLAMA3 \ quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \ quantizer.groupsize=32 ``` Eval: ``` tune run eleuther_eval --config eleuther_evaluation \ batch_size=1 \ 'tasks=[wikitext]' \ model._component_=torchtune.models.llama3_2.llama3_2_3b \ checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/epoch_0 \ checkpointer.output_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/epoch_0_out \ 'checkpointer.checkpoint_files=[model-00001-of-00002-8da4w.ckpt]' \ checkpointer.model_type=LLAMA3 \ tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \ quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \ quantizer.groupsize=32 ``` Results: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ----------------------- ------------------- ----------------- ---------------- ------------------- Llama3.2-3B_alpaca_full 4677.163 (+0.000%) 12.261 (+0.000%) 12.261 (+0.000%) 15.778 (+0.000%) Llama3.2-3B_alpaca_qat 1873.316 (-59.948%) 13.047 (+6.409%) 13.047 (+6.409%) 17.226 (+9.176%) experiment_name hellaswag_acc wikitext_word_perplexity ----------------------- ------------------------------ ------------------------------- Llama3.2-3B_alpaca_full 0.470 quant, 0.534 float 18.563 quant, 12.364 float Llama3.2-3B_alpaca_qat 0.511 quant, recovered 63.043% 13.792 quant, recovered 76.962% ```
1 parent a6db644 commit f0cefe6

File tree

1 file changed

+356
-211
lines changed

1 file changed

+356
-211
lines changed

0 commit comments

Comments
 (0)