Commit f0cefe6
committed
Update QAT recipe to match full finetune recipe (5/12/25)
**Summary:** Similar to meta-pytorch#1854. Update `qat_distributed` recipe to mirror `full_finetune_distributed` up until a6db644. The new major feature that is excluded from `qat_distributed` is FP8 finetuning (meta-pytorch#2546), since QAT FP8 is not supported in torchao yet.
Diff between full finetune and QAT recipes: P1809370361
```
diff --color recipes/full_finetune_distributed.py recipes/qat_distributed.py
```
**Test Plan:**
Finetune:
```
tune run --nnodes 1 --nproc_per_node 4 qat_distributed --config llama3_2/3B_qat_full \
epochs=1 \
batch_size=16 \
dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
checkpointer.output_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat \
output_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/metrics \
metric_logger.log_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/metrics \
quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer \
quantizer.groupsize=32
```
Quantize:
```
tune run quantize --config quantization \
model._component_=torchtune.models.llama3_2.llama3_2_3b \
checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/epoch_0 \
checkpointer.output_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/epoch_0_out \
'checkpointer.checkpoint_files=[model-00001-of-00002.safetensors,model-00002-of-00002.safetensors]' \
checkpointer.model_type=LLAMA3 \
quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
quantizer.groupsize=32
```
Eval:
```
tune run eleuther_eval --config eleuther_evaluation \
batch_size=1 \
'tasks=[wikitext]' \
model._component_=torchtune.models.llama3_2.llama3_2_3b \
checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer.checkpoint_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/epoch_0 \
checkpointer.output_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/epoch_0_out \
'checkpointer.checkpoint_files=[model-00001-of-00002-8da4w.ckpt]' \
checkpointer.model_type=LLAMA3 \
tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
quantizer.groupsize=32
```
Results:
```
experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved
----------------------- ------------------- ----------------- ---------------- -------------------
Llama3.2-3B_alpaca_full 4677.163 (+0.000%) 12.261 (+0.000%) 12.261 (+0.000%) 15.778 (+0.000%)
Llama3.2-3B_alpaca_qat 1873.316 (-59.948%) 13.047 (+6.409%) 13.047 (+6.409%) 17.226 (+9.176%)
experiment_name hellaswag_acc wikitext_word_perplexity
----------------------- ------------------------------ -------------------------------
Llama3.2-3B_alpaca_full 0.470 quant, 0.534 float 18.563 quant, 12.364 float
Llama3.2-3B_alpaca_qat 0.511 quant, recovered 63.043% 13.792 quant, recovered 76.962%
```1 parent a6db644 commit f0cefe6
1 file changed
+356
-211
lines changed
0 commit comments