-
Notifications
You must be signed in to change notification settings - Fork 18
Add HPU SFT example #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,188 @@ | ||||||||||
| #!/usr/bin/env python3 | ||||||||||
| sample_name="OSFT Training Example for HPU" | ||||||||||
| """ | ||||||||||
| This script demonstrates how to do OSFT (Orthogonal Subspace Fine-Tuning) training on HPU | ||||||||||
| using a single-node, multi-GPU setup with training_hub. | ||||||||||
|
|
||||||||||
| OSFT allows continual training without catastrophic forgetting, making it ideal for: | ||||||||||
| - Adapting models to specialized domains (medical, legal, technical) | ||||||||||
| - Adding new knowledge without degrading general capabilities | ||||||||||
| - Fine-tuning without complex replay mechanisms | ||||||||||
|
|
||||||||||
| Example usage: | ||||||||||
| python osft_hpu_example.py \\ | ||||||||||
| --data-path /path/to/data.jsonl \\ | ||||||||||
| --ckpt-output-dir /path/to/checkpoints \\ | ||||||||||
| --model-path meta-llama/Meta-Llama-3.1-8B-Instruct | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| import os | ||||||||||
| import sys | ||||||||||
| import time | ||||||||||
| from datetime import datetime | ||||||||||
| import argparse | ||||||||||
| import glob | ||||||||||
|
|
||||||||||
| from training_hub import osft | ||||||||||
|
|
||||||||||
| def find_most_recent_checkpoint(output_dir): | ||||||||||
| """ | ||||||||||
| Find the most recent checkpoint in the training output directory. | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| output_dir (str): Training output directory containing hf_format/ subdirectory | ||||||||||
|
|
||||||||||
| Returns: | ||||||||||
| str: Path to the most recent checkpoint | ||||||||||
|
|
||||||||||
| Raises: | ||||||||||
| ValueError: If no checkpoints are found | ||||||||||
| """ | ||||||||||
| # Get all checkpoint directories under hf_format | ||||||||||
| checkpoint_pattern = os.path.join(output_dir, "hf_format", "samples_*.0") | ||||||||||
| checkpoint_dirs = glob.glob(checkpoint_pattern) | ||||||||||
|
|
||||||||||
| if not checkpoint_dirs: | ||||||||||
| raise ValueError(f"No checkpoints found in {os.path.join(output_dir, 'hf_format')}") | ||||||||||
|
|
||||||||||
| # Find the most recently created checkpoint | ||||||||||
| most_recent_checkpoint = max(checkpoint_dirs, key=os.path.getctime) | ||||||||||
|
|
||||||||||
| return most_recent_checkpoint | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def main(): | ||||||||||
| parser = argparse.ArgumentParser(description=sample_name) | ||||||||||
|
|
||||||||||
| # Required parameters | ||||||||||
| parser.add_argument('--data-path', required=True, | ||||||||||
| help='Path to training data (JSONL format)') | ||||||||||
| parser.add_argument('--ckpt-output-dir', required=True, | ||||||||||
| help='Directory to save checkpoints') | ||||||||||
| parser.add_argument('--model-path', required=True, | ||||||||||
| help='Model path or HuggingFace name') | ||||||||||
|
|
||||||||||
| # Optional overrides | ||||||||||
| parser.add_argument('--num-epochs', type=int, default=3, | ||||||||||
| help='Number of epochs (default: 3)') | ||||||||||
| parser.add_argument('--unfreeze-rank-ratio', type=float, default=0.3, | ||||||||||
| help='Unfreeze rank ratio for OSFT (0.0-1.0, default: 0.3)') | ||||||||||
| parser.add_argument('--effective-batch-size', type=int, default=128, | ||||||||||
| help='effective batch size') | ||||||||||
| parser.add_argument('--max-seq-len', type=int, default=4096, | ||||||||||
| help='Maximum sequence length (default: 4096)') | ||||||||||
| parser.add_argument('--checkpoint-at-epoch', action='store_true', default=False, | ||||||||||
| help='Store checkpoint after each epoch') | ||||||||||
| parser.add_argument('--max-tokens-per-gpu', type=int, default=32768, | ||||||||||
| help='Max tokens per GPU (default: 32768)') | ||||||||||
| parser.add_argument('--nproc-per-node', type=int, default=8, | ||||||||||
| help='Number of GPUs (default: 8)') | ||||||||||
|
Comment on lines
+78
to
+79
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix inconsistent terminology: GPUs → HPUs. This is an HPU example, but the help text refers to "GPUs" instead of "HPUs". Apply this diff: - parser.add_argument('--nproc-per-node', type=int, default=8,
- help='Number of GPUs (default: 8)')
+ parser.add_argument('--nproc-per-node', type=int, default=8,
+ help='Number of HPUs per node (default: 8)')📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||
| parser.add_argument('--unmask-messages', action='store_true', default=False, | ||||||||||
| help='Unmask messages during training (default: False)') | ||||||||||
| parser.add_argument('--learning-rate', type=float, default=5e-6, | ||||||||||
| help='Learning rate for training (default: 5e-6)') | ||||||||||
| parser.add_argument('--torch-compile', action='store_true', default=False, | ||||||||||
| help='Enable torch.compile, hpu only') | ||||||||||
| parser.add_argument('--num-chunks', type=int, default=1, | ||||||||||
| help='Number of chunks to split dataset into for sequential training') | ||||||||||
|
|
||||||||||
| args = parser.parse_args() | ||||||||||
|
|
||||||||||
| # sample configuration | ||||||||||
| print(f"🚀 {sample_name}") | ||||||||||
| print("=" * 50) | ||||||||||
| print(f"Model: {args.model_path}") | ||||||||||
| print(f"Data: {args.data_path}") | ||||||||||
| print(f"Output: {args.ckpt_output_dir}") | ||||||||||
| print(f"GPUs: {args.nproc_per_node}") | ||||||||||
|
Comment on lines
+91
to
+97
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix terminology inconsistency in banner output. Similar to the help text, the banner output refers to "GPUs" instead of "HPUs". Apply this diff: - # sample configuration
- print(f"🚀 {sample_name}")
+ # Sample configuration
+ print(f"🚀 {SAMPLE_NAME}") # Update if sample_name is renamed
print("=" * 50)
print(f"Model: {args.model_path}")
print(f"Data: {args.data_path}")
print(f"Output: {args.ckpt_output_dir}")
- print(f"GPUs: {args.nproc_per_node}")
+ print(f"HPUs: {args.nproc_per_node}")
🤖 Prompt for AI Agents |
||||||||||
| print(f"Unfreeze Rank Ratio: {args.unfreeze_rank_ratio}") | ||||||||||
| print(f"Max tokens per GPU: {args.max_tokens_per_gpu:,}") | ||||||||||
| print() | ||||||||||
| print("📝 OSFT Benefits:") | ||||||||||
| print(" • Preserve model's strong general capabilities") | ||||||||||
| print(" • Add domain-specific knowledge efficiently") | ||||||||||
| print(" • No need for complex data mixing or replay buffers") | ||||||||||
| print() | ||||||||||
|
|
||||||||||
| # Training configuration | ||||||||||
| start_time = time.time() | ||||||||||
|
|
||||||||||
| try: | ||||||||||
| osft_params = { | ||||||||||
| # Model and data | ||||||||||
| 'model_path': args.model_path, | ||||||||||
| 'data_path': args.data_path, | ||||||||||
| 'ckpt_output_dir': args.ckpt_output_dir, | ||||||||||
|
|
||||||||||
| # OSFT-specific parameters | ||||||||||
| 'unfreeze_rank_ratio': args.unfreeze_rank_ratio, | ||||||||||
|
|
||||||||||
| # Training parameters | ||||||||||
| 'num_epochs': args.num_epochs, | ||||||||||
| 'effective_batch_size': args.effective_batch_size, | ||||||||||
| 'learning_rate': args.learning_rate, | ||||||||||
| 'max_seq_len': args.max_seq_len, | ||||||||||
| 'max_tokens_per_gpu': args.max_tokens_per_gpu, | ||||||||||
|
|
||||||||||
| # Data processing | ||||||||||
| 'data_output_dir': "/dev/shm", # Use RAM disk for speed | ||||||||||
| 'warmup_steps': 0, | ||||||||||
| 'unmask_messages': args.unmask_messages, | ||||||||||
|
|
||||||||||
| # Optimization | ||||||||||
| 'use_liger': False, | ||||||||||
| 'seed': 42, | ||||||||||
| 'lr_scheduler': 'cosine', # Cosine scheduler works well with OSFT | ||||||||||
|
|
||||||||||
| # Checkpointing | ||||||||||
| 'checkpoint_at_epoch': args.checkpoint_at_epoch, | ||||||||||
| 'save_final_checkpoint': True, | ||||||||||
|
|
||||||||||
| # Single-node multi-GPU setup | ||||||||||
| 'nproc_per_node': args.nproc_per_node, | ||||||||||
| 'nnodes': 1, | ||||||||||
| 'node_rank': 0, | ||||||||||
| 'rdzv_id': 103, | ||||||||||
| 'rdzv_endpoint': "127.0.0.1:29500", | ||||||||||
|
|
||||||||||
| # HPU specific arguments | ||||||||||
| 'device': 'hpu', | ||||||||||
| 'torch_compile': args.torch_compile, | ||||||||||
| 'num_chunks': args.num_chunks, | ||||||||||
|
Comment on lines
+148
to
+151
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chainVerify HPU-specific parameters are supported. The script passes HPU-specific parameters ( Verify that the osft function accepts these parameters: 🏁 Script executed: #!/bin/bash
# Check if the osft function signature or implementation handles device, torch_compile, and num_chunks
rg -n -A 20 'def osft\(' src/training_hub/algorithms/osft.py
rg -n 'device.*hpu|torch_compile|num_chunks' src/training_hub/Length of output: 773 🏁 Script executed: #!/bin/bash
# Get full osft function signature with more context
rg -n -A 50 'def osft\(' src/training_hub/algorithms/osft.py | head -60Length of output: 1984 🏁 Script executed: #!/bin/bash
# Check if osft accepts **kwargs
rg -n -B 2 -A 15 'def osft\(' src/training_hub/algorithms/osft.py | grep -E 'def osft|kwargs'Length of output: 13 🏁 Script executed: #!/bin/bash
# Search for any handling of device, torch_compile, num_chunks in the entire osft.py file
cat src/training_hub/algorithms/osft.py | head -100Length of output: 4035 🏁 Script executed: #!/bin/bash
# Look for kwargs unpacking or parameter forwarding in osft function
rg -n '\*\*' src/training_hub/algorithms/osft.pyLength of output: 556 🏁 Script executed: #!/bin/bash
# Find where osft is defined and check parameter docs/annotations
fd -t f 'osft.py' src/training_hub/Length of output: 35 🏁 Script executed: #!/bin/bash
# Get the full osft function - with line numbers
wc -l src/training_hub/algorithms/osft.pyLength of output: 39 🏁 Script executed: #!/bin/bash
# Extract full osft function definition
sed -n '448,550p' src/training_hub/algorithms/osft.pyLength of output: 2274 🏁 Script executed: #!/bin/bash
# Search for how osft function might use or document these parameters
rg -n 'device|torch_compile|num_chunks' src/training_hub/algorithms/osft.py -A 2 -B 2Length of output: 0 🏁 Script executed: #!/bin/bash
# Check examples directory for how osft is called
rg -n 'osft\(' examples/scripts/osft_hpu_example.py -A 10 -B 5Length of output: 591 🏁 Script executed: #!/bin/bash
# Look at the full context of the call in the example script around lines 148-151
sed -n '130,160p' examples/scripts/osft_hpu_example.pyLength of output: 996 🏁 Script executed: #!/bin/bash
# Search for documentation or type hints that mention these parameters
rg -n 'device.*:.*str|torch_compile.*:|num_chunks.*:' src/training_hub/Length of output: 0 🏁 Script executed: #!/bin/bash
# Check if there's any kwargs handling in osft
rg -n 'kwargs' src/training_hub/algorithms/osft.pyLength of output: 675 🏁 Script executed: #!/bin/bash
# Look for where osft function might reference these parameters implicitly
rg -n "kwargs\['device'\]|kwargs\['torch_compile'\]|kwargs\['num_chunks'\]" src/training_hub/Length of output: 0 🏁 Script executed: #!/bin/bash
# Get the line count and then extract the full osft function
sed -n '448,500p' src/training_hub/algorithms/osft.pyLength of output: 1850 🏁 Script executed: #!/bin/bash
# Check for the full function signature by looking for def and closing parenthesis
python3 << 'EOF'
import re
try:
with open('src/training_hub/algorithms/osft.py', 'r') as f:
content = f.read()
# Find the osft function definition
match = re.search(r'def osft\([^)]*\):', content, re.DOTALL)
if match:
print("=== osft function signature ===")
print(match.group(0))
else:
print("Function signature pattern not found")
# Try simpler pattern
match = re.search(r'def osft\([\s\S]{0,1000}\):', content)
if match:
print("=== osft function (extended search) ===")
print(match.group(0))
except Exception as e:
print(f"Error: {e}")
EOFLength of output: 36 🏁 Script executed: #!/bin/bash
# Check what parameters the osft function actually accepts
python3 << 'EOF'
import inspect
import sys
sys.path.insert(0, '.')
try:
# Try to find and parse the file
with open('src/training_hub/algorithms/osft.py', 'r') as f:
lines = f.readlines()
# Find the osft function starting at line 448
func_start = 447 # 0-indexed
paren_count = 0
found_start = False
func_sig_lines = []
for i in range(func_start, len(lines)):
line = lines[i]
func_sig_lines.append(line.rstrip())
if 'def osft(' in line:
found_start = True
paren_count = line.count('(') - line.count(')')
elif found_start:
paren_count += line.count('(') - line.count(')')
if found_start and ':' in line and paren_count == 0:
break
print("=== Full osft function signature ===")
for line in func_sig_lines:
print(line)
except Exception as e:
print(f"Error: {e}")
EOFLength of output: 1157 🏁 Script executed: #!/bin/bash
# Check if device, torch_compile, num_chunks are passed in the example call
grep -n -B 5 -A 15 "osft(" examples/scripts/osft_hpu_example.pyLength of output: 921 🏁 Script executed: #!/bin/bash
# Look for kwargs parameter or **kwargs in osft function
rg 'def osft' src/training_hub/algorithms/osft.py -A 40 | grep -E 'kwargs|\*\*'Length of output: 67 🏁 Script executed: #!/bin/bash
# Final check - get osft signature using ast-grep
ast-grep --pattern $'def osft($_) {
$$$
}'Length of output: 0 🏁 Script executed: #!/bin/bash
# Simple extraction - first 100 lines after osft function start
awk '/^def osft\(/{flag=1; count=0} flag{print; count++; if(count>60) exit}' src/training_hub/algorithms/osft.pyLength of output: 2148 🏁 Script executed: #!/bin/bash
# Check for backend parameter which might handle device routing
rg -n 'backend' src/training_hub/algorithms/osft.py | head -20Length of output: 861 🏁 Script executed: #!/bin/bash
# Look for any forwarding to other functions that might accept these params
rg -n 'device|torch_compile|num_chunks' src/training_hub/ --type py -B 3 -A 3Length of output: 0 🏁 Script executed: #!/bin/bash
# Get osft function body to see what it does with parameters
sed -n '448,600p' src/training_hub/algorithms/osft.py | grep -E '^\s*(device|torch_compile|num_chunks|def|return)'Length of output: 37 🏁 Script executed: #!/bin/bash
# Try getting the exact lines of the function definition in one go
python3 << 'EOF'
with open('src/training_hub/algorithms/osft.py', 'r') as f:
lines = f.readlines()
in_def = False
paren_depth = 0
for i in range(447, min(len(lines), 550)):
line = lines[i]
if 'def osft(' in line:
in_def = True
if in_def:
print(f"{i+1}: {line}", end='')
paren_depth += line.count('(') - line.count(')')
if ':' in line and paren_depth == 0:
break
EOFLength of output: 1285 🏁 Script executed: #!/bin/bash
# Also check where osft is being called with the parameters
python3 << 'EOF'
import re
with open('examples/scripts/osft_hpu_example.py', 'r') as f:
content = f.read()
# Find the osft call
match = re.search(r'osft\([^)]*\)', content, re.DOTALL)
if match:
print("=== osft function call in example ===")
print(match.group(0)[:500])
EOFLength of output: 57 🏁 Script executed: #!/bin/bash
# Check the actual call site around line 148-151
sed -n '140,160p' examples/scripts/osft_hpu_example.pyLength of output: 614 🏁 Script executed: #!/bin/bash
# Look for any wrapper or intermediate function that might accept **kwargs
rg -n 'def.*osft|osft(' examples/scripts/osft_hpu_example.py -B 3 -A 10Length of output: 80 🏁 Script executed: #!/bin/bash
# Find MiniTrainerOSFTBackend.execute_training implementation
rg -n 'def execute_training' src/training_hub/algorithms/osft.py -A 30Length of output: 1395 🏁 Script executed: #!/bin/bash
# Find _validate_param_types to see what parameters are accepted
rg -n '_validate_param_types' src/training_hub/algorithms/osft.py -A 50Length of output: 4304 🏁 Script executed: #!/bin/bash
# Check optional_param_types to see if device, torch_compile, num_chunks are listed
sed -n '220,260p' src/training_hub/algorithms/osft.pyLength of output: 1740 🏁 Script executed: #!/bin/bash
# Search for how parameters are used in backend execution
rg -n 'device|torch_compile|num_chunks' src/training_hub/ -B 2 -A 2Length of output: 0 🏁 Script executed: #!/bin/bash
# Check if instructlab training handles these parameters
find src/training_hub -name '*.py' -type f | xargs grep -l 'device.*hpu\|torch_compile'Length of output: 0 🏁 Script executed: #!/bin/bash
# Look for parameter validation and what happens to unknown kwargs
sed -n '190,210p' src/training_hub/algorithms/osft.pyLength of output: 822 🏁 Script executed: #!/bin/bash
# Check what the backend.execute_training actually does
sed -n '370,400p' src/training_hub/algorithms/osft.pyLength of output: 1473 🏁 Script executed: #!/bin/bash
# Examine the full MiniTrainerOSFTBackend class
sed -n '316,445p' src/training_hub/algorithms/osft.pyLength of output: 5768 🏁 Script executed: #!/bin/bash
# Check if these params are stripped out or cause errors
rg -n 'Unknown|unsupported|device|torch_compile|num_chunks' src/training_hub/algorithms/osft.pyLength of output: 0 🏁 Script executed: #!/bin/bash
# Find MiniTrainerOSFTBackend.execute_training implementation
rg -n 'def execute_training' src/training_hub/algorithms/osft.py -A 30Length of output: 1395 🏁 Script executed: #!/bin/bash
# Find _validate_param_types to see what parameters are accepted
rg -n '_validate_param_types' src/training_hub/algorithms/osft.py -A 50Length of output: 4304 🏁 Script executed: #!/bin/bash
# Check optional_param_types to see if device, torch_compile, num_chunks are listed
sed -n '220,260p' src/training_hub/algorithms/osft.pyLength of output: 1740 🏁 Script executed: #!/bin/bash
# Look at parameter validation logic
sed -n '217,280p' src/training_hub/algorithms/osft.pyLength of output: 2659 🏁 Script executed: #!/bin/bash
# Check what happens to unrecognized kwargs
sed -n '190,215p' src/training_hub/algorithms/osft.pyLength of output: 975 🏁 Script executed: #!/bin/bash
# Examine the full MiniTrainerOSFTBackend class and execute_training
sed -n '316,445p' src/training_hub/algorithms/osft.pyLength of output: 5768 🏁 Script executed: #!/bin/bash
# Check if parameters flow through to instructlab training or are handled locally
rg -n 'instructlab' src/training_hub/algorithms/osft.py -B 2 -A 5Length of output: 733 🏁 Script executed: #!/bin/bash
# Look for parameter mapping or transformation
sed -n '335,370p' src/training_hub/algorithms/osft.pyLength of output: 2028 HPU-specific parameters are accepted but silently discarded during backend execution. The Either add these parameters to 🤖 Prompt for AI Agents |
||||||||||
| } | ||||||||||
|
|
||||||||||
|
|
||||||||||
| osft(**osft_params) | ||||||||||
|
|
||||||||||
| end_time = time.time() | ||||||||||
| duration = end_time - start_time | ||||||||||
|
|
||||||||||
| most_recent_checkpoint = find_most_recent_checkpoint(args.ckpt_output_dir) | ||||||||||
|
|
||||||||||
| print("=" * 50) | ||||||||||
| print("✅ OSFT Training completed successfully!") | ||||||||||
| print(f"⏱️ Duration: {duration/3600:.2f} hours") | ||||||||||
| print(f"📁 Checkpoints: {args.ckpt_output_dir}/hf_format") | ||||||||||
| print(f" Most recent checkpoint: {most_recent_checkpoint}") | ||||||||||
| print() | ||||||||||
| print("🎯 Your model has been successfully adapted!") | ||||||||||
| print(" The model now incorporates your domain-specific knowledge") | ||||||||||
| print(" while maintaining its original instruction-following abilities.") | ||||||||||
|
|
||||||||||
| except Exception as e: | ||||||||||
| end_time = time.time() | ||||||||||
| duration = end_time - start_time | ||||||||||
|
|
||||||||||
| print("=" * 50) | ||||||||||
| print(f"❌ Training failed after {duration/60:.1f} minutes") | ||||||||||
| print(f"Error: {e}") | ||||||||||
| print() | ||||||||||
| print("💡 Troubleshooting tips:") | ||||||||||
| print(" - Reduce --max-tokens-per-gpu if you see OOM errors") | ||||||||||
| print(" - For domain adaptation, try --unfreeze-rank-ratio between 0.2-0.4") | ||||||||||
| sys.exit(1) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| if __name__ == "__main__": | ||||||||||
| main() | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| #!/usr/bin/env python3 | ||
splotnikv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| sample_name="SFT Training Example for HPU" | ||
| """ | ||
| This script demonstrates how to do SFT training on HPU | ||
| using a single-node, multi-GPU setup with training_hub. | ||
splotnikv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| Example usage: | ||
| python sft_hpu_example.py \\ | ||
| --data-path /path/to/data.jsonl \\ | ||
| --ckpt-output-dir /path/to/checkpoints | ||
| """ | ||
splotnikv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| import os | ||
| import sys | ||
| import time | ||
| from datetime import datetime | ||
| import argparse | ||
|
|
||
| from training_hub import sft | ||
|
|
||
|
|
||
| def main(): | ||
| #disable HPU backend autoloading | ||
| os.environ['PT_HPU_AUTOLOAD'] = '0' | ||
|
|
||
splotnikv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| parser = argparse.ArgumentParser(description=sample_name) | ||
|
|
||
| # Required parameters | ||
| parser.add_argument('--data-path', required=True, | ||
| help='Path to training data (JSONL format)') | ||
| parser.add_argument('--ckpt-output-dir', required=True, | ||
| help='Directory to save checkpoints') | ||
| parser.add_argument('--model-path', required=True, | ||
| help='Model path or HuggingFace name') | ||
|
|
||
| # Optional overrides | ||
| parser.add_argument('--num-epochs', type=int, default=3, | ||
| help='Number of epochs (default: 3)') | ||
| parser.add_argument('--effective-batch-size', type=int, default=128, | ||
| help='effective batch size') | ||
| parser.add_argument('--max-seq-len', type=int, default=16384, | ||
| help='Maximum sequence length') | ||
| parser.add_argument('--checkpoint-at-epoch', action='store_true', | ||
| help='Store checkpoint after each epoch') | ||
| parser.add_argument('--max-tokens-per-gpu', type=int, default=32768, | ||
| help='Max tokens per GPU') | ||
| parser.add_argument('--nproc-per-node', type=int, default=8, | ||
| help='Number of GPUs') | ||
| parser.add_argument('--torch-compile', action='store_true', default=False, | ||
| help='Enable torch.compile, hpu only') | ||
| parser.add_argument('--num-chunks', type=int, default=1, | ||
| help='Number of chunks to split dataset into for sequential training') | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
| # sample configuration | ||
| print(f"🚀 {sample_name}") | ||
| print("=" * 50) | ||
| print(f"Model: {args.model_path}") | ||
| print(f"Data: {args.data_path}") | ||
| print(f"Output: {args.ckpt_output_dir}") | ||
| print(f"GPUs: {args.nproc_per_node}") | ||
| print(f"Max tokens per GPU: {args.max_tokens_per_gpu:,}") | ||
| print() | ||
|
|
||
| # Training configuration optimized for Llama 3.1 8B Instruct | ||
| start_time = time.time() | ||
|
|
||
| try: | ||
| result = sft( | ||
splotnikv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # Model and data | ||
| model_path=args.model_path, | ||
| data_path=args.data_path, | ||
| ckpt_output_dir=args.ckpt_output_dir, | ||
|
|
||
| # Training parameters | ||
| num_epochs=args.num_epochs, | ||
| effective_batch_size=args.effective_batch_size, | ||
| learning_rate=1e-5, # Lower LR for instruct model | ||
| max_seq_len=args.max_seq_len, | ||
| max_tokens_per_gpu=args.max_tokens_per_gpu, | ||
|
|
||
| # Data processing | ||
| data_output_dir="/dev/shm", # Use RAM disk for speed | ||
| warmup_steps=100, | ||
| save_samples=0, # 0 disables sample-based checkpointing, use epoch-based only | ||
|
|
||
| # Checkpointing | ||
| checkpoint_at_epoch=args.checkpoint_at_epoch, | ||
| accelerate_full_state_at_epoch=False, # Disable for smaller checkpoints (no auto-resumption) | ||
|
|
||
| # Single-node multi-GPU setup | ||
| nproc_per_node=args.nproc_per_node, | ||
| nnodes=1, | ||
| node_rank=0, | ||
| rdzv_id=101, | ||
| rdzv_endpoint="127.0.0.1:29500", | ||
|
|
||
| # HPU specific arguments | ||
| disable_flash_attn = True, | ||
| device = 'hpu', | ||
| torch_compile = args.torch_compile, | ||
| num_chunks = args.num_chunks, | ||
| ) | ||
|
|
||
| end_time = time.time() | ||
| duration = end_time - start_time | ||
|
|
||
| print("=" * 50) | ||
| print("✅ Training completed successfully!") | ||
| print(f"⏱️ Duration: {duration/3600:.2f} hours") | ||
| print(f"📁 Checkpoints: {args.ckpt_output_dir}/hf_format/") | ||
|
|
||
| except Exception as e: | ||
| end_time = time.time() | ||
| duration = end_time - start_time | ||
|
|
||
| print("=" * 50) | ||
| print(f"❌ Training failed after {duration/60:.1f} minutes") | ||
| print(f"Error: {e}") | ||
| print() | ||
| print("💡 Try reducing --max-tokens-per-gpu if you see OOM errors") | ||
| sys.exit(1) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify datetime import usage.
The
datetimemodule is imported but doesn't appear to be used in the script. Onlytime.time()is called for duration tracking.Run this script to confirm the import is unused:
🏁 Script executed:
Length of output: 0
Let me verify this more thoroughly by checking for any datetime usage patterns and reviewing the complete file:
Length of output: 32
🏁 Script executed:
Length of output: 7231
Remove unused
datetimeimport.The
datetimemodule is imported on line 22 but is never used in the script. The timing operations usetime.time()instead. Remove this unused import.🤖 Prompt for AI Agents