Skip to content

Commit 4b1088a

Browse files
committed
Add HPU SFT example
Signed-off-by: Sergey Plotnikov <[email protected]>
1 parent 8164824 commit 4b1088a

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
#!/usr/bin/env python3
2+
sample_name="SFT Training Example for HPU"
3+
"""
4+
This script demonstrates how to do SFT training on HPU
5+
using a single-node, multi-GPU setup with training_hub.
6+
7+
Example usage:
8+
python sft_hpu_example.py \\
9+
--data-path /path/to/data.jsonl \\
10+
--ckpt-output-dir /path/to/checkpoints
11+
"""
12+
13+
import os
14+
import sys
15+
import time
16+
from datetime import datetime
17+
import argparse
18+
19+
from training_hub import sft
20+
21+
22+
def main():
23+
#disable HPU backend autoloading
24+
os.environ['PT_HPU_AUTOLOAD'] = '0'
25+
26+
parser = argparse.ArgumentParser(description=sample_name)
27+
28+
# Required parameters
29+
parser.add_argument('--data-path', required=True,
30+
help='Path to training data (JSONL format)')
31+
parser.add_argument('--ckpt-output-dir', required=True,
32+
help='Directory to save checkpoints')
33+
parser.add_argument('--model-path', required=True,
34+
help='Model path or HuggingFace name')
35+
36+
# Optional overrides
37+
parser.add_argument('--num-epochs', type=int, default=3,
38+
help='Number of epochs (default: 3)')
39+
parser.add_argument('--effective-batch-size', type=int, default=128,
40+
help='effective batch size')
41+
parser.add_argument('--max-seq-len', type=int, default=16384,
42+
help='Maximum sequence length')
43+
parser.add_argument('--checkpoint-at-epoch', action='store_true',
44+
help='Store checkpoint after each epoch')
45+
parser.add_argument('--max-tokens-per-gpu', type=int, default=32768,
46+
help='Max tokens per GPU')
47+
parser.add_argument('--nproc-per-node', type=int, default=8,
48+
help='Number of GPUs')
49+
50+
args = parser.parse_args()
51+
52+
# sample configuration
53+
print(f"🚀 {sample_name}")
54+
print("=" * 50)
55+
print(f"Model: {args.model_path}")
56+
print(f"Data: {args.data_path}")
57+
print(f"Output: {args.ckpt_output_dir}")
58+
print(f"GPUs: {args.nproc_per_node}")
59+
print(f"Max tokens per GPU: {args.max_tokens_per_gpu:,}")
60+
print()
61+
62+
# Training configuration optimized for Llama 3.1 8B Instruct
63+
start_time = time.time()
64+
65+
try:
66+
result = sft(
67+
# Model and data
68+
model_path=args.model_path,
69+
data_path=args.data_path,
70+
ckpt_output_dir=args.ckpt_output_dir,
71+
72+
# Training parameters
73+
num_epochs=args.num_epochs,
74+
effective_batch_size=args.effective_batch_size,
75+
learning_rate=1e-5, # Lower LR for instruct model
76+
max_seq_len=args.max_seq_len,
77+
max_tokens_per_gpu=args.max_tokens_per_gpu,
78+
79+
# Data processing
80+
data_output_dir="/dev/shm", # Use RAM disk for speed
81+
warmup_steps=100,
82+
save_samples=0, # 0 disables sample-based checkpointing, use epoch-based only
83+
84+
# Checkpointing
85+
checkpoint_at_epoch=args.checkpoint_at_epoch,
86+
accelerate_full_state_at_epoch=False, # Disable for smaller checkpoints (no auto-resumption)
87+
88+
# Single-node multi-GPU setup
89+
nproc_per_node=args.nproc_per_node,
90+
nnodes=1,
91+
node_rank=0,
92+
rdzv_id=101,
93+
rdzv_endpoint="127.0.0.1:29500",
94+
95+
# HPU specific arguments
96+
disable_flash_attn = True,
97+
device = 'hpu',
98+
)
99+
100+
end_time = time.time()
101+
duration = end_time - start_time
102+
103+
print("=" * 50)
104+
print("✅ Training completed successfully!")
105+
print(f"⏱️ Duration: {duration/3600:.2f} hours")
106+
print(f"📁 Checkpoints: {args.ckpt_output_dir}/hf_format/")
107+
108+
except Exception as e:
109+
end_time = time.time()
110+
duration = end_time - start_time
111+
112+
print("=" * 50)
113+
print(f"❌ Training failed after {duration/60:.1f} minutes")
114+
print(f"Error: {e}")
115+
print()
116+
print("💡 Try reducing --max-tokens-per-gpu if you see OOM errors")
117+
sys.exit(1)
118+
119+
120+
if __name__ == "__main__":
121+
main()

0 commit comments

Comments
 (0)