Skip to content

Conversation

@mtake
Copy link
Contributor

@mtake mtake commented Oct 13, 2025

This PR is to add training examples for granite-3.3-8b-instruct model.
granite-3.3-8b-instruct is the student model for Japanese and we have been using the code for training the model.

Summary by CodeRabbit

  • New Features

    • CLI tool to interpolate between two compatible model checkpoints and save the result.
    • Single-node multi-GPU example workflows for SFT and OSFT on Granite 3.3 8B Instruct with configurable training, checkpointing, scheduler, and optional post-training interpolation.
    • Automatic discovery of latest checkpoints and conditional interpolation via a weight parameter.
    • Improved run outputs: timing, completion summaries, and OOM mitigation guidance.
  • Documentation

    • Added an experimental "Model Interpolation" guide and example entries for SFT/OSFT (also referenced in memory estimation).

@coderabbitai
Copy link

coderabbitai bot commented Oct 13, 2025

Walkthrough

Adds three new example scripts: a model interpolator and two Granite 3.3 8B training examples (SFT and OSFT). The training scripts provide CLI-driven single-node multi-GPU orchestration, checkpoint discovery, optional post-training interpolation, and logging/error handling.

Changes

Cohort / File(s) Summary of Changes
Interpolator script
examples/scripts/interpolator.py
New CLI script to linearly interpolate two model checkpoints. Implements interpolate_models(...), parse_arguments(), and main(); validates weight and torch_dtype, loads base and trained models/state_dicts, blends parameters by weight, saves interpolated model, and copies tokenizer.
SFT Granite example
examples/scripts/sft_granite_example.py
New single-node multi-GPU SFT example using training_hub.sft. Adds CLI, runtime defaults and overrides, validation (e.g., nproc_per_node >= 4), find_most_recent_checkpoint(...), training invocation, optional interpolation of latest checkpoint, timing, and error handling.
OSFT Granite example
examples/scripts/osft_granite_example.py
New OSFT example using training_hub.osft. Adds Granite-specific OSFT config, CLI, GPU-count validation, find_most_recent_checkpoint(...), training orchestration, optional post-training interpolation, checkpoint handling, logging, and error reporting.
Docs update
examples/README.md
Adds entries for Granite 3.3 8B SFT and OSFT example scripts and a new "Model Interpolation (Experimental / In-Development)" subsection documenting the interpolator CLI and Python usage (appears in two places).

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor U as User
  participant CLI as interpolator.py (CLI)
  participant LoaderA as Base Model Loader
  participant LoaderB as Trained Model Loader
  participant Merger as StateDict Merger
  participant FS as Filesystem

  U->>CLI: run --model-path --trained-model-path --trained-model-weight --output-model-path
  CLI->>LoaderA: load base model & tokenizer
  CLI->>LoaderB: load trained model state_dict
  CLI->>Merger: compute (1-w)*base + w*trained
  Merger-->>CLI: interpolated state_dict
  CLI->>FS: save interpolated model and copy tokenizer
  CLI-->>U: report output path
Loading
sequenceDiagram
  autonumber
  actor U as User
  participant S as sft_granite_example.py
  participant TH as training_hub.sft
  participant CK as Checkpoint Store
  participant INT as interpolator (optional)

  U->>S: run with --data-path --ckpt-output-dir [overrides]
  S->>S: validate args (nproc_per_node >= 4)
  S->>TH: invoke sft(...) with config/topology
  TH-->>CK: write hf_format/samples_*.0 checkpoints
  TH-->>S: training completes
  S->>S: find_most_recent_checkpoint(output_dir)
  alt model_weight in (0,1)
    S->>INT: interpolate_models(base_model, latest_ckpt, weight)
    INT-->>S: interpolated model saved
  end
  S-->>U: print summary (duration, checkpoint, interpolated path)
  opt on error
    S-->>U: print error and troubleshooting tips
  end
Loading
sequenceDiagram
  autonumber
  actor U as User
  participant O as osft_granite_example.py
  participant TH as training_hub.osft
  participant CK as Checkpoint Store
  participant INT as interpolator (optional)

  U->>O: run with data/output args and OSFT overrides
  O->>O: ensure GPUs per node >= 4
  O->>TH: invoke osft(...) with OSFT config/topology
  TH-->>CK: write hf_format/samples_*.0 checkpoints
  TH-->>O: training completes
  O->>O: find_most_recent_checkpoint(output_dir)
  alt interpolation weight in (0,1)
    O->>INT: interpolate_models(base_model, latest_ckpt, weight)
    INT-->>O: interpolated model saved
  end
  O-->>U: print completion details or error with tips
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I twitch my whiskers at the code,
I blend two weights where carrots grow,
SFT hops, OSFT leaps in tune,
Granite learns beneath the moon,
Checkpoints saved — I thump, bravo! 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 20.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The pull request title "Add granite training example" is directly related to the main objective of the changeset, which is to add training examples for the Granite 3.3 8B model. The title accurately captures the primary intent and refers to a concrete, specific addition to the repository. While the title could be more comprehensive (e.g., acknowledging the plural "examples" for both SFT and OSFT, or mentioning the interpolator utility script that was also added), it is not misleading or off-topic, and it clearly identifies the core contribution without unnecessary noise.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (12)
examples/scripts/interpolator.py (3)

37-41: Guard against missing/shape-mismatched keys and avoid autograd overhead

Use no_grad, operate on the intersection of keys, and in-place ops to reduce memory. Warn on skipped keys.

-    state_dict = model.state_dict()
-    original_model_weight = 1 - trained_model_weight
-    for key in state_dict.keys():
-        state_dict[key] = state_dict[key] * original_model_weight
+    state_dict = model.state_dict()
+    original_model_weight = 1 - trained_model_weight
@@
-    trained_state_dict = trained_model.state_dict()
-    for key in state_dict.keys():
-        state_dict[key] += trained_state_dict[key] * trained_model_weight
+    trained_state_dict = trained_model.state_dict()
+    common_keys = state_dict.keys() & trained_state_dict.keys()
+    missing_in_trained = state_dict.keys() - trained_state_dict.keys()
+    missing_in_base = trained_state_dict.keys() - state_dict.keys()
+    if missing_in_trained:
+        print(f"[interpolator] Skipping {len(missing_in_trained)} base-only params")
+    if missing_in_base:
+        print(f"[interpolator] Skipping {len(missing_in_base)} trained-only params")
+    with torch.no_grad():
+        for key in common_keys:
+            if state_dict[key].shape != trained_state_dict[key].shape:
+                print(f"[interpolator] Shape mismatch for {key}, skipping")
+                continue
+            state_dict[key].mul_(original_model_weight).add_(trained_state_dict[key], alpha=trained_model_weight)

Also applies to: 47-49


55-56: Prefer tokenizer from trained checkpoint (fall back to base if needed)

Ensures any tokenizer updates during training are preserved.

-    tokenizer = AutoTokenizer.from_pretrained(model_path)
+    try:
+        tokenizer = AutoTokenizer.from_pretrained(trained_model_path)
+    except Exception:
+        tokenizer = AutoTokenizer.from_pretrained(model_path)

90-94: Clarify dtype CLI and acceptable values

Make CLI self-explanatory.

-    parser.add_argument(
+    parser.add_argument(
         "--torch-dtype",
         type=str,
         default="bfloat16",
-        help="torch dtype",
+        help='torch dtype: "auto", "bfloat16"/"bf16", "float16"/"fp16", or "float32"/"fp32"',
     )
examples/scripts/osft_granite_example.py (5)

70-72: Ensure data_output_dir exists

Create the directory to avoid downstream failures.

-data_output_dir=f"/dev/shm/data/{full_experiment_name}"  # Directory for processed data (RAM disk for speed)
+data_output_dir=f"/dev/shm/data/{full_experiment_name}"  # Directory for processed data (RAM disk for speed)
+os.makedirs(data_output_dir, exist_ok=True)

88-93: Broaden checkpoint glob pattern

Using "samples_*" matches both epoch- and sample-based patterns; avoids missing checkpoints.

-    checkpoint_pattern = os.path.join(output_dir, "hf_format", "samples_*.0")
+    checkpoint_pattern = os.path.join(output_dir, "hf_format", "samples_*")

159-195: Remove unused variable assignment

result is never used.

-        result = osft(
+        osft(

211-219: Make interpolator import robust regardless of working directory

Support running script from repo root or from examples/scripts.

-        if 0.0 < trained_model_weight and trained_model_weight < 1.0:
-            from interpolator import interpolate_models
+        if 0.0 < trained_model_weight < 1.0:
+            # Ensure sibling script is importable when run from repo root
+            script_dir = os.path.dirname(os.path.abspath(__file__))
+            if script_dir not in sys.path:
+                sys.path.insert(0, script_dir)
+            from interpolator import interpolate_models

221-232: Consider narrowing the exception or re-raising after logging

Catching Exception can hide bugs; at least re-raise after logging in non-interactive runs.

No diff provided; consider catching specific exceptions or calling raise after sys.exit(1) in testing contexts.

examples/scripts/sft_granite_example.py (4)

63-65: Ensure data_output_dir exists

Create the directory proactively.

-data_output_dir=f"/dev/shm/data/{full_experiment_name}"  # Directory for processed data (RAM disk for speed)
+data_output_dir=f"/dev/shm/data/{full_experiment_name}"  # Directory for processed data (RAM disk for speed)
+os.makedirs(data_output_dir, exist_ok=True)

145-173: Remove unused variable assignment

result is not used.

-        result = sft(
+        sft(

186-194: Make interpolator import robust regardless of working directory

Same concern as OSFT script.

-        if 0.0 < trained_model_weight and trained_model_weight < 1.0:
-            from interpolator import interpolate_models
+        if 0.0 < trained_model_weight < 1.0:
+            script_dir = os.path.dirname(os.path.abspath(__file__))
+            if script_dir not in sys.path:
+                sys.path.insert(0, script_dir)
+            from interpolator import interpolate_models

196-205: Consider narrowing broad exception handling

As with OSFT script.

No diff provided; consider catching specific exceptions or re-raising after logging in CI contexts.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fc2175d and a8d9030.

📒 Files selected for processing (3)
  • examples/scripts/interpolator.py (1 hunks)
  • examples/scripts/osft_granite_example.py (1 hunks)
  • examples/scripts/sft_granite_example.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
examples/scripts/osft_granite_example.py (3)
src/training_hub/algorithms/osft.py (1)
  • osft (439-500)
examples/scripts/sft_granite_example.py (2)
  • find_most_recent_checkpoint (68-91)
  • main (94-205)
examples/scripts/interpolator.py (2)
  • main (99-113)
  • interpolate_models (18-58)
examples/scripts/interpolator.py (2)
examples/scripts/osft_granite_example.py (1)
  • main (100-232)
examples/scripts/sft_granite_example.py (1)
  • main (94-205)
examples/scripts/sft_granite_example.py (2)
src/training_hub/algorithms/sft.py (1)
  • sft (177-249)
examples/scripts/interpolator.py (2)
  • main (99-113)
  • interpolate_models (18-58)
🪛 Ruff (0.13.3)
examples/scripts/osft_granite_example.py

1-1: Shebang is present but file is not executable

(EXE001)


71-71: Probable insecure usage of temporary file or directory: "/dev/shm/data/"

(S108)


92-92: Avoid specifying long messages outside the exception class

(TRY003)


134-134: Avoid specifying long messages outside the exception class

(TRY003)


159-159: Local variable result is assigned to but never used

Remove assignment to unused variable result

(F841)


221-221: Do not catch blind exception: Exception

(BLE001)

examples/scripts/sft_granite_example.py

1-1: Shebang is present but file is not executable

(EXE001)


64-64: Probable insecure usage of temporary file or directory: "/dev/shm/data/"

(S108)


86-86: Avoid specifying long messages outside the exception class

(TRY003)


124-124: Avoid specifying long messages outside the exception class

(TRY003)


145-145: Local variable result is assigned to but never used

Remove assignment to unused variable result

(F841)


196-196: Do not catch blind exception: Exception

(BLE001)

@Maxusmusti
Copy link
Contributor

Hi @mtake thanks for the contribution! Could you add these scripts to the examples dir readme? And also explain the usasge for interpolator.py in the readme as well? It all looks good at first glance, you can also take a look at the coderabbit suggestions, and just press resolve on them if they are not relevant.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
examples/scripts/interpolator.py (1)

19-28: Validate trained_model_weight is within [0, 1].

The function does not validate that trained_model_weight falls within the valid range before performing interpolation. Invalid weights can produce meaningless or unstable results.

Apply this diff to add validation:

 def interpolate_models(
     model_path: str,
     trained_model_path: str,
     trained_model_weight: float = 0.5,
     output_model_path: str | None = None,
     torch_dtype: str | torch.dtype | None = "bfloat16",
 ) -> str:
+    if not (0.0 <= trained_model_weight <= 1.0):
+        raise ValueError(f"trained_model_weight must be in [0, 1], got {trained_model_weight}")
+
     if output_model_path is None:
         output_model_path = f"{trained_model_path}_interp"
🧹 Nitpick comments (2)
examples/scripts/interpolator.py (2)

114-128: Add error handling and user feedback.

The main function lacks error handling and success feedback. Given this script can be invoked directly from the command line, users would benefit from clear error messages and confirmation when interpolation completes.

Apply this diff to add error handling and user feedback:

 def main():
     args = parse_arguments()
     model_path: str = args.model_path
     trained_model_path: str = args.trained_model_path
     trained_model_weight: float = args.trained_model_weight
     output_model_path: str | None = args.output_model_path
     torch_dtype: str | None = args.torch_dtype
 
-    interpolate_models(
-        model_path,
-        trained_model_path,
-        trained_model_weight=trained_model_weight,
-        output_model_path=output_model_path,
-        torch_dtype=torch_dtype,
-    )
+    try:
+        print(f"Starting model interpolation with weight {trained_model_weight}...")
+        output_path = interpolate_models(
+            model_path,
+            trained_model_path,
+            trained_model_weight=trained_model_weight,
+            output_model_path=output_model_path,
+            torch_dtype=torch_dtype,
+        )
+        print(f"✅ Interpolation completed successfully!")
+        print(f"📁 Output model saved to: {output_path}")
+    except Exception as e:
+        print(f"❌ Interpolation failed: {e}")
+        raise

41-42: Optional: Extract error message for consistency.

Ruff suggests extracting long messages outside the exception class for consistency. However, this message is concise and context-specific, so extraction is optional.

If you prefer to address the linter hint:

+_UNSUPPORTED_DTYPE_MSG = "Unsupported --torch-dtype: {dtype}"
+
 def interpolate_models(
     ...
             if _torch_dtype not in _map:
-                raise ValueError(f"Unsupported --torch-dtype: {torch_dtype}")
+                raise ValueError(_UNSUPPORTED_DTYPE_MSG.format(dtype=torch_dtype))

Based on learnings

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ec90208 and 210031b.

📒 Files selected for processing (1)
  • examples/scripts/interpolator.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/scripts/interpolator.py (2)
examples/scripts/osft_granite_example.py (1)
  • main (100-232)
examples/scripts/sft_granite_example.py (1)
  • main (94-205)
🪛 Ruff (0.14.0)
examples/scripts/interpolator.py

42-42: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (4)
examples/scripts/interpolator.py (4)

14-17: Good: torch import added.

The torch import addresses the previous feedback and enables proper dtype handling in the interpolation logic.


29-45: LGTM: torch_dtype handling is robust.

The string-to-dtype mapping with validation addresses the previous feedback well. The code correctly handles both string literals (with case-insensitive mapping) and torch.dtype objects, and properly supports "auto" mode.


63-64: Verify state_dict key compatibility between checkpoints.

The code iterates over state_dict.keys() (from the original model) to blend weights from trained_state_dict. If trained_state_dict contains keys not present in state_dict, they are silently ignored. If state_dict has keys missing in trained_state_dict, this will raise a KeyError.

For fine-tuned versions of the same base model, the keys should match. However, consider adding a guard to verify key compatibility and provide a clear error message if they differ.

Apply this diff to add key compatibility validation:

     trained_state_dict = trained_model.state_dict()
+    if set(state_dict.keys()) != set(trained_state_dict.keys()):
+        raise ValueError(
+            "Model state_dict keys do not match. "
+            f"Original model has {len(state_dict)} keys, trained model has {len(trained_state_dict)} keys."
+        )
     for key in state_dict.keys():
         state_dict[key] += trained_state_dict[key] * trained_model_weight

76-111: LGTM: Clear and standard argument parsing.

The argument parser is well-structured with helpful descriptions and sensible defaults. The use of ArgumentDefaultsHelpFormatter improves user experience.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
examples/scripts/interpolator.py (3)

32-48: LGTM: torch_dtype handling is comprehensive.

The string-to-dtype mapping covers common formats (bfloat16/bf16, float16/fp16, float32/fp32) and properly validates input.

Optional: Address static analysis style suggestion

Ruff (TRY003) suggests extracting long error messages to reduce exception instantiation overhead. While the current inline messages are clear and readable for a script, you could optionally define message constants if preferred:

+UNSUPPORTED_DTYPE_MSG = "Unsupported --torch-dtype: {}"
+
     if torch_dtype is not None:
         if isinstance(torch_dtype, str):
             _torch_dtype = torch_dtype.lower()
             if _torch_dtype == "auto":
                 model_kwargs["torch_dtype"] = "auto"
             else:
                 _map = {
                     "bfloat16": torch.bfloat16, "bf16": torch.bfloat16,
                     "float16": torch.float16, "fp16": torch.float16,
                     "float32": torch.float32, "fp32": torch.float32,
                 }
                 if _torch_dtype not in _map:
-                    raise ValueError(f"Unsupported --torch-dtype: {torch_dtype}")
+                    raise ValueError(UNSUPPORTED_DTYPE_MSG.format(torch_dtype))
                 model_kwargs["torch_dtype"] = _map[_torch_dtype]

As per static analysis hints.


51-54: Consider specifying device placement for large models.

For large models like Granite 3.3 8B, loading without an explicit device_map may cause OOM errors or inefficient single-device placement. Consider adding device placement strategy.

Add device_map to model loading:

     # load original model
     model = AutoModelForCausalLM.from_pretrained(
         model_path,
+        device_map="auto",  # Automatically distribute across available devices
         **model_kwargs,
     )
     state_dict = model.state_dict()
     original_model_weight = 1 - trained_model_weight
     for key in state_dict.keys():
         state_dict[key] = state_dict[key] * original_model_weight

     # load trained model
     trained_model = AutoModelForCausalLM.from_pretrained(
         trained_model_path,
+        device_map="auto",
         **model_kwargs,
     )

Alternatively, make device_map a configurable parameter if users want explicit control over device placement.

Also applies to: 61-64


73-74: Document tokenizer assumption.

The code copies the tokenizer from the base model_path, assuming both models use the same tokenizer. This is reasonable for interpolating checkpoints of the same model family, but worth documenting.

Add a brief comment or docstring note:

     # save interpolated model
     model.save_pretrained(output_model_path, state_dict=state_dict)

-    # copy tokenizer
+    # Copy tokenizer from base model (assumes both models use the same tokenizer)
     tokenizer = AutoTokenizer.from_pretrained(model_path)
     tokenizer.save_pretrained(output_model_path)

Or add to the function docstring:

 def interpolate_models(
     model_path: str,
     trained_model_path: str,
     trained_model_weight: float = 0.5,
     output_model_path: str | None = None,
     torch_dtype: str | torch.dtype | None = "bfloat16",
 ) -> str:
+    """
+    Interpolate two model checkpoints with linear weighting.
+    
+    Both models must have the same architecture and tokenizer.
+    The tokenizer from model_path will be copied to the output.
+    """
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 210031b and 2a67cef.

📒 Files selected for processing (1)
  • examples/scripts/interpolator.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/scripts/interpolator.py (2)
examples/scripts/osft_granite_example.py (1)
  • main (100-232)
examples/scripts/sft_granite_example.py (1)
  • main (94-205)
🪛 Ruff (0.14.0)
examples/scripts/interpolator.py

30-30: Avoid specifying long messages outside the exception class

(TRY003)


45-45: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (3)
examples/scripts/interpolator.py (3)

14-16: LGTM: Imports correctly address previous feedback.

The torch import and transformers components are now properly included for dtype handling and model loading.


29-30: LGTM: Weight validation is correct.

The validation ensures trained_model_weight is within [0, 1] and provides a clear error message if not.


79-135: LGTM: CLI interface is well-structured.

The argument parser provides clear help messages, appropriate defaults, and correct type handling. The main function cleanly orchestrates the argument parsing and model interpolation.

@mtake
Copy link
Contributor Author

mtake commented Oct 15, 2025

@Maxusmusti Thank you for the comments. I added the scripts to README.md and addressed all Bot's suggestions. Please review.

Copy link
Contributor

@Maxusmusti Maxusmusti left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mtake Thanks again for the contribution! Everything looks good, just a few minor nits, if you have time might be worth addressing, but if not can be merged as-is. I tested the scripts and all seems to be working as expected.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
examples/scripts/interpolator.py (1)

66-67: Add validation that model state_dicts have matching keys.

The interpolation loop assumes both models have identical keys. If the trained model checkpoint has a different architecture, is corrupted, or missing parameters, accessing trained_state_dict[key] will raise a KeyError. This issue was flagged in a previous review but hasn't been fully addressed yet.

Add key validation before interpolation:

     trained_model = AutoModelForCausalLM.from_pretrained(
         trained_model_path,
         **model_kwargs,
     )
     trained_state_dict = trained_model.state_dict()
+    
+    # Validate that both models have matching keys
+    base_keys = set(state_dict.keys())
+    trained_keys = set(trained_state_dict.keys())
+    if base_keys != trained_keys:
+        missing_in_trained = base_keys - trained_keys
+        extra_in_trained = trained_keys - base_keys
+        error_parts = []
+        if missing_in_trained:
+            error_parts.append(f"Missing in trained model: {list(missing_in_trained)[:5]}")
+        if extra_in_trained:
+            error_parts.append(f"Extra in trained model: {list(extra_in_trained)[:5]}")
+        raise ValueError(f"Model architectures do not match. {'; '.join(error_parts)}")
+    
     for key in state_dict.keys():
         state_dict[key] += trained_state_dict[key] * trained_model_weight
🧹 Nitpick comments (1)
examples/scripts/osft_granite_example.py (1)

74-97: Consider extracting the duplicated checkpoint helper.

The find_most_recent_checkpoint function is identical to the one in sft_granite_example.py (lines 67-90). To reduce duplication and improve maintainability, consider extracting this helper to a shared utility module.

For example, create examples/scripts/training_utils.py:

"""Common utilities for training example scripts."""
import os
import glob

def find_most_recent_checkpoint(output_dir, pattern="samples_*"):
    """Find the most recent checkpoint in the training output directory."""
    checkpoint_pattern = os.path.join(output_dir, "hf_format", pattern)
    checkpoint_dirs = glob.glob(checkpoint_pattern)
    
    if not checkpoint_dirs:
        raise ValueError(f"No checkpoints found in {os.path.join(output_dir, 'hf_format')}")
    
    return max(checkpoint_dirs, key=os.path.getctime)

Then import and use it in both scripts:

from training_utils import find_most_recent_checkpoint
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2a67cef and 2cd1d2a.

📒 Files selected for processing (3)
  • examples/README.md (3 hunks)
  • examples/scripts/interpolator.py (1 hunks)
  • examples/scripts/osft_granite_example.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/scripts/osft_granite_example.py (3)
src/training_hub/algorithms/osft.py (1)
  • osft (448-513)
examples/scripts/sft_granite_example.py (2)
  • find_most_recent_checkpoint (68-91)
  • main (94-205)
examples/scripts/interpolator.py (2)
  • main (119-133)
  • interpolate_models (19-78)
examples/scripts/interpolator.py (2)
examples/scripts/osft_granite_example.py (1)
  • main (100-232)
examples/scripts/sft_granite_example.py (1)
  • main (94-205)
🪛 LanguageTool
examples/README.md

[grammar] ~110-~110: There might be a mistake here.
Context: ...ne with linear interpolation. Script: - [interpolator.py](scripts/interpolator.py...

(QB_NEW_EN)

🪛 Ruff (0.14.0)
examples/scripts/osft_granite_example.py

1-1: Shebang is present but file is not executable

(EXE001)


71-71: Probable insecure usage of temporary file or directory: "/dev/shm/data/"

(S108)


92-92: Avoid specifying long messages outside the exception class

(TRY003)


134-134: Avoid specifying long messages outside the exception class

(TRY003)


159-159: Local variable result is assigned to but never used

Remove assignment to unused variable result

(F841)


221-221: Do not catch blind exception: Exception

(BLE001)

examples/scripts/interpolator.py

30-30: Avoid specifying long messages outside the exception class

(TRY003)


45-45: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (4)
examples/README.md (1)

30-30: LGTM! Documentation additions are clear and well-structured.

The new Granite 3.3 8B example entries follow the existing documentation pattern, and the Model Interpolation section provides both CLI and Python usage examples with generic paths as requested in past reviews.

Also applies to: 64-64, 106-123

examples/scripts/interpolator.py (1)

19-78: Other past review feedback has been addressed.

The function now includes:

  • Proper weight validation (lines 29-30)
  • Correct torch dtype handling with string-to-dtype mapping (lines 32-48)
  • Success message indicating where the merged model was saved (line 76)
examples/scripts/osft_granite_example.py (2)

212-219: Interpolation logic correctly handles edge cases.

The condition 0.0 < trained_model_weight < 1.0 properly excludes edge cases where:

  • trained_model_weight == 0.0 would produce the base model (no interpolation needed)
  • trained_model_weight == 1.0 would produce the trained model (no interpolation needed)

This is the correct behavior and avoids unnecessary interpolation overhead.


1-236: Overall script structure is solid.

The script provides:

  • Comprehensive CLI with sensible defaults
  • Clear configuration printing before training
  • Robust error handling with helpful troubleshooting tips
  • Optional post-training interpolation
  • Proper integration with training_hub.osft

The unused result variable (line 159) flagged by static analysis is acceptable since osft() performs side effects (training and checkpointing) even if the return value isn't used.

@mtake
Copy link
Contributor Author

mtake commented Oct 17, 2025

Hi @Maxusmusti
I addressed all your comments in 2cd1d2a. The comments were useful to make the code comprehensive. Thank you!

@mtake mtake requested a review from Maxusmusti October 17, 2025 02:24
Copy link
Contributor

@Maxusmusti Maxusmusti left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me!

@Maxusmusti Maxusmusti merged commit f485b6c into Red-Hat-AI-Innovation-Team:main Oct 17, 2025
4 checks passed
@mtake mtake deleted the granite_training_example branch October 17, 2025 16:28
@coderabbitai coderabbitai bot mentioned this pull request Nov 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants