Skip to content

Commit 9a20895

Browse files
authored
feat: added fine tuning example focused on TPUs (#3847)
* feat: added fine tuning example focused on TPUs The example script is focused on showing off fine-tuning is possible with current version of accelerate and it is written to be run on TPUs. It has been successfully run and tested on a TPU v5 litepod-8, and it shows how it is possible to perform a fine-tuning task on such hardware thanks to accelerate and FSDPv2, using transformers and Torch XLA. * review: add comment that explains how to launch TPU example script * chore(style): fix formatting
1 parent 139d14b commit 9a20895

File tree

2 files changed

+176
-0
lines changed

2 files changed

+176
-0
lines changed

examples/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,12 @@ with `pip install runhouse`, and you can refer to
255255
for hardware setup instructions, or this
256256
[Colab tutorial](https://colab.research.google.com/drive/1qVwYyLTCPYPSdz9ZX7BZl9Qm0A3j7RJe) for a more in-depth walkthrough.
257257
258+
## Simple fine-tuning script that works on TPU
259+
260+
[finetune_lm_tpu.py](./finetune_lm_tpu.py) is a classical language modeling generation fine tuning script that has been
261+
adapted to run best on TPUs. It has been successfully run and tested on a TPU v5 litepod-8, and it shows how it is
262+
possible to perform a fine-tuning task on such hardware thanks to accelerate and FSDPv2, using transformers and Torch XLA.
263+
258264
## Finer Examples
259265
260266
While the first two scripts are extremely barebones when it comes to what you can do with accelerate, more advanced features are documented in two other locations.

examples/finetune_lm_tpu.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Example of fine-tuning a model on a TPU using FSDPv2, TRL and PEFT.
16+
#
17+
# Run the script with:
18+
# python finetune_lm_tpu.py [--model_id MODEL_ID] [--dataset_id DATASET_ID]
19+
#
20+
# This script has been tested on a TPU v5 litepod-8.
21+
22+
import argparse
23+
24+
import torch
25+
import torch_xla.runtime as xr
26+
from datasets import load_dataset
27+
from peft import LoraConfig
28+
from transformers import AutoModelForCausalLM, AutoTokenizer
29+
from trl import SFTConfig, SFTTrainer
30+
31+
32+
# FSDPv2 requires SPMD to be enabled.
33+
xr.use_spmd()
34+
35+
36+
def format_dolly(example, tokenizer):
37+
"""Format Dolly dataset examples using the tokenizer's chat template."""
38+
user_content = example["instruction"]
39+
if len(example["context"]) > 0:
40+
user_content += f"\n\nContext: {example['context']}"
41+
42+
messages = [
43+
{
44+
"role": "system",
45+
"content": "You are a helpful assistant",
46+
},
47+
{"role": "user", "content": user_content},
48+
{"role": "assistant", "content": example["response"]},
49+
]
50+
51+
return tokenizer.apply_chat_template(messages, tokenize=False)
52+
53+
54+
def train(model_id, dataset):
55+
# Load model with low_cpu_mem_usage to avoid loading full model into CPU memory
56+
# FSDPv2 will handle sharding across TPUs
57+
model = AutoModelForCausalLM.from_pretrained(
58+
model_id,
59+
use_cache=False,
60+
torch_dtype=torch.bfloat16,
61+
low_cpu_mem_usage=True,
62+
device_map=None, # Let FSDP handle device placement
63+
)
64+
65+
tokenizer = AutoTokenizer.from_pretrained(model_id)
66+
67+
if tokenizer.pad_token is None:
68+
if model.config.model_type == "llama":
69+
# Vanilla Llama models have a finetune gith pad id token
70+
tokenizer.pad_token = "<|finetune_right_pad_id|>"
71+
elif tokenizer.eos_token is not None:
72+
tokenizer.pad_token = tokenizer.eos_token
73+
else:
74+
raise ValueError(f"Cannot get or guess pad token for model {model_id}.")
75+
76+
if tokenizer.chat_template is None:
77+
# Set chat template for Llama 3.1 format
78+
tokenizer.chat_template = (
79+
"{% for message in messages %}"
80+
"{% if message['role'] == 'system' %}"
81+
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{{ message['content'] }}<|eot_id|>"
82+
"{% elif message['role'] == 'user' %}"
83+
"<|start_header_id|>user<|end_header_id|>\n\n{{ message['content'] }}<|eot_id|>"
84+
"{% elif message['role'] == 'assistant' %}"
85+
"<|start_header_id|>assistant<|end_header_id|>\n\n{{ message['content'] }}<|eot_id|>"
86+
"{% endif %}"
87+
"{% endfor %}"
88+
"{% if add_generation_prompt %}"
89+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
90+
"{% endif %}"
91+
)
92+
93+
# Try to guess the DecoderLayer class name, based on common model architectures
94+
transformer_layer_cls_to_wrap = model.model.layers[0].__class__.__name__
95+
96+
# Get FSDP training arguments
97+
fsdp_training_args = {
98+
"fsdp": "full_shard",
99+
"fsdp_config": {
100+
"transformer_layer_cls_to_wrap": [transformer_layer_cls_to_wrap],
101+
"xla": True,
102+
"xla_fsdp_v2": True,
103+
"xla_fsdp_grad_ckpt": True,
104+
},
105+
}
106+
107+
# Set up PEFT LoRA for fine-tuning.
108+
lora_config = LoraConfig(
109+
r=32,
110+
lora_alpha=128,
111+
lora_dropout=0.05,
112+
target_modules=["q_proj", "k_proj"],
113+
task_type="CAUSAL_LM",
114+
)
115+
116+
sft_config = SFTConfig(
117+
gradient_checkpointing=False, # Required on TPU, not supported
118+
max_length=1024,
119+
per_device_train_batch_size=4,
120+
num_train_epochs=3,
121+
max_steps=-1,
122+
output_dir="./output",
123+
optim="adafactor",
124+
logging_steps=1,
125+
dataloader_drop_last=True, # Required for FSDPv2.
126+
dataset_text_field="text",
127+
packing=True,
128+
**fsdp_training_args,
129+
)
130+
131+
# Set up the trainer
132+
trainer = SFTTrainer(
133+
model=model,
134+
train_dataset=dataset,
135+
args=sft_config,
136+
peft_config=lora_config,
137+
processing_class=tokenizer,
138+
formatting_func=lambda example: format_dolly(example, tokenizer),
139+
)
140+
141+
trainer.train()
142+
143+
144+
# =============================================================================
145+
# Main Function
146+
# =============================================================================
147+
if __name__ == "__main__":
148+
parser = argparse.ArgumentParser(description="Simple example of training script.")
149+
150+
parser.add_argument(
151+
"--model_id", "-m", type=str, default="meta-llama/Llama-3.2-1B", help="Model id to use for training."
152+
)
153+
parser.add_argument(
154+
"--dataset_id",
155+
"-d",
156+
type=str,
157+
default="databricks/databricks-dolly-15k",
158+
help="Dataset id to use for training.",
159+
)
160+
161+
args = parser.parse_args()
162+
163+
# NOTE: this section can be adapted to load any dataset you want.
164+
dataset_id = args.dataset_id
165+
dolly_dataset = load_dataset(dataset_id, split="train")
166+
167+
train(
168+
model_id=args.model_id,
169+
dataset=dolly_dataset,
170+
)

0 commit comments

Comments
 (0)