-
Notifications
You must be signed in to change notification settings - Fork 18
Add HPU SFT example #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughAdds two new HPU example scripts (SFT and OSFT) for single-node multi-GPU runs and defers importing Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Script as Example Script (sft/osft)
participant TH as training_hub.sft / training_hub.osft
participant PG as ProcessGroup / Rendezvous
participant HPU as HPU Workers
participant FS as Filesystem (checkpoints)
User->>Script: run with CLI args
Script->>Script: parse args, set PT_HPU_AUTOLOAD=0, print banner
Script->>TH: call sft/osft(config with HPU & rendezvous params)
TH->>PG: init single-node multi-GPU (nproc_per_node, rdzv)
PG->>HPU: spawn workers on HPU devices
TH->>TH: training loop, checkpointing, profiling
TH->>FS: write checkpoints
alt success
TH-->>Script: return / locate checkpoint path
Script-->>User: print duration and checkpoint location
else failure
TH-->>Script: raise exception
Script-->>User: print error, duration, suggest reducing max-tokens-per-gpu
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
This sample should be used with instructlab/training#660 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (7)
examples/scripts/sft_hpu_example.py (7)
66-66: Remove unused variable assignmentresult is not used; simplify to avoid linter F841 and keep intent clear.
Apply this diff:
- result = sft( + sft(
36-49: Make /dev/shm usage optional and portableHard-coding /dev/shm breaks on non-Linux and triggers S108. Provide a CLI override and fall back only if available.
Apply this diff:
@@ # Optional overrides @@ parser.add_argument('--nproc-per-node', type=int, default=8, help='Number of HPUs (processes) on this node') + parser.add_argument('--data-output-dir', default=None, + help='Directory for processed/intermediate data (default: /dev/shm if available)') @@ - data_output_dir="/dev/shm", # Use RAM disk for speed + data_output_dir=( + args.data_output_dir + if args.data_output_dir + else ("/dev/shm" if os.path.exists("/dev/shm") else None) + ), # Prefer RAM disk when available warmup_steps=100, save_samples=0, # 0 disables sample-based checkpointing, use epoch-based onlyAlso applies to: 80-83
50-51: Validate inputs and ensure output directory existsFail fast on bad paths; create checkpoint dir if missing.
Apply this diff:
args = parser.parse_args() + # Basic validation + if not os.path.exists(args.data_path): + sys.exit(f"Training data not found: {args.data_path}") + os.makedirs(args.ckpt_output_dir, exist_ok=True)
108-117: Avoid catching blind Exception; handle common cases explicitlyHandle KeyboardInterrupt gracefully and OOM hints for RuntimeError/MemoryError; otherwise re-raise.
Apply this diff:
- except Exception as e: - end_time = time.time() - duration = end_time - start_time - - print("=" * 50) - print(f"❌ Training failed after {duration/60:.1f} minutes") - print(f"Error: {e}") - print() - print("💡 Try reducing --max-tokens-per-gpu if you see OOM errors") - sys.exit(1) + except KeyboardInterrupt: + end_time = time.time() + duration = end_time - start_time + print("=" * 50) + print(f"⏹️ Training canceled by user after {duration/60:.1f} minutes") + sys.exit(130) + except (RuntimeError, MemoryError) as e: + end_time = time.time() + duration = end_time - start_time + print("=" * 50) + print(f"❌ Training failed after {duration/60:.1f} minutes") + print(f"Error: {e}") + if "oom" in str(e).lower() or "out of memory" in str(e).lower(): + print("💡 Try reducing --max-tokens-per-gpu if you see OOM errors") + sys.exit(1)
47-48: Consider auto-detecting processes per nodeDefaulting to 8 may oversubscribe most machines. Consider default=1 or autodetect from HABANA_VISIBLE_DEVICES to set a safer default.
1-1: Shebang present — make file executable or drop itEither mark the script executable (chmod +x) or remove the shebang to silence EXE001.
105-107: Avoid assuming hf_format/ in checkpoint pathexamples/scripts/sft_hpu_example.py:105-107 — repo search found no occurrences of "hf_format"; print the configured ckpt_output_dir instead of appending "/hf_format/".
Apply this diff:
- print(f"📁 Checkpoints: {args.ckpt_output_dir}/hf_format/") + print(f"📁 Checkpoints directory: {args.ckpt_output_dir}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/scripts/sft_hpu_example.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/scripts/sft_hpu_example.py (1)
src/training_hub/algorithms/sft.py (1)
sft(177-249)
🪛 Ruff (0.13.1)
examples/scripts/sft_hpu_example.py
1-1: Shebang is present but file is not executable
(EXE001)
66-66: Local variable result is assigned to but never used
Remove assignment to unused variable result
(F841)
80-80: Probable insecure usage of temporary file or directory: "/dev/shm"
(S108)
108-108: Do not catch blind exception: Exception
(BLE001)
🔇 Additional comments (2)
examples/scripts/sft_hpu_example.py (2)
52-64: General: solid exampleNice, concise banner and timing with a sensible default config for SFT on HPU.
Consider a quick smoke run on a small subset to validate CLI wiring.
95-98: Verify backend accepts disable_flash_attn and device kwargsSFTAlgorithm.train forwards arbitrary kwargs to the backend (params.update(kwargs)) and InstructLabTrainingSFTBackend builds TrainingArgs(**training_params) without special handling for these keys; the repo has no evidence TrainingArgs/TorchrunArgs accept disable_flash_attn or device — confirm the external instructlab.training API accepts them or remove/translate these kwargs in the example or backend. Locations: examples/scripts/sft_hpu_example.py (flags added), src/training_hub/algorithms/sft.py (SFTAlgorithm.train and InstructLabTrainingSFTBackend.execute_training).
Signed-off-by: Sergey Plotnikov <[email protected]>
Signed-off-by: Sergey Plotnikov <[email protected]>
4b1088a to
84b62f1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (5)
examples/scripts/sft_hpu_example.py (5)
5-5: Use HPU terminology instead of GPU for clarity.The docstring mentions "multi-GPU" but this is an HPU training example. This was flagged in a previous review.
7-11: Add --model-path to the example usage.The example is missing the required
--model-pathargument. This was flagged in a previous review and remains unresolved.
19-19: Set PT_HPU_AUTOLOAD before importing training_hub.The environment variable is set in
main()(line 24) after importingtraining_hub. If the HPU backend reads this variable at import time, it's too late. This was flagged in a previous review.
45-48: Use HPU terminology in argument help text.The help text for
--max-tokens-per-gpuand--nproc-per-nodementions "GPU" instead of "HPU". This was flagged in a previous review.
62-63: Use HPU terminology in output messages.The print statements display "GPUs" and "GPU" but should say "HPUs" and "HPU" for consistency. This was flagged in a previous review.
🧹 Nitpick comments (2)
examples/scripts/sft_hpu_example.py (2)
92-92: Update comment to use HPU terminology.The comment mentions "multi-GPU" but should say "multi-HPU" for consistency with the script's purpose.
Apply this diff:
- # Single-node multi-GPU setup + # Single-node multi-HPU setup
100-103: Maintain consistent spacing around equals signs.Lines 100-101 have spaces around
=while lines 102-103 don't. Pick one style for consistency.Apply this diff for consistent spacing:
- disable_flash_attn = True, - device = 'hpu', - torch_compile = args.torch_compile, - num_chunks = args.num_chunks, + disable_flash_attn=True, + device='hpu', + torch_compile=args.torch_compile, + num_chunks=args.num_chunks,
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/scripts/sft_hpu_example.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/scripts/sft_hpu_example.py (1)
src/training_hub/algorithms/sft.py (1)
sft(169-248)
🪛 Ruff (0.14.1)
examples/scripts/sft_hpu_example.py
1-1: Shebang is present but file is not executable
(EXE001)
70-70: Local variable result is assigned to but never used
Remove assignment to unused variable result
(F841)
84-84: Probable insecure usage of temporary file or directory: "/dev/shm"
(S108)
114-114: Do not catch blind exception: Exception
(BLE001)
🔇 Additional comments (1)
examples/scripts/sft_hpu_example.py (1)
112-112: No issues found. Thehf_format/subdirectory path is correct.Verification of the codebase confirms that checkpoints are consistently output to the
hf_format/subdirectory across all example scripts. Helper functions throughout the codebase explicitly construct checkpoint paths as{output_dir}/hf_format/samples_*, and docstrings document this structure. The print statement at line 112 is accurate.
Signed-off-by: Sergey Plotnikov <[email protected]>
Signed-off-by: Sergey Plotnikov <[email protected]>
|
I've added OSFT support. It should be used with Red-Hat-AI-Innovation-Team/mini_trainer#59. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (3)
examples/scripts/osft_hpu_example.py (3)
1-2: Add executable permission and move module-level variable.The shebang indicates this script should be executable, but the file lacks the executable bit. Additionally, defining
sample_namebefore the module docstring is unconventional in Python.Apply these changes:
#!/usr/bin/env python3 -sample_name="OSFT Training Example for HPU" """ This script demonstrates how to do OSFT (Orthogonal Subspace Fine-Tuning) training on HPU using a single-node, multi-GPU setup with training_hub. OSFT allows continual training without catastrophic forgetting, making it ideal for: - Adapting models to specialized domains (medical, legal, technical) - Adding new knowledge without degrading general capabilities - Fine-tuning without complex replay mechanisms Example usage: python osft_hpu_example.py \\ --data-path /path/to/data.jsonl \\ --ckpt-output-dir /path/to/checkpoints \\ --model-path meta-llama/Meta-Llama-3.1-8B-Instruct """ + +SAMPLE_NAME = "OSFT Training Example for HPU"Also run:
chmod +x examples/scripts/osft_hpu_example.py
128-128: Make data output directory configurable.The data output directory is hardcoded to
/dev/shm, which may not exist on all systems (e.g., Windows, certain container environments) and is flagged as a potential security concern by static analysis.Consider making it configurable with a sensible default:
+ parser.add_argument('--data-output-dir', default="/dev/shm", + help='Directory for processed data (default: /dev/shm)')Then in the osft_params:
- 'data_output_dir': "/dev/shm", # Use RAM disk for speed + 'data_output_dir': args.data_output_dir, # Use RAM disk for speed
42-49: Use modification time instead of creation time for reliable checkpoint detection.The
os.path.getctime()call may not reliably indicate the most recent checkpoint. On Unix systems,ctimetracks metadata changes rather than creation time, and files may have outdated timestamps if copied. Useos.path.getmtime()(modification time) instead, which is more reliable across platforms:- # Find the most recently created checkpoint - most_recent_checkpoint = max(checkpoint_dirs, key=os.path.getctime) + # Find the most recently modified checkpoint + most_recent_checkpoint = max(checkpoint_dirs, key=os.path.getmtime)The checkpoint pattern
samples_*.0is correct and intentional for OSFT implementations.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/scripts/osft_hpu_example.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/scripts/osft_hpu_example.py (1)
src/training_hub/algorithms/osft.py (1)
osft(448-513)
🪛 Ruff (0.14.4)
examples/scripts/osft_hpu_example.py
1-1: Shebang is present but file is not executable
(EXE001)
46-46: Avoid specifying long messages outside the exception class
(TRY003)
128-128: Probable insecure usage of temporary file or directory: "/dev/shm"
(S108)
172-172: Do not catch blind exception: Exception
(BLE001)
🔇 Additional comments (1)
examples/scripts/osft_hpu_example.py (1)
172-183: Good user-facing error handling.The broad exception catch is appropriate here for a top-level example script. It provides users with helpful troubleshooting guidance and a clean exit.
| import os | ||
| import sys | ||
| import time | ||
| from datetime import datetime | ||
| import argparse | ||
| import glob | ||
|
|
||
| from training_hub import osft |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify datetime import usage.
The datetime module is imported but doesn't appear to be used in the script. Only time.time() is called for duration tracking.
Run this script to confirm the import is unused:
🏁 Script executed:
#!/bin/bash
# Check if datetime module is used beyond the import statement
rg -n '\bdatetime\.' examples/scripts/osft_hpu_example.pyLength of output: 0
Let me verify this more thoroughly by checking for any datetime usage patterns and reviewing the complete file:
---
🏁 Script executed:
```shell
#!/bin/bash
# Check for any datetime usage in the file (class, function calls, etc.)
rg -n 'datetime' examples/scripts/osft_hpu_example.py
Length of output: 32
🏁 Script executed:
#!/bin/bash
# Show the complete file to understand the context
cat examples/scripts/osft_hpu_example.pyLength of output: 7231
Remove unused datetime import.
The datetime module is imported on line 22 but is never used in the script. The timing operations use time.time() instead. Remove this unused import.
🤖 Prompt for AI Agents
In examples/scripts/osft_hpu_example.py around lines 19 to 26, the datetime
module is imported but never used; remove the unused "from datetime import
datetime" import from the imports block so only required modules remain (e.g.,
keep os, sys, time, argparse, glob, and training_hub.osft) to eliminate the
unused import.
| parser.add_argument('--nproc-per-node', type=int, default=8, | ||
| help='Number of GPUs (default: 8)') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix inconsistent terminology: GPUs → HPUs.
This is an HPU example, but the help text refers to "GPUs" instead of "HPUs".
Apply this diff:
- parser.add_argument('--nproc-per-node', type=int, default=8,
- help='Number of GPUs (default: 8)')
+ parser.add_argument('--nproc-per-node', type=int, default=8,
+ help='Number of HPUs per node (default: 8)')📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| parser.add_argument('--nproc-per-node', type=int, default=8, | |
| help='Number of GPUs (default: 8)') | |
| parser.add_argument('--nproc-per-node', type=int, default=8, | |
| help='Number of HPUs per node (default: 8)') |
🤖 Prompt for AI Agents
In examples/scripts/osft_hpu_example.py around lines 78 to 79, the argparse help
text incorrectly refers to "GPUs"; update the help string to say "Number of HPUs
(default: 8)" so the terminology matches this HPU example and zachronsistency;
modify the parser.add_argument call's help parameter accordingly.
| # sample configuration | ||
| print(f"🚀 {sample_name}") | ||
| print("=" * 50) | ||
| print(f"Model: {args.model_path}") | ||
| print(f"Data: {args.data_path}") | ||
| print(f"Output: {args.ckpt_output_dir}") | ||
| print(f"GPUs: {args.nproc_per_node}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix terminology inconsistency in banner output.
Similar to the help text, the banner output refers to "GPUs" instead of "HPUs".
Apply this diff:
- # sample configuration
- print(f"🚀 {sample_name}")
+ # Sample configuration
+ print(f"🚀 {SAMPLE_NAME}") # Update if sample_name is renamed
print("=" * 50)
print(f"Model: {args.model_path}")
print(f"Data: {args.data_path}")
print(f"Output: {args.ckpt_output_dir}")
- print(f"GPUs: {args.nproc_per_node}")
+ print(f"HPUs: {args.nproc_per_node}")Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In examples/scripts/osft_hpu_example.py around lines 91 to 97, the banner prints
"GPUs" but the script and help text refer to HPUs; update the label in the
f-string so it prints "HPUs" instead of "GPUs" (i.e., change the line that
prints f"GPUs: {args.nproc_per_node}" to use "HPUs" to keep terminology
consistent).
| # HPU specific arguments | ||
| 'device': 'hpu', | ||
| 'torch_compile': args.torch_compile, | ||
| 'num_chunks': args.num_chunks, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify HPU-specific parameters are supported.
The script passes HPU-specific parameters (device='hpu', torch_compile, num_chunks) to the osft() function via **kwargs. Ensure these parameters are properly handled by the OSFT implementation.
Verify that the osft function accepts these parameters:
🏁 Script executed:
#!/bin/bash
# Check if the osft function signature or implementation handles device, torch_compile, and num_chunks
rg -n -A 20 'def osft\(' src/training_hub/algorithms/osft.py
rg -n 'device.*hpu|torch_compile|num_chunks' src/training_hub/Length of output: 773
🏁 Script executed:
#!/bin/bash
# Get full osft function signature with more context
rg -n -A 50 'def osft\(' src/training_hub/algorithms/osft.py | head -60Length of output: 1984
🏁 Script executed:
#!/bin/bash
# Check if osft accepts **kwargs
rg -n -B 2 -A 15 'def osft\(' src/training_hub/algorithms/osft.py | grep -E 'def osft|kwargs'Length of output: 13
🏁 Script executed:
#!/bin/bash
# Search for any handling of device, torch_compile, num_chunks in the entire osft.py file
cat src/training_hub/algorithms/osft.py | head -100Length of output: 4035
🏁 Script executed:
#!/bin/bash
# Look for kwargs unpacking or parameter forwarding in osft function
rg -n '\*\*' src/training_hub/algorithms/osft.pyLength of output: 556
🏁 Script executed:
#!/bin/bash
# Find where osft is defined and check parameter docs/annotations
fd -t f 'osft.py' src/training_hub/Length of output: 35
🏁 Script executed:
#!/bin/bash
# Get the full osft function - with line numbers
wc -l src/training_hub/algorithms/osft.pyLength of output: 39
🏁 Script executed:
#!/bin/bash
# Extract full osft function definition
sed -n '448,550p' src/training_hub/algorithms/osft.pyLength of output: 2274
🏁 Script executed:
#!/bin/bash
# Search for how osft function might use or document these parameters
rg -n 'device|torch_compile|num_chunks' src/training_hub/algorithms/osft.py -A 2 -B 2Length of output: 0
🏁 Script executed:
#!/bin/bash
# Check examples directory for how osft is called
rg -n 'osft\(' examples/scripts/osft_hpu_example.py -A 10 -B 5Length of output: 591
🏁 Script executed:
#!/bin/bash
# Look at the full context of the call in the example script around lines 148-151
sed -n '130,160p' examples/scripts/osft_hpu_example.pyLength of output: 996
🏁 Script executed:
#!/bin/bash
# Search for documentation or type hints that mention these parameters
rg -n 'device.*:.*str|torch_compile.*:|num_chunks.*:' src/training_hub/Length of output: 0
🏁 Script executed:
#!/bin/bash
# Check if there's any kwargs handling in osft
rg -n 'kwargs' src/training_hub/algorithms/osft.pyLength of output: 675
🏁 Script executed:
#!/bin/bash
# Look for where osft function might reference these parameters implicitly
rg -n "kwargs\['device'\]|kwargs\['torch_compile'\]|kwargs\['num_chunks'\]" src/training_hub/Length of output: 0
🏁 Script executed:
#!/bin/bash
# Get the line count and then extract the full osft function
sed -n '448,500p' src/training_hub/algorithms/osft.pyLength of output: 1850
🏁 Script executed:
#!/bin/bash
# Check for the full function signature by looking for def and closing parenthesis
python3 << 'EOF'
import re
try:
with open('src/training_hub/algorithms/osft.py', 'r') as f:
content = f.read()
# Find the osft function definition
match = re.search(r'def osft\([^)]*\):', content, re.DOTALL)
if match:
print("=== osft function signature ===")
print(match.group(0))
else:
print("Function signature pattern not found")
# Try simpler pattern
match = re.search(r'def osft\([\s\S]{0,1000}\):', content)
if match:
print("=== osft function (extended search) ===")
print(match.group(0))
except Exception as e:
print(f"Error: {e}")
EOFLength of output: 36
🏁 Script executed:
#!/bin/bash
# Check what parameters the osft function actually accepts
python3 << 'EOF'
import inspect
import sys
sys.path.insert(0, '.')
try:
# Try to find and parse the file
with open('src/training_hub/algorithms/osft.py', 'r') as f:
lines = f.readlines()
# Find the osft function starting at line 448
func_start = 447 # 0-indexed
paren_count = 0
found_start = False
func_sig_lines = []
for i in range(func_start, len(lines)):
line = lines[i]
func_sig_lines.append(line.rstrip())
if 'def osft(' in line:
found_start = True
paren_count = line.count('(') - line.count(')')
elif found_start:
paren_count += line.count('(') - line.count(')')
if found_start and ':' in line and paren_count == 0:
break
print("=== Full osft function signature ===")
for line in func_sig_lines:
print(line)
except Exception as e:
print(f"Error: {e}")
EOFLength of output: 1157
🏁 Script executed:
#!/bin/bash
# Check if device, torch_compile, num_chunks are passed in the example call
grep -n -B 5 -A 15 "osft(" examples/scripts/osft_hpu_example.pyLength of output: 921
🏁 Script executed:
#!/bin/bash
# Look for kwargs parameter or **kwargs in osft function
rg 'def osft' src/training_hub/algorithms/osft.py -A 40 | grep -E 'kwargs|\*\*'Length of output: 67
🏁 Script executed:
#!/bin/bash
# Final check - get osft signature using ast-grep
ast-grep --pattern $'def osft($_) {
$$$
}'Length of output: 0
🏁 Script executed:
#!/bin/bash
# Simple extraction - first 100 lines after osft function start
awk '/^def osft\(/{flag=1; count=0} flag{print; count++; if(count>60) exit}' src/training_hub/algorithms/osft.pyLength of output: 2148
🏁 Script executed:
#!/bin/bash
# Check for backend parameter which might handle device routing
rg -n 'backend' src/training_hub/algorithms/osft.py | head -20Length of output: 861
🏁 Script executed:
#!/bin/bash
# Look for any forwarding to other functions that might accept these params
rg -n 'device|torch_compile|num_chunks' src/training_hub/ --type py -B 3 -A 3Length of output: 0
🏁 Script executed:
#!/bin/bash
# Get osft function body to see what it does with parameters
sed -n '448,600p' src/training_hub/algorithms/osft.py | grep -E '^\s*(device|torch_compile|num_chunks|def|return)'Length of output: 37
🏁 Script executed:
#!/bin/bash
# Try getting the exact lines of the function definition in one go
python3 << 'EOF'
with open('src/training_hub/algorithms/osft.py', 'r') as f:
lines = f.readlines()
in_def = False
paren_depth = 0
for i in range(447, min(len(lines), 550)):
line = lines[i]
if 'def osft(' in line:
in_def = True
if in_def:
print(f"{i+1}: {line}", end='')
paren_depth += line.count('(') - line.count(')')
if ':' in line and paren_depth == 0:
break
EOFLength of output: 1285
🏁 Script executed:
#!/bin/bash
# Also check where osft is being called with the parameters
python3 << 'EOF'
import re
with open('examples/scripts/osft_hpu_example.py', 'r') as f:
content = f.read()
# Find the osft call
match = re.search(r'osft\([^)]*\)', content, re.DOTALL)
if match:
print("=== osft function call in example ===")
print(match.group(0)[:500])
EOFLength of output: 57
🏁 Script executed:
#!/bin/bash
# Check the actual call site around line 148-151
sed -n '140,160p' examples/scripts/osft_hpu_example.pyLength of output: 614
🏁 Script executed:
#!/bin/bash
# Look for any wrapper or intermediate function that might accept **kwargs
rg -n 'def.*osft|osft(' examples/scripts/osft_hpu_example.py -B 3 -A 10Length of output: 80
🏁 Script executed:
#!/bin/bash
# Find MiniTrainerOSFTBackend.execute_training implementation
rg -n 'def execute_training' src/training_hub/algorithms/osft.py -A 30Length of output: 1395
🏁 Script executed:
#!/bin/bash
# Find _validate_param_types to see what parameters are accepted
rg -n '_validate_param_types' src/training_hub/algorithms/osft.py -A 50Length of output: 4304
🏁 Script executed:
#!/bin/bash
# Check optional_param_types to see if device, torch_compile, num_chunks are listed
sed -n '220,260p' src/training_hub/algorithms/osft.pyLength of output: 1740
🏁 Script executed:
#!/bin/bash
# Search for how parameters are used in backend execution
rg -n 'device|torch_compile|num_chunks' src/training_hub/ -B 2 -A 2Length of output: 0
🏁 Script executed:
#!/bin/bash
# Check if instructlab training handles these parameters
find src/training_hub -name '*.py' -type f | xargs grep -l 'device.*hpu\|torch_compile'Length of output: 0
🏁 Script executed:
#!/bin/bash
# Look for parameter validation and what happens to unknown kwargs
sed -n '190,210p' src/training_hub/algorithms/osft.pyLength of output: 822
🏁 Script executed:
#!/bin/bash
# Check what the backend.execute_training actually does
sed -n '370,400p' src/training_hub/algorithms/osft.pyLength of output: 1473
🏁 Script executed:
#!/bin/bash
# Examine the full MiniTrainerOSFTBackend class
sed -n '316,445p' src/training_hub/algorithms/osft.pyLength of output: 5768
🏁 Script executed:
#!/bin/bash
# Check if these params are stripped out or cause errors
rg -n 'Unknown|unsupported|device|torch_compile|num_chunks' src/training_hub/algorithms/osft.pyLength of output: 0
🏁 Script executed:
#!/bin/bash
# Find MiniTrainerOSFTBackend.execute_training implementation
rg -n 'def execute_training' src/training_hub/algorithms/osft.py -A 30Length of output: 1395
🏁 Script executed:
#!/bin/bash
# Find _validate_param_types to see what parameters are accepted
rg -n '_validate_param_types' src/training_hub/algorithms/osft.py -A 50Length of output: 4304
🏁 Script executed:
#!/bin/bash
# Check optional_param_types to see if device, torch_compile, num_chunks are listed
sed -n '220,260p' src/training_hub/algorithms/osft.pyLength of output: 1740
🏁 Script executed:
#!/bin/bash
# Look at parameter validation logic
sed -n '217,280p' src/training_hub/algorithms/osft.pyLength of output: 2659
🏁 Script executed:
#!/bin/bash
# Check what happens to unrecognized kwargs
sed -n '190,215p' src/training_hub/algorithms/osft.pyLength of output: 975
🏁 Script executed:
#!/bin/bash
# Examine the full MiniTrainerOSFTBackend class and execute_training
sed -n '316,445p' src/training_hub/algorithms/osft.pyLength of output: 5768
🏁 Script executed:
#!/bin/bash
# Check if parameters flow through to instructlab training or are handled locally
rg -n 'instructlab' src/training_hub/algorithms/osft.py -B 2 -A 5Length of output: 733
🏁 Script executed:
#!/bin/bash
# Look for parameter mapping or transformation
sed -n '335,370p' src/training_hub/algorithms/osft.pyLength of output: 2028
HPU-specific parameters are accepted but silently discarded during backend execution.
The osft() function accepts device, torch_compile, and num_chunks via **kwargs and passes them through the validation layer. However, in MiniTrainerOSFTBackend.execute_training() (lines 346-369), parameters are filtered into TorchrunArgs and TrainingArgs dataclass fields. Since these HPU parameters are not fields in either dataclass, they are silently dropped before training begins.
Either add these parameters to get_optional_params() in OSFTAlgorithm.train() with proper type hints, or remove them from the example script if they're not supported.
🤖 Prompt for AI Agents
In examples/scripts/osft_hpu_example.py around lines 148 to 151, HPU-specific
kwargs ('device', 'torch_compile', 'num_chunks') are accepted but dropped before
backend execution; add these parameters to OSFTAlgorithm.train()'s
get_optional_params() with proper type hints (device: str, torch_compile: bool,
num_chunks: int) and ensure they are either mapped into the existing
TorchrunArgs/TrainingArgs dataclasses or the
MiniTrainerOSFTBackend.execute_training() is updated to accept and forward them
to the backend execution path; alternatively, if HPU support isn't intended,
remove those keys from the example script so they aren't silently ignored.
This PR adds new sample that is based on https://github.com/Red-Hat-AI-Innovation-Team/training_hub/blob/main/examples/scripts/sft_llama_example.py and demonstrates how to run SFT on HPU. This sample allows to configure more parameters through command line, e.g. batch size and also adds some HPU specific parameters, e.g. number of chunks to split dataset into during training. This PR should be used together with instructlab/training#660.
Summary by CodeRabbit
New Features
Bug Fixes