Skip to content

Conversation

@splotnikv
Copy link

@splotnikv splotnikv commented Sep 23, 2025

This PR adds new sample that is based on https://github.com/Red-Hat-AI-Innovation-Team/training_hub/blob/main/examples/scripts/sft_llama_example.py and demonstrates how to run SFT on HPU. This sample allows to configure more parameters through command line, e.g. batch size and also adds some HPU specific parameters, e.g. number of chunks to split dataset into during training. This PR should be used together with instructlab/training#660.

Summary by CodeRabbit

  • New Features

    • Added example scripts demonstrating supervised fine-tuning (SFT) and orthogonal subspace fine-tuning (OSFT) on HPU devices in a single-node, multi-GPU setup, with CLI options, run summaries, checkpointing, and runtime diagnostics.
  • Bug Fixes

    • Improved HPU environment compatibility by deferring certain model-configuration imports to avoid initialization conflicts.

@coderabbitai
Copy link

coderabbitai bot commented Sep 23, 2025

Walkthrough

Adds two new HPU example scripts (SFT and OSFT) for single-node multi-GPU runs and defers importing MODEL_CONFIGS into OSFTEstimatorExperimental.__init__ to avoid global import issues on HPU environments.

Changes

Cohort / File(s) Summary
New HPU SFT example script
examples/scripts/sft_hpu_example.py
Adds an executable example that parses CLI args, sets PT_HPU_AUTOLOAD=0, prints a run banner, builds and calls training_hub.sft with model/data, epochs, batch/max-tokens, checkpointing, single-node multi-GPU rendezvous options, HPU/device flags and optional torch_compile/num_chunks; measures duration and reports success or error with OOM guidance.
New HPU OSFT example script
examples/scripts/osft_hpu_example.py
Adds an executable OSFT example that parses CLI args, prints run info, assembles osft params (OSFT-specific settings, training hyperparameters, RAM-disk/data options, single-node multi-GPU rendezvous, HPU flags, optional torch_compile/num_chunks), invokes osft(...), finds most-recent checkpoint on success, measures duration, and emits errors/tips on failure.
Deferred MODEL_CONFIGS import in estimator
src/training_hub/profiling/memory_estimator.py
Moves MODEL_CONFIGS import from module scope into OSFTEstimatorExperimental.__init__, then populates self.target_terms after import to defer heavy/global import and avoid HPU-related import conflicts while preserving existing model-pattern selection logic.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant Script as Example Script (sft/osft)
  participant TH as training_hub.sft / training_hub.osft
  participant PG as ProcessGroup / Rendezvous
  participant HPU as HPU Workers
  participant FS as Filesystem (checkpoints)

  User->>Script: run with CLI args
  Script->>Script: parse args, set PT_HPU_AUTOLOAD=0, print banner
  Script->>TH: call sft/osft(config with HPU & rendezvous params)
  TH->>PG: init single-node multi-GPU (nproc_per_node, rdzv)
  PG->>HPU: spawn workers on HPU devices
  TH->>TH: training loop, checkpointing, profiling
  TH->>FS: write checkpoints
  alt success
    TH-->>Script: return / locate checkpoint path
    Script-->>User: print duration and checkpoint location
  else failure
    TH-->>Script: raise exception
    Script-->>User: print error, duration, suggest reducing max-tokens-per-gpu
  end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • Pay attention to: OSFTEstimatorExperimental.__init__ change (import deferral and correctness of self.target_terms population), CLI argument handling and defaults in both example scripts, and checkpoint discovery logic in osft_hpu_example.py.

Possibly related PRs

Poem

I twitch my whiskers at HPU light,
I queue the tokens, train through night,
One node, many devices hum—
Checkpoints saved, the work is done.
If memory nibbles at my clue,
I trim the tokens, hop anew. 🐇

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title check ❓ Inconclusive The PR title 'Add HPU SFT example' is overly vague and does not accurately reflect the full scope of changes, which include both an SFT example and an OSFT example, plus a modification to memory_estimator.py. Consider a more descriptive title such as 'Add HPU SFT and OSFT examples with memory estimator update' to better represent all significant changes in the changeset.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@splotnikv
Copy link
Author

This sample should be used with instructlab/training#660

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (7)
examples/scripts/sft_hpu_example.py (7)

66-66: Remove unused variable assignment

result is not used; simplify to avoid linter F841 and keep intent clear.

Apply this diff:

-        result = sft(
+        sft(

36-49: Make /dev/shm usage optional and portable

Hard-coding /dev/shm breaks on non-Linux and triggers S108. Provide a CLI override and fall back only if available.

Apply this diff:

@@
     # Optional overrides
@@
     parser.add_argument('--nproc-per-node', type=int, default=8,
                        help='Number of HPUs (processes) on this node')
+    parser.add_argument('--data-output-dir', default=None,
+                       help='Directory for processed/intermediate data (default: /dev/shm if available)')
@@
-            data_output_dir="/dev/shm",        # Use RAM disk for speed
+            data_output_dir=(
+                args.data_output_dir
+                if args.data_output_dir
+                else ("/dev/shm" if os.path.exists("/dev/shm") else None)
+            ),                                  # Prefer RAM disk when available
             warmup_steps=100,
             save_samples=0,                    # 0 disables sample-based checkpointing, use epoch-based only

Also applies to: 80-83


50-51: Validate inputs and ensure output directory exists

Fail fast on bad paths; create checkpoint dir if missing.

Apply this diff:

     args = parser.parse_args()
 
+    # Basic validation
+    if not os.path.exists(args.data_path):
+        sys.exit(f"Training data not found: {args.data_path}")
+    os.makedirs(args.ckpt_output_dir, exist_ok=True)

108-117: Avoid catching blind Exception; handle common cases explicitly

Handle KeyboardInterrupt gracefully and OOM hints for RuntimeError/MemoryError; otherwise re-raise.

Apply this diff:

-    except Exception as e:
-        end_time = time.time()
-        duration = end_time - start_time
-        
-        print("=" * 50)
-        print(f"❌ Training failed after {duration/60:.1f} minutes")
-        print(f"Error: {e}")
-        print()
-        print("💡 Try reducing --max-tokens-per-gpu if you see OOM errors")
-        sys.exit(1)
+    except KeyboardInterrupt:
+        end_time = time.time()
+        duration = end_time - start_time
+        print("=" * 50)
+        print(f"⏹️  Training canceled by user after {duration/60:.1f} minutes")
+        sys.exit(130)
+    except (RuntimeError, MemoryError) as e:
+        end_time = time.time()
+        duration = end_time - start_time
+        print("=" * 50)
+        print(f"❌ Training failed after {duration/60:.1f} minutes")
+        print(f"Error: {e}")
+        if "oom" in str(e).lower() or "out of memory" in str(e).lower():
+            print("💡 Try reducing --max-tokens-per-gpu if you see OOM errors")
+        sys.exit(1)

47-48: Consider auto-detecting processes per node

Defaulting to 8 may oversubscribe most machines. Consider default=1 or autodetect from HABANA_VISIBLE_DEVICES to set a safer default.


1-1: Shebang present — make file executable or drop it

Either mark the script executable (chmod +x) or remove the shebang to silence EXE001.


105-107: Avoid assuming hf_format/ in checkpoint path

examples/scripts/sft_hpu_example.py:105-107 — repo search found no occurrences of "hf_format"; print the configured ckpt_output_dir instead of appending "/hf_format/".

Apply this diff:

-        print(f"📁 Checkpoints: {args.ckpt_output_dir}/hf_format/")
+        print(f"📁 Checkpoints directory: {args.ckpt_output_dir}")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8164824 and 4b1088a.

📒 Files selected for processing (1)
  • examples/scripts/sft_hpu_example.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/scripts/sft_hpu_example.py (1)
src/training_hub/algorithms/sft.py (1)
  • sft (177-249)
🪛 Ruff (0.13.1)
examples/scripts/sft_hpu_example.py

1-1: Shebang is present but file is not executable

(EXE001)


66-66: Local variable result is assigned to but never used

Remove assignment to unused variable result

(F841)


80-80: Probable insecure usage of temporary file or directory: "/dev/shm"

(S108)


108-108: Do not catch blind exception: Exception

(BLE001)

🔇 Additional comments (2)
examples/scripts/sft_hpu_example.py (2)

52-64: General: solid example

Nice, concise banner and timing with a sensible default config for SFT on HPU.

Consider a quick smoke run on a small subset to validate CLI wiring.


95-98: Verify backend accepts disable_flash_attn and device kwargs

SFTAlgorithm.train forwards arbitrary kwargs to the backend (params.update(kwargs)) and InstructLabTrainingSFTBackend builds TrainingArgs(**training_params) without special handling for these keys; the repo has no evidence TrainingArgs/TorchrunArgs accept disable_flash_attn or device — confirm the external instructlab.training API accepts them or remove/translate these kwargs in the example or backend. Locations: examples/scripts/sft_hpu_example.py (flags added), src/training_hub/algorithms/sft.py (SFTAlgorithm.train and InstructLabTrainingSFTBackend.execute_training).

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (5)
examples/scripts/sft_hpu_example.py (5)

5-5: Use HPU terminology instead of GPU for clarity.

The docstring mentions "multi-GPU" but this is an HPU training example. This was flagged in a previous review.


7-11: Add --model-path to the example usage.

The example is missing the required --model-path argument. This was flagged in a previous review and remains unresolved.


19-19: Set PT_HPU_AUTOLOAD before importing training_hub.

The environment variable is set in main() (line 24) after importing training_hub. If the HPU backend reads this variable at import time, it's too late. This was flagged in a previous review.


45-48: Use HPU terminology in argument help text.

The help text for --max-tokens-per-gpu and --nproc-per-node mentions "GPU" instead of "HPU". This was flagged in a previous review.


62-63: Use HPU terminology in output messages.

The print statements display "GPUs" and "GPU" but should say "HPUs" and "HPU" for consistency. This was flagged in a previous review.

🧹 Nitpick comments (2)
examples/scripts/sft_hpu_example.py (2)

92-92: Update comment to use HPU terminology.

The comment mentions "multi-GPU" but should say "multi-HPU" for consistency with the script's purpose.

Apply this diff:

-            # Single-node multi-GPU setup
+            # Single-node multi-HPU setup

100-103: Maintain consistent spacing around equals signs.

Lines 100-101 have spaces around = while lines 102-103 don't. Pick one style for consistency.

Apply this diff for consistent spacing:

-            disable_flash_attn = True,
-            device = 'hpu',
-            torch_compile = args.torch_compile,
-            num_chunks = args.num_chunks,
+            disable_flash_attn=True,
+            device='hpu',
+            torch_compile=args.torch_compile,
+            num_chunks=args.num_chunks,
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4b1088a and 84b62f1.

📒 Files selected for processing (1)
  • examples/scripts/sft_hpu_example.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/scripts/sft_hpu_example.py (1)
src/training_hub/algorithms/sft.py (1)
  • sft (169-248)
🪛 Ruff (0.14.1)
examples/scripts/sft_hpu_example.py

1-1: Shebang is present but file is not executable

(EXE001)


70-70: Local variable result is assigned to but never used

Remove assignment to unused variable result

(F841)


84-84: Probable insecure usage of temporary file or directory: "/dev/shm"

(S108)


114-114: Do not catch blind exception: Exception

(BLE001)

🔇 Additional comments (1)
examples/scripts/sft_hpu_example.py (1)

112-112: No issues found. The hf_format/ subdirectory path is correct.

Verification of the codebase confirms that checkpoints are consistently output to the hf_format/ subdirectory across all example scripts. Helper functions throughout the codebase explicitly construct checkpoint paths as {output_dir}/hf_format/samples_*, and docstrings document this structure. The print statement at line 112 is accurate.

@splotnikv
Copy link
Author

I've added OSFT support. It should be used with Red-Hat-AI-Innovation-Team/mini_trainer#59.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (3)
examples/scripts/osft_hpu_example.py (3)

1-2: Add executable permission and move module-level variable.

The shebang indicates this script should be executable, but the file lacks the executable bit. Additionally, defining sample_name before the module docstring is unconventional in Python.

Apply these changes:

 #!/usr/bin/env python3
-sample_name="OSFT Training Example for HPU"
 """
 This script demonstrates how to do OSFT (Orthogonal Subspace Fine-Tuning) training on HPU
 using a single-node, multi-GPU setup with training_hub.
 
 OSFT allows continual training without catastrophic forgetting, making it ideal for:
 - Adapting models to specialized domains (medical, legal, technical)
 - Adding new knowledge without degrading general capabilities
 - Fine-tuning without complex replay mechanisms
 
 Example usage:
     python osft_hpu_example.py \\
         --data-path /path/to/data.jsonl \\
         --ckpt-output-dir /path/to/checkpoints \\
         --model-path meta-llama/Meta-Llama-3.1-8B-Instruct
 """
+
+SAMPLE_NAME = "OSFT Training Example for HPU"

Also run: chmod +x examples/scripts/osft_hpu_example.py


128-128: Make data output directory configurable.

The data output directory is hardcoded to /dev/shm, which may not exist on all systems (e.g., Windows, certain container environments) and is flagged as a potential security concern by static analysis.

Consider making it configurable with a sensible default:

+    parser.add_argument('--data-output-dir', default="/dev/shm",
+                       help='Directory for processed data (default: /dev/shm)')

Then in the osft_params:

-            'data_output_dir': "/dev/shm",      # Use RAM disk for speed
+            'data_output_dir': args.data_output_dir,      # Use RAM disk for speed

42-49: Use modification time instead of creation time for reliable checkpoint detection.

The os.path.getctime() call may not reliably indicate the most recent checkpoint. On Unix systems, ctime tracks metadata changes rather than creation time, and files may have outdated timestamps if copied. Use os.path.getmtime() (modification time) instead, which is more reliable across platforms:

-    # Find the most recently created checkpoint
-    most_recent_checkpoint = max(checkpoint_dirs, key=os.path.getctime)
+    # Find the most recently modified checkpoint
+    most_recent_checkpoint = max(checkpoint_dirs, key=os.path.getmtime)

The checkpoint pattern samples_*.0 is correct and intentional for OSFT implementations.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a477b3b and 33ed975.

📒 Files selected for processing (1)
  • examples/scripts/osft_hpu_example.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/scripts/osft_hpu_example.py (1)
src/training_hub/algorithms/osft.py (1)
  • osft (448-513)
🪛 Ruff (0.14.4)
examples/scripts/osft_hpu_example.py

1-1: Shebang is present but file is not executable

(EXE001)


46-46: Avoid specifying long messages outside the exception class

(TRY003)


128-128: Probable insecure usage of temporary file or directory: "/dev/shm"

(S108)


172-172: Do not catch blind exception: Exception

(BLE001)

🔇 Additional comments (1)
examples/scripts/osft_hpu_example.py (1)

172-183: Good user-facing error handling.

The broad exception catch is appropriate here for a top-level example script. It provides users with helpful troubleshooting guidance and a clean exit.

Comment on lines +19 to +26
import os
import sys
import time
from datetime import datetime
import argparse
import glob

from training_hub import osft
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify datetime import usage.

The datetime module is imported but doesn't appear to be used in the script. Only time.time() is called for duration tracking.

Run this script to confirm the import is unused:


🏁 Script executed:

#!/bin/bash
# Check if datetime module is used beyond the import statement
rg -n '\bdatetime\.' examples/scripts/osft_hpu_example.py

Length of output: 0


Let me verify this more thoroughly by checking for any datetime usage patterns and reviewing the complete file:


---

🏁 Script executed:

```shell
#!/bin/bash
# Check for any datetime usage in the file (class, function calls, etc.)
rg -n 'datetime' examples/scripts/osft_hpu_example.py

Length of output: 32


🏁 Script executed:

#!/bin/bash
# Show the complete file to understand the context
cat examples/scripts/osft_hpu_example.py

Length of output: 7231


Remove unused datetime import.

The datetime module is imported on line 22 but is never used in the script. The timing operations use time.time() instead. Remove this unused import.

🤖 Prompt for AI Agents
In examples/scripts/osft_hpu_example.py around lines 19 to 26, the datetime
module is imported but never used; remove the unused "from datetime import
datetime" import from the imports block so only required modules remain (e.g.,
keep os, sys, time, argparse, glob, and training_hub.osft) to eliminate the
unused import.

Comment on lines +78 to +79
parser.add_argument('--nproc-per-node', type=int, default=8,
help='Number of GPUs (default: 8)')
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix inconsistent terminology: GPUs → HPUs.

This is an HPU example, but the help text refers to "GPUs" instead of "HPUs".

Apply this diff:

-    parser.add_argument('--nproc-per-node', type=int, default=8,
-                       help='Number of GPUs (default: 8)')
+    parser.add_argument('--nproc-per-node', type=int, default=8,
+                       help='Number of HPUs per node (default: 8)')
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
parser.add_argument('--nproc-per-node', type=int, default=8,
help='Number of GPUs (default: 8)')
parser.add_argument('--nproc-per-node', type=int, default=8,
help='Number of HPUs per node (default: 8)')
🤖 Prompt for AI Agents
In examples/scripts/osft_hpu_example.py around lines 78 to 79, the argparse help
text incorrectly refers to "GPUs"; update the help string to say "Number of HPUs
(default: 8)" so the terminology matches this HPU example and zachronsistency;
modify the parser.add_argument call's help parameter accordingly.

Comment on lines +91 to +97
# sample configuration
print(f"🚀 {sample_name}")
print("=" * 50)
print(f"Model: {args.model_path}")
print(f"Data: {args.data_path}")
print(f"Output: {args.ckpt_output_dir}")
print(f"GPUs: {args.nproc_per_node}")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix terminology inconsistency in banner output.

Similar to the help text, the banner output refers to "GPUs" instead of "HPUs".

Apply this diff:

-    # sample configuration
-    print(f"🚀 {sample_name}")
+    # Sample configuration
+    print(f"🚀 {SAMPLE_NAME}")  # Update if sample_name is renamed
     print("=" * 50)
     print(f"Model: {args.model_path}")
     print(f"Data: {args.data_path}")
     print(f"Output: {args.ckpt_output_dir}")
-    print(f"GPUs: {args.nproc_per_node}")
+    print(f"HPUs: {args.nproc_per_node}")

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In examples/scripts/osft_hpu_example.py around lines 91 to 97, the banner prints
"GPUs" but the script and help text refer to HPUs; update the label in the
f-string so it prints "HPUs" instead of "GPUs" (i.e., change the line that
prints f"GPUs: {args.nproc_per_node}" to use "HPUs" to keep terminology
consistent).

Comment on lines +148 to +151
# HPU specific arguments
'device': 'hpu',
'torch_compile': args.torch_compile,
'num_chunks': args.num_chunks,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify HPU-specific parameters are supported.

The script passes HPU-specific parameters (device='hpu', torch_compile, num_chunks) to the osft() function via **kwargs. Ensure these parameters are properly handled by the OSFT implementation.

Verify that the osft function accepts these parameters:


🏁 Script executed:

#!/bin/bash
# Check if the osft function signature or implementation handles device, torch_compile, and num_chunks
rg -n -A 20 'def osft\(' src/training_hub/algorithms/osft.py
rg -n 'device.*hpu|torch_compile|num_chunks' src/training_hub/

Length of output: 773



🏁 Script executed:

#!/bin/bash
# Get full osft function signature with more context
rg -n -A 50 'def osft\(' src/training_hub/algorithms/osft.py | head -60

Length of output: 1984


🏁 Script executed:

#!/bin/bash
# Check if osft accepts **kwargs
rg -n -B 2 -A 15 'def osft\(' src/training_hub/algorithms/osft.py | grep -E 'def osft|kwargs'

Length of output: 13


🏁 Script executed:

#!/bin/bash
# Search for any handling of device, torch_compile, num_chunks in the entire osft.py file
cat src/training_hub/algorithms/osft.py | head -100

Length of output: 4035


🏁 Script executed:

#!/bin/bash
# Look for kwargs unpacking or parameter forwarding in osft function
rg -n '\*\*' src/training_hub/algorithms/osft.py

Length of output: 556


🏁 Script executed:

#!/bin/bash
# Find where osft is defined and check parameter docs/annotations
fd -t f 'osft.py' src/training_hub/

Length of output: 35


🏁 Script executed:

#!/bin/bash
# Get the full osft function - with line numbers
wc -l src/training_hub/algorithms/osft.py

Length of output: 39


🏁 Script executed:

#!/bin/bash
# Extract full osft function definition
sed -n '448,550p' src/training_hub/algorithms/osft.py

Length of output: 2274


🏁 Script executed:

#!/bin/bash
# Search for how osft function might use or document these parameters
rg -n 'device|torch_compile|num_chunks' src/training_hub/algorithms/osft.py -A 2 -B 2

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Check examples directory for how osft is called
rg -n 'osft\(' examples/scripts/osft_hpu_example.py -A 10 -B 5

Length of output: 591


🏁 Script executed:

#!/bin/bash
# Look at the full context of the call in the example script around lines 148-151
sed -n '130,160p' examples/scripts/osft_hpu_example.py

Length of output: 996


🏁 Script executed:

#!/bin/bash
# Search for documentation or type hints that mention these parameters
rg -n 'device.*:.*str|torch_compile.*:|num_chunks.*:' src/training_hub/

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Check if there's any kwargs handling in osft
rg -n 'kwargs' src/training_hub/algorithms/osft.py

Length of output: 675


🏁 Script executed:

#!/bin/bash
# Look for where osft function might reference these parameters implicitly
rg -n "kwargs\['device'\]|kwargs\['torch_compile'\]|kwargs\['num_chunks'\]" src/training_hub/

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Get the line count and then extract the full osft function
sed -n '448,500p' src/training_hub/algorithms/osft.py

Length of output: 1850


🏁 Script executed:

#!/bin/bash
# Check for the full function signature by looking for def and closing parenthesis
python3 << 'EOF'
import re

try:
    with open('src/training_hub/algorithms/osft.py', 'r') as f:
        content = f.read()
    
    # Find the osft function definition
    match = re.search(r'def osft\([^)]*\):', content, re.DOTALL)
    if match:
        print("=== osft function signature ===")
        print(match.group(0))
    else:
        print("Function signature pattern not found")
        # Try simpler pattern
        match = re.search(r'def osft\([\s\S]{0,1000}\):', content)
        if match:
            print("=== osft function (extended search) ===")
            print(match.group(0))
except Exception as e:
    print(f"Error: {e}")
EOF

Length of output: 36


🏁 Script executed:

#!/bin/bash
# Check what parameters the osft function actually accepts
python3 << 'EOF'
import inspect
import sys
sys.path.insert(0, '.')

try:
    # Try to find and parse the file
    with open('src/training_hub/algorithms/osft.py', 'r') as f:
        lines = f.readlines()
    
    # Find the osft function starting at line 448
    func_start = 447  # 0-indexed
    paren_count = 0
    found_start = False
    func_sig_lines = []
    
    for i in range(func_start, len(lines)):
        line = lines[i]
        func_sig_lines.append(line.rstrip())
        
        if 'def osft(' in line:
            found_start = True
            paren_count = line.count('(') - line.count(')')
        elif found_start:
            paren_count += line.count('(') - line.count(')')
        
        if found_start and ':' in line and paren_count == 0:
            break
    
    print("=== Full osft function signature ===")
    for line in func_sig_lines:
        print(line)
        
except Exception as e:
    print(f"Error: {e}")
EOF

Length of output: 1157


🏁 Script executed:

#!/bin/bash
# Check if device, torch_compile, num_chunks are passed in the example call
grep -n -B 5 -A 15 "osft(" examples/scripts/osft_hpu_example.py

Length of output: 921


🏁 Script executed:

#!/bin/bash
# Look for kwargs parameter or **kwargs in osft function
rg 'def osft' src/training_hub/algorithms/osft.py -A 40 | grep -E 'kwargs|\*\*'

Length of output: 67


🏁 Script executed:

#!/bin/bash
# Final check - get osft signature using ast-grep
ast-grep --pattern $'def osft($_) {
  $$$
}'

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Simple extraction - first 100 lines after osft function start
awk '/^def osft\(/{flag=1; count=0} flag{print; count++; if(count>60) exit}' src/training_hub/algorithms/osft.py

Length of output: 2148


🏁 Script executed:

#!/bin/bash
# Check for backend parameter which might handle device routing
rg -n 'backend' src/training_hub/algorithms/osft.py | head -20

Length of output: 861


🏁 Script executed:

#!/bin/bash
# Look for any forwarding to other functions that might accept these params
rg -n 'device|torch_compile|num_chunks' src/training_hub/ --type py -B 3 -A 3

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Get osft function body to see what it does with parameters
sed -n '448,600p' src/training_hub/algorithms/osft.py | grep -E '^\s*(device|torch_compile|num_chunks|def|return)'

Length of output: 37


🏁 Script executed:

#!/bin/bash
# Try getting the exact lines of the function definition in one go
python3 << 'EOF'
with open('src/training_hub/algorithms/osft.py', 'r') as f:
    lines = f.readlines()

in_def = False
paren_depth = 0
for i in range(447, min(len(lines), 550)):
    line = lines[i]
    if 'def osft(' in line:
        in_def = True
    
    if in_def:
        print(f"{i+1}: {line}", end='')
        paren_depth += line.count('(') - line.count(')')
        if ':' in line and paren_depth == 0:
            break
EOF

Length of output: 1285


🏁 Script executed:

#!/bin/bash
# Also check where osft is being called with the parameters
python3 << 'EOF'
import re
with open('examples/scripts/osft_hpu_example.py', 'r') as f:
    content = f.read()

# Find the osft call
match = re.search(r'osft\([^)]*\)', content, re.DOTALL)
if match:
    print("=== osft function call in example ===")
    print(match.group(0)[:500])
EOF

Length of output: 57


🏁 Script executed:

#!/bin/bash
# Check the actual call site around line 148-151
sed -n '140,160p' examples/scripts/osft_hpu_example.py

Length of output: 614


🏁 Script executed:

#!/bin/bash
# Look for any wrapper or intermediate function that might accept **kwargs
rg -n 'def.*osft|osft(' examples/scripts/osft_hpu_example.py -B 3 -A 10

Length of output: 80



🏁 Script executed:

#!/bin/bash
# Find MiniTrainerOSFTBackend.execute_training implementation
rg -n 'def execute_training' src/training_hub/algorithms/osft.py -A 30

Length of output: 1395


🏁 Script executed:

#!/bin/bash
# Find _validate_param_types to see what parameters are accepted
rg -n '_validate_param_types' src/training_hub/algorithms/osft.py -A 50

Length of output: 4304


🏁 Script executed:

#!/bin/bash
# Check optional_param_types to see if device, torch_compile, num_chunks are listed
sed -n '220,260p' src/training_hub/algorithms/osft.py

Length of output: 1740


🏁 Script executed:

#!/bin/bash
# Search for how parameters are used in backend execution
rg -n 'device|torch_compile|num_chunks' src/training_hub/ -B 2 -A 2

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Check if instructlab training handles these parameters
find src/training_hub -name '*.py' -type f | xargs grep -l 'device.*hpu\|torch_compile'

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Look for parameter validation and what happens to unknown kwargs
sed -n '190,210p' src/training_hub/algorithms/osft.py

Length of output: 822


🏁 Script executed:

#!/bin/bash
# Check what the backend.execute_training actually does
sed -n '370,400p' src/training_hub/algorithms/osft.py

Length of output: 1473


🏁 Script executed:

#!/bin/bash
# Examine the full MiniTrainerOSFTBackend class
sed -n '316,445p' src/training_hub/algorithms/osft.py

Length of output: 5768


🏁 Script executed:

#!/bin/bash
# Check if these params are stripped out or cause errors
rg -n 'Unknown|unsupported|device|torch_compile|num_chunks' src/training_hub/algorithms/osft.py

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Find MiniTrainerOSFTBackend.execute_training implementation
rg -n 'def execute_training' src/training_hub/algorithms/osft.py -A 30

Length of output: 1395


🏁 Script executed:

#!/bin/bash
# Find _validate_param_types to see what parameters are accepted
rg -n '_validate_param_types' src/training_hub/algorithms/osft.py -A 50

Length of output: 4304


🏁 Script executed:

#!/bin/bash
# Check optional_param_types to see if device, torch_compile, num_chunks are listed
sed -n '220,260p' src/training_hub/algorithms/osft.py

Length of output: 1740


🏁 Script executed:

#!/bin/bash
# Look at parameter validation logic
sed -n '217,280p' src/training_hub/algorithms/osft.py

Length of output: 2659


🏁 Script executed:

#!/bin/bash
# Check what happens to unrecognized kwargs
sed -n '190,215p' src/training_hub/algorithms/osft.py

Length of output: 975


🏁 Script executed:

#!/bin/bash
# Examine the full MiniTrainerOSFTBackend class and execute_training
sed -n '316,445p' src/training_hub/algorithms/osft.py

Length of output: 5768


🏁 Script executed:

#!/bin/bash
# Check if parameters flow through to instructlab training or are handled locally
rg -n 'instructlab' src/training_hub/algorithms/osft.py -B 2 -A 5

Length of output: 733


🏁 Script executed:

#!/bin/bash
# Look for parameter mapping or transformation
sed -n '335,370p' src/training_hub/algorithms/osft.py

Length of output: 2028


HPU-specific parameters are accepted but silently discarded during backend execution.

The osft() function accepts device, torch_compile, and num_chunks via **kwargs and passes them through the validation layer. However, in MiniTrainerOSFTBackend.execute_training() (lines 346-369), parameters are filtered into TorchrunArgs and TrainingArgs dataclass fields. Since these HPU parameters are not fields in either dataclass, they are silently dropped before training begins.

Either add these parameters to get_optional_params() in OSFTAlgorithm.train() with proper type hints, or remove them from the example script if they're not supported.

🤖 Prompt for AI Agents
In examples/scripts/osft_hpu_example.py around lines 148 to 151, HPU-specific
kwargs ('device', 'torch_compile', 'num_chunks') are accepted but dropped before
backend execution; add these parameters to OSFTAlgorithm.train()'s
get_optional_params() with proper type hints (device: str, torch_compile: bool,
num_chunks: int) and ensure they are either mapped into the existing
TorchrunArgs/TrainingArgs dataclasses or the
MiniTrainerOSFTBackend.execute_training() is updated to accept and forward them
to the backend execution path; alternatively, if HPU support isn't intended,
remove those keys from the example script so they aren't silently ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant