|
| 1 | +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +# Example of fine-tuning a model on a TPU using FSDPv2, TRL and PEFT. |
| 16 | +# |
| 17 | +# Run the script with: |
| 18 | +# python finetune_lm_tpu.py [--model_id MODEL_ID] [--dataset_id DATASET_ID] |
| 19 | +# |
| 20 | +# This script has been tested on a TPU v5 litepod-8. |
| 21 | + |
| 22 | +import argparse |
| 23 | + |
| 24 | +import torch |
| 25 | +import torch_xla.runtime as xr |
| 26 | +from datasets import load_dataset |
| 27 | +from peft import LoraConfig |
| 28 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 29 | +from trl import SFTConfig, SFTTrainer |
| 30 | + |
| 31 | + |
| 32 | +# FSDPv2 requires SPMD to be enabled. |
| 33 | +xr.use_spmd() |
| 34 | + |
| 35 | + |
| 36 | +def format_dolly(example, tokenizer): |
| 37 | + """Format Dolly dataset examples using the tokenizer's chat template.""" |
| 38 | + user_content = example["instruction"] |
| 39 | + if len(example["context"]) > 0: |
| 40 | + user_content += f"\n\nContext: {example['context']}" |
| 41 | + |
| 42 | + messages = [ |
| 43 | + { |
| 44 | + "role": "system", |
| 45 | + "content": "You are a helpful assistant", |
| 46 | + }, |
| 47 | + {"role": "user", "content": user_content}, |
| 48 | + {"role": "assistant", "content": example["response"]}, |
| 49 | + ] |
| 50 | + |
| 51 | + return tokenizer.apply_chat_template(messages, tokenize=False) |
| 52 | + |
| 53 | + |
| 54 | +def train(model_id, dataset): |
| 55 | + # Load model with low_cpu_mem_usage to avoid loading full model into CPU memory |
| 56 | + # FSDPv2 will handle sharding across TPUs |
| 57 | + model = AutoModelForCausalLM.from_pretrained( |
| 58 | + model_id, |
| 59 | + use_cache=False, |
| 60 | + torch_dtype=torch.bfloat16, |
| 61 | + low_cpu_mem_usage=True, |
| 62 | + device_map=None, # Let FSDP handle device placement |
| 63 | + ) |
| 64 | + |
| 65 | + tokenizer = AutoTokenizer.from_pretrained(model_id) |
| 66 | + |
| 67 | + if tokenizer.pad_token is None: |
| 68 | + if model.config.model_type == "llama": |
| 69 | + # Vanilla Llama models have a finetune gith pad id token |
| 70 | + tokenizer.pad_token = "<|finetune_right_pad_id|>" |
| 71 | + elif tokenizer.eos_token is not None: |
| 72 | + tokenizer.pad_token = tokenizer.eos_token |
| 73 | + else: |
| 74 | + raise ValueError(f"Cannot get or guess pad token for model {model_id}.") |
| 75 | + |
| 76 | + if tokenizer.chat_template is None: |
| 77 | + # Set chat template for Llama 3.1 format |
| 78 | + tokenizer.chat_template = ( |
| 79 | + "{% for message in messages %}" |
| 80 | + "{% if message['role'] == 'system' %}" |
| 81 | + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{{ message['content'] }}<|eot_id|>" |
| 82 | + "{% elif message['role'] == 'user' %}" |
| 83 | + "<|start_header_id|>user<|end_header_id|>\n\n{{ message['content'] }}<|eot_id|>" |
| 84 | + "{% elif message['role'] == 'assistant' %}" |
| 85 | + "<|start_header_id|>assistant<|end_header_id|>\n\n{{ message['content'] }}<|eot_id|>" |
| 86 | + "{% endif %}" |
| 87 | + "{% endfor %}" |
| 88 | + "{% if add_generation_prompt %}" |
| 89 | + "<|start_header_id|>assistant<|end_header_id|>\n\n" |
| 90 | + "{% endif %}" |
| 91 | + ) |
| 92 | + |
| 93 | + # Try to guess the DecoderLayer class name, based on common model architectures |
| 94 | + transformer_layer_cls_to_wrap = model.model.layers[0].__class__.__name__ |
| 95 | + |
| 96 | + # Get FSDP training arguments |
| 97 | + fsdp_training_args = { |
| 98 | + "fsdp": "full_shard", |
| 99 | + "fsdp_config": { |
| 100 | + "transformer_layer_cls_to_wrap": [transformer_layer_cls_to_wrap], |
| 101 | + "xla": True, |
| 102 | + "xla_fsdp_v2": True, |
| 103 | + "xla_fsdp_grad_ckpt": True, |
| 104 | + }, |
| 105 | + } |
| 106 | + |
| 107 | + # Set up PEFT LoRA for fine-tuning. |
| 108 | + lora_config = LoraConfig( |
| 109 | + r=32, |
| 110 | + lora_alpha=128, |
| 111 | + lora_dropout=0.05, |
| 112 | + target_modules=["q_proj", "k_proj"], |
| 113 | + task_type="CAUSAL_LM", |
| 114 | + ) |
| 115 | + |
| 116 | + sft_config = SFTConfig( |
| 117 | + gradient_checkpointing=False, # Required on TPU, not supported |
| 118 | + max_length=1024, |
| 119 | + per_device_train_batch_size=4, |
| 120 | + num_train_epochs=3, |
| 121 | + max_steps=-1, |
| 122 | + output_dir="./output", |
| 123 | + optim="adafactor", |
| 124 | + logging_steps=1, |
| 125 | + dataloader_drop_last=True, # Required for FSDPv2. |
| 126 | + dataset_text_field="text", |
| 127 | + packing=True, |
| 128 | + **fsdp_training_args, |
| 129 | + ) |
| 130 | + |
| 131 | + # Set up the trainer |
| 132 | + trainer = SFTTrainer( |
| 133 | + model=model, |
| 134 | + train_dataset=dataset, |
| 135 | + args=sft_config, |
| 136 | + peft_config=lora_config, |
| 137 | + processing_class=tokenizer, |
| 138 | + formatting_func=lambda example: format_dolly(example, tokenizer), |
| 139 | + ) |
| 140 | + |
| 141 | + trainer.train() |
| 142 | + |
| 143 | + |
| 144 | +# ============================================================================= |
| 145 | +# Main Function |
| 146 | +# ============================================================================= |
| 147 | +if __name__ == "__main__": |
| 148 | + parser = argparse.ArgumentParser(description="Simple example of training script.") |
| 149 | + |
| 150 | + parser.add_argument( |
| 151 | + "--model_id", "-m", type=str, default="meta-llama/Llama-3.2-1B", help="Model id to use for training." |
| 152 | + ) |
| 153 | + parser.add_argument( |
| 154 | + "--dataset_id", |
| 155 | + "-d", |
| 156 | + type=str, |
| 157 | + default="databricks/databricks-dolly-15k", |
| 158 | + help="Dataset id to use for training.", |
| 159 | + ) |
| 160 | + |
| 161 | + args = parser.parse_args() |
| 162 | + |
| 163 | + # NOTE: this section can be adapted to load any dataset you want. |
| 164 | + dataset_id = args.dataset_id |
| 165 | + dolly_dataset = load_dataset(dataset_id, split="train") |
| 166 | + |
| 167 | + train( |
| 168 | + model_id=args.model_id, |
| 169 | + dataset=dolly_dataset, |
| 170 | + ) |
0 commit comments