Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions examples/scripts/osft_hpu_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
#!/usr/bin/env python3
sample_name="OSFT Training Example for HPU"
"""
This script demonstrates how to do OSFT (Orthogonal Subspace Fine-Tuning) training on HPU
using a single-node, multi-GPU setup with training_hub.

OSFT allows continual training without catastrophic forgetting, making it ideal for:
- Adapting models to specialized domains (medical, legal, technical)
- Adding new knowledge without degrading general capabilities
- Fine-tuning without complex replay mechanisms

Example usage:
python osft_hpu_example.py \\
--data-path /path/to/data.jsonl \\
--ckpt-output-dir /path/to/checkpoints \\
--model-path meta-llama/Meta-Llama-3.1-8B-Instruct
"""

import os
import sys
import time
from datetime import datetime
import argparse
import glob

from training_hub import osft
Comment on lines +19 to +26
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify datetime import usage.

The datetime module is imported but doesn't appear to be used in the script. Only time.time() is called for duration tracking.

Run this script to confirm the import is unused:


🏁 Script executed:

#!/bin/bash
# Check if datetime module is used beyond the import statement
rg -n '\bdatetime\.' examples/scripts/osft_hpu_example.py

Length of output: 0


Let me verify this more thoroughly by checking for any datetime usage patterns and reviewing the complete file:


---

🏁 Script executed:

```shell
#!/bin/bash
# Check for any datetime usage in the file (class, function calls, etc.)
rg -n 'datetime' examples/scripts/osft_hpu_example.py

Length of output: 32


🏁 Script executed:

#!/bin/bash
# Show the complete file to understand the context
cat examples/scripts/osft_hpu_example.py

Length of output: 7231


Remove unused datetime import.

The datetime module is imported on line 22 but is never used in the script. The timing operations use time.time() instead. Remove this unused import.

🤖 Prompt for AI Agents
In examples/scripts/osft_hpu_example.py around lines 19 to 26, the datetime
module is imported but never used; remove the unused "from datetime import
datetime" import from the imports block so only required modules remain (e.g.,
keep os, sys, time, argparse, glob, and training_hub.osft) to eliminate the
unused import.


def find_most_recent_checkpoint(output_dir):
"""
Find the most recent checkpoint in the training output directory.

Args:
output_dir (str): Training output directory containing hf_format/ subdirectory

Returns:
str: Path to the most recent checkpoint

Raises:
ValueError: If no checkpoints are found
"""
# Get all checkpoint directories under hf_format
checkpoint_pattern = os.path.join(output_dir, "hf_format", "samples_*.0")
checkpoint_dirs = glob.glob(checkpoint_pattern)

if not checkpoint_dirs:
raise ValueError(f"No checkpoints found in {os.path.join(output_dir, 'hf_format')}")

# Find the most recently created checkpoint
most_recent_checkpoint = max(checkpoint_dirs, key=os.path.getctime)

return most_recent_checkpoint


def main():
parser = argparse.ArgumentParser(description=sample_name)

# Required parameters
parser.add_argument('--data-path', required=True,
help='Path to training data (JSONL format)')
parser.add_argument('--ckpt-output-dir', required=True,
help='Directory to save checkpoints')
parser.add_argument('--model-path', required=True,
help='Model path or HuggingFace name')

# Optional overrides
parser.add_argument('--num-epochs', type=int, default=3,
help='Number of epochs (default: 3)')
parser.add_argument('--unfreeze-rank-ratio', type=float, default=0.3,
help='Unfreeze rank ratio for OSFT (0.0-1.0, default: 0.3)')
parser.add_argument('--effective-batch-size', type=int, default=128,
help='effective batch size')
parser.add_argument('--max-seq-len', type=int, default=4096,
help='Maximum sequence length (default: 4096)')
parser.add_argument('--checkpoint-at-epoch', action='store_true', default=False,
help='Store checkpoint after each epoch')
parser.add_argument('--max-tokens-per-gpu', type=int, default=32768,
help='Max tokens per GPU (default: 32768)')
parser.add_argument('--nproc-per-node', type=int, default=8,
help='Number of GPUs (default: 8)')
Comment on lines +78 to +79
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix inconsistent terminology: GPUs → HPUs.

This is an HPU example, but the help text refers to "GPUs" instead of "HPUs".

Apply this diff:

-    parser.add_argument('--nproc-per-node', type=int, default=8,
-                       help='Number of GPUs (default: 8)')
+    parser.add_argument('--nproc-per-node', type=int, default=8,
+                       help='Number of HPUs per node (default: 8)')
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
parser.add_argument('--nproc-per-node', type=int, default=8,
help='Number of GPUs (default: 8)')
parser.add_argument('--nproc-per-node', type=int, default=8,
help='Number of HPUs per node (default: 8)')
🤖 Prompt for AI Agents
In examples/scripts/osft_hpu_example.py around lines 78 to 79, the argparse help
text incorrectly refers to "GPUs"; update the help string to say "Number of HPUs
(default: 8)" so the terminology matches this HPU example and zachronsistency;
modify the parser.add_argument call's help parameter accordingly.

parser.add_argument('--unmask-messages', action='store_true', default=False,
help='Unmask messages during training (default: False)')
parser.add_argument('--learning-rate', type=float, default=5e-6,
help='Learning rate for training (default: 5e-6)')
parser.add_argument('--torch-compile', action='store_true', default=False,
help='Enable torch.compile, hpu only')
parser.add_argument('--num-chunks', type=int, default=1,
help='Number of chunks to split dataset into for sequential training')

args = parser.parse_args()

# sample configuration
print(f"🚀 {sample_name}")
print("=" * 50)
print(f"Model: {args.model_path}")
print(f"Data: {args.data_path}")
print(f"Output: {args.ckpt_output_dir}")
print(f"GPUs: {args.nproc_per_node}")
Comment on lines +91 to +97
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix terminology inconsistency in banner output.

Similar to the help text, the banner output refers to "GPUs" instead of "HPUs".

Apply this diff:

-    # sample configuration
-    print(f"🚀 {sample_name}")
+    # Sample configuration
+    print(f"🚀 {SAMPLE_NAME}")  # Update if sample_name is renamed
     print("=" * 50)
     print(f"Model: {args.model_path}")
     print(f"Data: {args.data_path}")
     print(f"Output: {args.ckpt_output_dir}")
-    print(f"GPUs: {args.nproc_per_node}")
+    print(f"HPUs: {args.nproc_per_node}")

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In examples/scripts/osft_hpu_example.py around lines 91 to 97, the banner prints
"GPUs" but the script and help text refer to HPUs; update the label in the
f-string so it prints "HPUs" instead of "GPUs" (i.e., change the line that
prints f"GPUs: {args.nproc_per_node}" to use "HPUs" to keep terminology
consistent).

print(f"Unfreeze Rank Ratio: {args.unfreeze_rank_ratio}")
print(f"Max tokens per GPU: {args.max_tokens_per_gpu:,}")
print()
print("📝 OSFT Benefits:")
print(" • Preserve model's strong general capabilities")
print(" • Add domain-specific knowledge efficiently")
print(" • No need for complex data mixing or replay buffers")
print()

# Training configuration
start_time = time.time()

try:
osft_params = {
# Model and data
'model_path': args.model_path,
'data_path': args.data_path,
'ckpt_output_dir': args.ckpt_output_dir,

# OSFT-specific parameters
'unfreeze_rank_ratio': args.unfreeze_rank_ratio,

# Training parameters
'num_epochs': args.num_epochs,
'effective_batch_size': args.effective_batch_size,
'learning_rate': args.learning_rate,
'max_seq_len': args.max_seq_len,
'max_tokens_per_gpu': args.max_tokens_per_gpu,

# Data processing
'data_output_dir': "/dev/shm", # Use RAM disk for speed
'warmup_steps': 0,
'unmask_messages': args.unmask_messages,

# Optimization
'use_liger': False,
'seed': 42,
'lr_scheduler': 'cosine', # Cosine scheduler works well with OSFT

# Checkpointing
'checkpoint_at_epoch': args.checkpoint_at_epoch,
'save_final_checkpoint': True,

# Single-node multi-GPU setup
'nproc_per_node': args.nproc_per_node,
'nnodes': 1,
'node_rank': 0,
'rdzv_id': 103,
'rdzv_endpoint': "127.0.0.1:29500",

# HPU specific arguments
'device': 'hpu',
'torch_compile': args.torch_compile,
'num_chunks': args.num_chunks,
Comment on lines +148 to +151
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify HPU-specific parameters are supported.

The script passes HPU-specific parameters (device='hpu', torch_compile, num_chunks) to the osft() function via **kwargs. Ensure these parameters are properly handled by the OSFT implementation.

Verify that the osft function accepts these parameters:


🏁 Script executed:

#!/bin/bash
# Check if the osft function signature or implementation handles device, torch_compile, and num_chunks
rg -n -A 20 'def osft\(' src/training_hub/algorithms/osft.py
rg -n 'device.*hpu|torch_compile|num_chunks' src/training_hub/

Length of output: 773



🏁 Script executed:

#!/bin/bash
# Get full osft function signature with more context
rg -n -A 50 'def osft\(' src/training_hub/algorithms/osft.py | head -60

Length of output: 1984


🏁 Script executed:

#!/bin/bash
# Check if osft accepts **kwargs
rg -n -B 2 -A 15 'def osft\(' src/training_hub/algorithms/osft.py | grep -E 'def osft|kwargs'

Length of output: 13


🏁 Script executed:

#!/bin/bash
# Search for any handling of device, torch_compile, num_chunks in the entire osft.py file
cat src/training_hub/algorithms/osft.py | head -100

Length of output: 4035


🏁 Script executed:

#!/bin/bash
# Look for kwargs unpacking or parameter forwarding in osft function
rg -n '\*\*' src/training_hub/algorithms/osft.py

Length of output: 556


🏁 Script executed:

#!/bin/bash
# Find where osft is defined and check parameter docs/annotations
fd -t f 'osft.py' src/training_hub/

Length of output: 35


🏁 Script executed:

#!/bin/bash
# Get the full osft function - with line numbers
wc -l src/training_hub/algorithms/osft.py

Length of output: 39


🏁 Script executed:

#!/bin/bash
# Extract full osft function definition
sed -n '448,550p' src/training_hub/algorithms/osft.py

Length of output: 2274


🏁 Script executed:

#!/bin/bash
# Search for how osft function might use or document these parameters
rg -n 'device|torch_compile|num_chunks' src/training_hub/algorithms/osft.py -A 2 -B 2

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Check examples directory for how osft is called
rg -n 'osft\(' examples/scripts/osft_hpu_example.py -A 10 -B 5

Length of output: 591


🏁 Script executed:

#!/bin/bash
# Look at the full context of the call in the example script around lines 148-151
sed -n '130,160p' examples/scripts/osft_hpu_example.py

Length of output: 996


🏁 Script executed:

#!/bin/bash
# Search for documentation or type hints that mention these parameters
rg -n 'device.*:.*str|torch_compile.*:|num_chunks.*:' src/training_hub/

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Check if there's any kwargs handling in osft
rg -n 'kwargs' src/training_hub/algorithms/osft.py

Length of output: 675


🏁 Script executed:

#!/bin/bash
# Look for where osft function might reference these parameters implicitly
rg -n "kwargs\['device'\]|kwargs\['torch_compile'\]|kwargs\['num_chunks'\]" src/training_hub/

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Get the line count and then extract the full osft function
sed -n '448,500p' src/training_hub/algorithms/osft.py

Length of output: 1850


🏁 Script executed:

#!/bin/bash
# Check for the full function signature by looking for def and closing parenthesis
python3 << 'EOF'
import re

try:
    with open('src/training_hub/algorithms/osft.py', 'r') as f:
        content = f.read()
    
    # Find the osft function definition
    match = re.search(r'def osft\([^)]*\):', content, re.DOTALL)
    if match:
        print("=== osft function signature ===")
        print(match.group(0))
    else:
        print("Function signature pattern not found")
        # Try simpler pattern
        match = re.search(r'def osft\([\s\S]{0,1000}\):', content)
        if match:
            print("=== osft function (extended search) ===")
            print(match.group(0))
except Exception as e:
    print(f"Error: {e}")
EOF

Length of output: 36


🏁 Script executed:

#!/bin/bash
# Check what parameters the osft function actually accepts
python3 << 'EOF'
import inspect
import sys
sys.path.insert(0, '.')

try:
    # Try to find and parse the file
    with open('src/training_hub/algorithms/osft.py', 'r') as f:
        lines = f.readlines()
    
    # Find the osft function starting at line 448
    func_start = 447  # 0-indexed
    paren_count = 0
    found_start = False
    func_sig_lines = []
    
    for i in range(func_start, len(lines)):
        line = lines[i]
        func_sig_lines.append(line.rstrip())
        
        if 'def osft(' in line:
            found_start = True
            paren_count = line.count('(') - line.count(')')
        elif found_start:
            paren_count += line.count('(') - line.count(')')
        
        if found_start and ':' in line and paren_count == 0:
            break
    
    print("=== Full osft function signature ===")
    for line in func_sig_lines:
        print(line)
        
except Exception as e:
    print(f"Error: {e}")
EOF

Length of output: 1157


🏁 Script executed:

#!/bin/bash
# Check if device, torch_compile, num_chunks are passed in the example call
grep -n -B 5 -A 15 "osft(" examples/scripts/osft_hpu_example.py

Length of output: 921


🏁 Script executed:

#!/bin/bash
# Look for kwargs parameter or **kwargs in osft function
rg 'def osft' src/training_hub/algorithms/osft.py -A 40 | grep -E 'kwargs|\*\*'

Length of output: 67


🏁 Script executed:

#!/bin/bash
# Final check - get osft signature using ast-grep
ast-grep --pattern $'def osft($_) {
  $$$
}'

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Simple extraction - first 100 lines after osft function start
awk '/^def osft\(/{flag=1; count=0} flag{print; count++; if(count>60) exit}' src/training_hub/algorithms/osft.py

Length of output: 2148


🏁 Script executed:

#!/bin/bash
# Check for backend parameter which might handle device routing
rg -n 'backend' src/training_hub/algorithms/osft.py | head -20

Length of output: 861


🏁 Script executed:

#!/bin/bash
# Look for any forwarding to other functions that might accept these params
rg -n 'device|torch_compile|num_chunks' src/training_hub/ --type py -B 3 -A 3

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Get osft function body to see what it does with parameters
sed -n '448,600p' src/training_hub/algorithms/osft.py | grep -E '^\s*(device|torch_compile|num_chunks|def|return)'

Length of output: 37


🏁 Script executed:

#!/bin/bash
# Try getting the exact lines of the function definition in one go
python3 << 'EOF'
with open('src/training_hub/algorithms/osft.py', 'r') as f:
    lines = f.readlines()

in_def = False
paren_depth = 0
for i in range(447, min(len(lines), 550)):
    line = lines[i]
    if 'def osft(' in line:
        in_def = True
    
    if in_def:
        print(f"{i+1}: {line}", end='')
        paren_depth += line.count('(') - line.count(')')
        if ':' in line and paren_depth == 0:
            break
EOF

Length of output: 1285


🏁 Script executed:

#!/bin/bash
# Also check where osft is being called with the parameters
python3 << 'EOF'
import re
with open('examples/scripts/osft_hpu_example.py', 'r') as f:
    content = f.read()

# Find the osft call
match = re.search(r'osft\([^)]*\)', content, re.DOTALL)
if match:
    print("=== osft function call in example ===")
    print(match.group(0)[:500])
EOF

Length of output: 57


🏁 Script executed:

#!/bin/bash
# Check the actual call site around line 148-151
sed -n '140,160p' examples/scripts/osft_hpu_example.py

Length of output: 614


🏁 Script executed:

#!/bin/bash
# Look for any wrapper or intermediate function that might accept **kwargs
rg -n 'def.*osft|osft(' examples/scripts/osft_hpu_example.py -B 3 -A 10

Length of output: 80



🏁 Script executed:

#!/bin/bash
# Find MiniTrainerOSFTBackend.execute_training implementation
rg -n 'def execute_training' src/training_hub/algorithms/osft.py -A 30

Length of output: 1395


🏁 Script executed:

#!/bin/bash
# Find _validate_param_types to see what parameters are accepted
rg -n '_validate_param_types' src/training_hub/algorithms/osft.py -A 50

Length of output: 4304


🏁 Script executed:

#!/bin/bash
# Check optional_param_types to see if device, torch_compile, num_chunks are listed
sed -n '220,260p' src/training_hub/algorithms/osft.py

Length of output: 1740


🏁 Script executed:

#!/bin/bash
# Search for how parameters are used in backend execution
rg -n 'device|torch_compile|num_chunks' src/training_hub/ -B 2 -A 2

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Check if instructlab training handles these parameters
find src/training_hub -name '*.py' -type f | xargs grep -l 'device.*hpu\|torch_compile'

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Look for parameter validation and what happens to unknown kwargs
sed -n '190,210p' src/training_hub/algorithms/osft.py

Length of output: 822


🏁 Script executed:

#!/bin/bash
# Check what the backend.execute_training actually does
sed -n '370,400p' src/training_hub/algorithms/osft.py

Length of output: 1473


🏁 Script executed:

#!/bin/bash
# Examine the full MiniTrainerOSFTBackend class
sed -n '316,445p' src/training_hub/algorithms/osft.py

Length of output: 5768


🏁 Script executed:

#!/bin/bash
# Check if these params are stripped out or cause errors
rg -n 'Unknown|unsupported|device|torch_compile|num_chunks' src/training_hub/algorithms/osft.py

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Find MiniTrainerOSFTBackend.execute_training implementation
rg -n 'def execute_training' src/training_hub/algorithms/osft.py -A 30

Length of output: 1395


🏁 Script executed:

#!/bin/bash
# Find _validate_param_types to see what parameters are accepted
rg -n '_validate_param_types' src/training_hub/algorithms/osft.py -A 50

Length of output: 4304


🏁 Script executed:

#!/bin/bash
# Check optional_param_types to see if device, torch_compile, num_chunks are listed
sed -n '220,260p' src/training_hub/algorithms/osft.py

Length of output: 1740


🏁 Script executed:

#!/bin/bash
# Look at parameter validation logic
sed -n '217,280p' src/training_hub/algorithms/osft.py

Length of output: 2659


🏁 Script executed:

#!/bin/bash
# Check what happens to unrecognized kwargs
sed -n '190,215p' src/training_hub/algorithms/osft.py

Length of output: 975


🏁 Script executed:

#!/bin/bash
# Examine the full MiniTrainerOSFTBackend class and execute_training
sed -n '316,445p' src/training_hub/algorithms/osft.py

Length of output: 5768


🏁 Script executed:

#!/bin/bash
# Check if parameters flow through to instructlab training or are handled locally
rg -n 'instructlab' src/training_hub/algorithms/osft.py -B 2 -A 5

Length of output: 733


🏁 Script executed:

#!/bin/bash
# Look for parameter mapping or transformation
sed -n '335,370p' src/training_hub/algorithms/osft.py

Length of output: 2028


HPU-specific parameters are accepted but silently discarded during backend execution.

The osft() function accepts device, torch_compile, and num_chunks via **kwargs and passes them through the validation layer. However, in MiniTrainerOSFTBackend.execute_training() (lines 346-369), parameters are filtered into TorchrunArgs and TrainingArgs dataclass fields. Since these HPU parameters are not fields in either dataclass, they are silently dropped before training begins.

Either add these parameters to get_optional_params() in OSFTAlgorithm.train() with proper type hints, or remove them from the example script if they're not supported.

🤖 Prompt for AI Agents
In examples/scripts/osft_hpu_example.py around lines 148 to 151, HPU-specific
kwargs ('device', 'torch_compile', 'num_chunks') are accepted but dropped before
backend execution; add these parameters to OSFTAlgorithm.train()'s
get_optional_params() with proper type hints (device: str, torch_compile: bool,
num_chunks: int) and ensure they are either mapped into the existing
TorchrunArgs/TrainingArgs dataclasses or the
MiniTrainerOSFTBackend.execute_training() is updated to accept and forward them
to the backend execution path; alternatively, if HPU support isn't intended,
remove those keys from the example script so they aren't silently ignored.

}


osft(**osft_params)

end_time = time.time()
duration = end_time - start_time

most_recent_checkpoint = find_most_recent_checkpoint(args.ckpt_output_dir)

print("=" * 50)
print("✅ OSFT Training completed successfully!")
print(f"⏱️ Duration: {duration/3600:.2f} hours")
print(f"📁 Checkpoints: {args.ckpt_output_dir}/hf_format")
print(f" Most recent checkpoint: {most_recent_checkpoint}")
print()
print("🎯 Your model has been successfully adapted!")
print(" The model now incorporates your domain-specific knowledge")
print(" while maintaining its original instruction-following abilities.")

except Exception as e:
end_time = time.time()
duration = end_time - start_time

print("=" * 50)
print(f"❌ Training failed after {duration/60:.1f} minutes")
print(f"Error: {e}")
print()
print("💡 Troubleshooting tips:")
print(" - Reduce --max-tokens-per-gpu if you see OOM errors")
print(" - For domain adaptation, try --unfreeze-rank-ratio between 0.2-0.4")
sys.exit(1)


if __name__ == "__main__":
main()

127 changes: 127 additions & 0 deletions examples/scripts/sft_hpu_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#!/usr/bin/env python3
sample_name="SFT Training Example for HPU"
"""
This script demonstrates how to do SFT training on HPU
using a single-node, multi-GPU setup with training_hub.

Example usage:
python sft_hpu_example.py \\
--data-path /path/to/data.jsonl \\
--ckpt-output-dir /path/to/checkpoints
"""

import os
import sys
import time
from datetime import datetime
import argparse

from training_hub import sft


def main():
#disable HPU backend autoloading
os.environ['PT_HPU_AUTOLOAD'] = '0'

parser = argparse.ArgumentParser(description=sample_name)

# Required parameters
parser.add_argument('--data-path', required=True,
help='Path to training data (JSONL format)')
parser.add_argument('--ckpt-output-dir', required=True,
help='Directory to save checkpoints')
parser.add_argument('--model-path', required=True,
help='Model path or HuggingFace name')

# Optional overrides
parser.add_argument('--num-epochs', type=int, default=3,
help='Number of epochs (default: 3)')
parser.add_argument('--effective-batch-size', type=int, default=128,
help='effective batch size')
parser.add_argument('--max-seq-len', type=int, default=16384,
help='Maximum sequence length')
parser.add_argument('--checkpoint-at-epoch', action='store_true',
help='Store checkpoint after each epoch')
parser.add_argument('--max-tokens-per-gpu', type=int, default=32768,
help='Max tokens per GPU')
parser.add_argument('--nproc-per-node', type=int, default=8,
help='Number of GPUs')
parser.add_argument('--torch-compile', action='store_true', default=False,
help='Enable torch.compile, hpu only')
parser.add_argument('--num-chunks', type=int, default=1,
help='Number of chunks to split dataset into for sequential training')

args = parser.parse_args()

# sample configuration
print(f"🚀 {sample_name}")
print("=" * 50)
print(f"Model: {args.model_path}")
print(f"Data: {args.data_path}")
print(f"Output: {args.ckpt_output_dir}")
print(f"GPUs: {args.nproc_per_node}")
print(f"Max tokens per GPU: {args.max_tokens_per_gpu:,}")
print()

# Training configuration optimized for Llama 3.1 8B Instruct
start_time = time.time()

try:
result = sft(
# Model and data
model_path=args.model_path,
data_path=args.data_path,
ckpt_output_dir=args.ckpt_output_dir,

# Training parameters
num_epochs=args.num_epochs,
effective_batch_size=args.effective_batch_size,
learning_rate=1e-5, # Lower LR for instruct model
max_seq_len=args.max_seq_len,
max_tokens_per_gpu=args.max_tokens_per_gpu,

# Data processing
data_output_dir="/dev/shm", # Use RAM disk for speed
warmup_steps=100,
save_samples=0, # 0 disables sample-based checkpointing, use epoch-based only

# Checkpointing
checkpoint_at_epoch=args.checkpoint_at_epoch,
accelerate_full_state_at_epoch=False, # Disable for smaller checkpoints (no auto-resumption)

# Single-node multi-GPU setup
nproc_per_node=args.nproc_per_node,
nnodes=1,
node_rank=0,
rdzv_id=101,
rdzv_endpoint="127.0.0.1:29500",

# HPU specific arguments
disable_flash_attn = True,
device = 'hpu',
torch_compile = args.torch_compile,
num_chunks = args.num_chunks,
)

end_time = time.time()
duration = end_time - start_time

print("=" * 50)
print("✅ Training completed successfully!")
print(f"⏱️ Duration: {duration/3600:.2f} hours")
print(f"📁 Checkpoints: {args.ckpt_output_dir}/hf_format/")

except Exception as e:
end_time = time.time()
duration = end_time - start_time

print("=" * 50)
print(f"❌ Training failed after {duration/60:.1f} minutes")
print(f"Error: {e}")
print()
print("💡 Try reducing --max-tokens-per-gpu if you see OOM errors")
sys.exit(1)


if __name__ == "__main__":
main()
4 changes: 3 additions & 1 deletion src/training_hub/profiling/memory_estimator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import override
from transformers import AutoModel
from mini_trainer.osft_utils import MODEL_CONFIGS
import numpy as np

"""
Expand Down Expand Up @@ -322,6 +321,9 @@ def __init__(
raise ValueError("Ratio must be in the range [0, 1]")

# Check to see which terms need to be included in the search for valid layers
# MODEL_CONFIGS require transformers>4.56.0, that conflict with HPU enabling,
# temporary moving MODEL_CONFIGS import here to enable HPU support
from mini_trainer.osft_utils import MODEL_CONFIGS
self.target_terms = MODEL_CONFIGS['default']['patterns']
for key in MODEL_CONFIGS.keys():
if self.model_path.find(key) > -1:
Expand Down