-
Notifications
You must be signed in to change notification settings - Fork 18
Add granite training example #16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add granite training example #16
Conversation
WalkthroughAdds three new example scripts: a model interpolator and two Granite 3.3 8B training examples (SFT and OSFT). The training scripts provide CLI-driven single-node multi-GPU orchestration, checkpoint discovery, optional post-training interpolation, and logging/error handling. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor U as User
participant CLI as interpolator.py (CLI)
participant LoaderA as Base Model Loader
participant LoaderB as Trained Model Loader
participant Merger as StateDict Merger
participant FS as Filesystem
U->>CLI: run --model-path --trained-model-path --trained-model-weight --output-model-path
CLI->>LoaderA: load base model & tokenizer
CLI->>LoaderB: load trained model state_dict
CLI->>Merger: compute (1-w)*base + w*trained
Merger-->>CLI: interpolated state_dict
CLI->>FS: save interpolated model and copy tokenizer
CLI-->>U: report output path
sequenceDiagram
autonumber
actor U as User
participant S as sft_granite_example.py
participant TH as training_hub.sft
participant CK as Checkpoint Store
participant INT as interpolator (optional)
U->>S: run with --data-path --ckpt-output-dir [overrides]
S->>S: validate args (nproc_per_node >= 4)
S->>TH: invoke sft(...) with config/topology
TH-->>CK: write hf_format/samples_*.0 checkpoints
TH-->>S: training completes
S->>S: find_most_recent_checkpoint(output_dir)
alt model_weight in (0,1)
S->>INT: interpolate_models(base_model, latest_ckpt, weight)
INT-->>S: interpolated model saved
end
S-->>U: print summary (duration, checkpoint, interpolated path)
opt on error
S-->>U: print error and troubleshooting tips
end
sequenceDiagram
autonumber
actor U as User
participant O as osft_granite_example.py
participant TH as training_hub.osft
participant CK as Checkpoint Store
participant INT as interpolator (optional)
U->>O: run with data/output args and OSFT overrides
O->>O: ensure GPUs per node >= 4
O->>TH: invoke osft(...) with OSFT config/topology
TH-->>CK: write hf_format/samples_*.0 checkpoints
TH-->>O: training completes
O->>O: find_most_recent_checkpoint(output_dir)
alt interpolation weight in (0,1)
O->>INT: interpolate_models(base_model, latest_ckpt, weight)
INT-->>O: interpolated model saved
end
O-->>U: print completion details or error with tips
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (12)
examples/scripts/interpolator.py (3)
37-41: Guard against missing/shape-mismatched keys and avoid autograd overheadUse no_grad, operate on the intersection of keys, and in-place ops to reduce memory. Warn on skipped keys.
- state_dict = model.state_dict() - original_model_weight = 1 - trained_model_weight - for key in state_dict.keys(): - state_dict[key] = state_dict[key] * original_model_weight + state_dict = model.state_dict() + original_model_weight = 1 - trained_model_weight @@ - trained_state_dict = trained_model.state_dict() - for key in state_dict.keys(): - state_dict[key] += trained_state_dict[key] * trained_model_weight + trained_state_dict = trained_model.state_dict() + common_keys = state_dict.keys() & trained_state_dict.keys() + missing_in_trained = state_dict.keys() - trained_state_dict.keys() + missing_in_base = trained_state_dict.keys() - state_dict.keys() + if missing_in_trained: + print(f"[interpolator] Skipping {len(missing_in_trained)} base-only params") + if missing_in_base: + print(f"[interpolator] Skipping {len(missing_in_base)} trained-only params") + with torch.no_grad(): + for key in common_keys: + if state_dict[key].shape != trained_state_dict[key].shape: + print(f"[interpolator] Shape mismatch for {key}, skipping") + continue + state_dict[key].mul_(original_model_weight).add_(trained_state_dict[key], alpha=trained_model_weight)Also applies to: 47-49
55-56: Prefer tokenizer from trained checkpoint (fall back to base if needed)Ensures any tokenizer updates during training are preserved.
- tokenizer = AutoTokenizer.from_pretrained(model_path) + try: + tokenizer = AutoTokenizer.from_pretrained(trained_model_path) + except Exception: + tokenizer = AutoTokenizer.from_pretrained(model_path)
90-94: Clarify dtype CLI and acceptable valuesMake CLI self-explanatory.
- parser.add_argument( + parser.add_argument( "--torch-dtype", type=str, default="bfloat16", - help="torch dtype", + help='torch dtype: "auto", "bfloat16"/"bf16", "float16"/"fp16", or "float32"/"fp32"', )examples/scripts/osft_granite_example.py (5)
70-72: Ensure data_output_dir existsCreate the directory to avoid downstream failures.
-data_output_dir=f"/dev/shm/data/{full_experiment_name}" # Directory for processed data (RAM disk for speed) +data_output_dir=f"/dev/shm/data/{full_experiment_name}" # Directory for processed data (RAM disk for speed) +os.makedirs(data_output_dir, exist_ok=True)
88-93: Broaden checkpoint glob patternUsing "samples_*" matches both epoch- and sample-based patterns; avoids missing checkpoints.
- checkpoint_pattern = os.path.join(output_dir, "hf_format", "samples_*.0") + checkpoint_pattern = os.path.join(output_dir, "hf_format", "samples_*")
159-195: Remove unused variable assignmentresult is never used.
- result = osft( + osft(
211-219: Make interpolator import robust regardless of working directorySupport running script from repo root or from examples/scripts.
- if 0.0 < trained_model_weight and trained_model_weight < 1.0: - from interpolator import interpolate_models + if 0.0 < trained_model_weight < 1.0: + # Ensure sibling script is importable when run from repo root + script_dir = os.path.dirname(os.path.abspath(__file__)) + if script_dir not in sys.path: + sys.path.insert(0, script_dir) + from interpolator import interpolate_models
221-232: Consider narrowing the exception or re-raising after loggingCatching Exception can hide bugs; at least re-raise after logging in non-interactive runs.
No diff provided; consider catching specific exceptions or calling
raiseaftersys.exit(1)in testing contexts.examples/scripts/sft_granite_example.py (4)
63-65: Ensure data_output_dir existsCreate the directory proactively.
-data_output_dir=f"/dev/shm/data/{full_experiment_name}" # Directory for processed data (RAM disk for speed) +data_output_dir=f"/dev/shm/data/{full_experiment_name}" # Directory for processed data (RAM disk for speed) +os.makedirs(data_output_dir, exist_ok=True)
145-173: Remove unused variable assignmentresult is not used.
- result = sft( + sft(
186-194: Make interpolator import robust regardless of working directorySame concern as OSFT script.
- if 0.0 < trained_model_weight and trained_model_weight < 1.0: - from interpolator import interpolate_models + if 0.0 < trained_model_weight < 1.0: + script_dir = os.path.dirname(os.path.abspath(__file__)) + if script_dir not in sys.path: + sys.path.insert(0, script_dir) + from interpolator import interpolate_models
196-205: Consider narrowing broad exception handlingAs with OSFT script.
No diff provided; consider catching specific exceptions or re-raising after logging in CI contexts.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/scripts/interpolator.py(1 hunks)examples/scripts/osft_granite_example.py(1 hunks)examples/scripts/sft_granite_example.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
examples/scripts/osft_granite_example.py (3)
src/training_hub/algorithms/osft.py (1)
osft(439-500)examples/scripts/sft_granite_example.py (2)
find_most_recent_checkpoint(68-91)main(94-205)examples/scripts/interpolator.py (2)
main(99-113)interpolate_models(18-58)
examples/scripts/interpolator.py (2)
examples/scripts/osft_granite_example.py (1)
main(100-232)examples/scripts/sft_granite_example.py (1)
main(94-205)
examples/scripts/sft_granite_example.py (2)
src/training_hub/algorithms/sft.py (1)
sft(177-249)examples/scripts/interpolator.py (2)
main(99-113)interpolate_models(18-58)
🪛 Ruff (0.13.3)
examples/scripts/osft_granite_example.py
1-1: Shebang is present but file is not executable
(EXE001)
71-71: Probable insecure usage of temporary file or directory: "/dev/shm/data/"
(S108)
92-92: Avoid specifying long messages outside the exception class
(TRY003)
134-134: Avoid specifying long messages outside the exception class
(TRY003)
159-159: Local variable result is assigned to but never used
Remove assignment to unused variable result
(F841)
221-221: Do not catch blind exception: Exception
(BLE001)
examples/scripts/sft_granite_example.py
1-1: Shebang is present but file is not executable
(EXE001)
64-64: Probable insecure usage of temporary file or directory: "/dev/shm/data/"
(S108)
86-86: Avoid specifying long messages outside the exception class
(TRY003)
124-124: Avoid specifying long messages outside the exception class
(TRY003)
145-145: Local variable result is assigned to but never used
Remove assignment to unused variable result
(F841)
196-196: Do not catch blind exception: Exception
(BLE001)
|
Hi @mtake thanks for the contribution! Could you add these scripts to the examples dir readme? And also explain the usasge for interpolator.py in the readme as well? It all looks good at first glance, you can also take a look at the coderabbit suggestions, and just press resolve on them if they are not relevant. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
examples/scripts/interpolator.py (1)
19-28: Validate trained_model_weight is within [0, 1].The function does not validate that
trained_model_weightfalls within the valid range before performing interpolation. Invalid weights can produce meaningless or unstable results.Apply this diff to add validation:
def interpolate_models( model_path: str, trained_model_path: str, trained_model_weight: float = 0.5, output_model_path: str | None = None, torch_dtype: str | torch.dtype | None = "bfloat16", ) -> str: + if not (0.0 <= trained_model_weight <= 1.0): + raise ValueError(f"trained_model_weight must be in [0, 1], got {trained_model_weight}") + if output_model_path is None: output_model_path = f"{trained_model_path}_interp"
🧹 Nitpick comments (2)
examples/scripts/interpolator.py (2)
114-128: Add error handling and user feedback.The main function lacks error handling and success feedback. Given this script can be invoked directly from the command line, users would benefit from clear error messages and confirmation when interpolation completes.
Apply this diff to add error handling and user feedback:
def main(): args = parse_arguments() model_path: str = args.model_path trained_model_path: str = args.trained_model_path trained_model_weight: float = args.trained_model_weight output_model_path: str | None = args.output_model_path torch_dtype: str | None = args.torch_dtype - interpolate_models( - model_path, - trained_model_path, - trained_model_weight=trained_model_weight, - output_model_path=output_model_path, - torch_dtype=torch_dtype, - ) + try: + print(f"Starting model interpolation with weight {trained_model_weight}...") + output_path = interpolate_models( + model_path, + trained_model_path, + trained_model_weight=trained_model_weight, + output_model_path=output_model_path, + torch_dtype=torch_dtype, + ) + print(f"✅ Interpolation completed successfully!") + print(f"📁 Output model saved to: {output_path}") + except Exception as e: + print(f"❌ Interpolation failed: {e}") + raise
41-42: Optional: Extract error message for consistency.Ruff suggests extracting long messages outside the exception class for consistency. However, this message is concise and context-specific, so extraction is optional.
If you prefer to address the linter hint:
+_UNSUPPORTED_DTYPE_MSG = "Unsupported --torch-dtype: {dtype}" + def interpolate_models( ... if _torch_dtype not in _map: - raise ValueError(f"Unsupported --torch-dtype: {torch_dtype}") + raise ValueError(_UNSUPPORTED_DTYPE_MSG.format(dtype=torch_dtype))Based on learnings
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/scripts/interpolator.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/scripts/interpolator.py (2)
examples/scripts/osft_granite_example.py (1)
main(100-232)examples/scripts/sft_granite_example.py (1)
main(94-205)
🪛 Ruff (0.14.0)
examples/scripts/interpolator.py
42-42: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (4)
examples/scripts/interpolator.py (4)
14-17: Good: torch import added.The torch import addresses the previous feedback and enables proper dtype handling in the interpolation logic.
29-45: LGTM: torch_dtype handling is robust.The string-to-dtype mapping with validation addresses the previous feedback well. The code correctly handles both string literals (with case-insensitive mapping) and torch.dtype objects, and properly supports "auto" mode.
63-64: Verify state_dict key compatibility between checkpoints.The code iterates over
state_dict.keys()(from the original model) to blend weights fromtrained_state_dict. Iftrained_state_dictcontains keys not present instate_dict, they are silently ignored. Ifstate_dicthas keys missing intrained_state_dict, this will raise a KeyError.For fine-tuned versions of the same base model, the keys should match. However, consider adding a guard to verify key compatibility and provide a clear error message if they differ.
Apply this diff to add key compatibility validation:
trained_state_dict = trained_model.state_dict() + if set(state_dict.keys()) != set(trained_state_dict.keys()): + raise ValueError( + "Model state_dict keys do not match. " + f"Original model has {len(state_dict)} keys, trained model has {len(trained_state_dict)} keys." + ) for key in state_dict.keys(): state_dict[key] += trained_state_dict[key] * trained_model_weight
76-111: LGTM: Clear and standard argument parsing.The argument parser is well-structured with helpful descriptions and sensible defaults. The use of
ArgumentDefaultsHelpFormatterimproves user experience.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
examples/scripts/interpolator.py (3)
32-48: LGTM: torch_dtype handling is comprehensive.The string-to-dtype mapping covers common formats (bfloat16/bf16, float16/fp16, float32/fp32) and properly validates input.
Optional: Address static analysis style suggestion
Ruff (TRY003) suggests extracting long error messages to reduce exception instantiation overhead. While the current inline messages are clear and readable for a script, you could optionally define message constants if preferred:
+UNSUPPORTED_DTYPE_MSG = "Unsupported --torch-dtype: {}" + if torch_dtype is not None: if isinstance(torch_dtype, str): _torch_dtype = torch_dtype.lower() if _torch_dtype == "auto": model_kwargs["torch_dtype"] = "auto" else: _map = { "bfloat16": torch.bfloat16, "bf16": torch.bfloat16, "float16": torch.float16, "fp16": torch.float16, "float32": torch.float32, "fp32": torch.float32, } if _torch_dtype not in _map: - raise ValueError(f"Unsupported --torch-dtype: {torch_dtype}") + raise ValueError(UNSUPPORTED_DTYPE_MSG.format(torch_dtype)) model_kwargs["torch_dtype"] = _map[_torch_dtype]As per static analysis hints.
51-54: Consider specifying device placement for large models.For large models like Granite 3.3 8B, loading without an explicit
device_mapmay cause OOM errors or inefficient single-device placement. Consider adding device placement strategy.Add
device_mapto model loading:# load original model model = AutoModelForCausalLM.from_pretrained( model_path, + device_map="auto", # Automatically distribute across available devices **model_kwargs, ) state_dict = model.state_dict() original_model_weight = 1 - trained_model_weight for key in state_dict.keys(): state_dict[key] = state_dict[key] * original_model_weight # load trained model trained_model = AutoModelForCausalLM.from_pretrained( trained_model_path, + device_map="auto", **model_kwargs, )Alternatively, make
device_mapa configurable parameter if users want explicit control over device placement.Also applies to: 61-64
73-74: Document tokenizer assumption.The code copies the tokenizer from the base
model_path, assuming both models use the same tokenizer. This is reasonable for interpolating checkpoints of the same model family, but worth documenting.Add a brief comment or docstring note:
# save interpolated model model.save_pretrained(output_model_path, state_dict=state_dict) - # copy tokenizer + # Copy tokenizer from base model (assumes both models use the same tokenizer) tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer.save_pretrained(output_model_path)Or add to the function docstring:
def interpolate_models( model_path: str, trained_model_path: str, trained_model_weight: float = 0.5, output_model_path: str | None = None, torch_dtype: str | torch.dtype | None = "bfloat16", ) -> str: + """ + Interpolate two model checkpoints with linear weighting. + + Both models must have the same architecture and tokenizer. + The tokenizer from model_path will be copied to the output. + """
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/scripts/interpolator.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/scripts/interpolator.py (2)
examples/scripts/osft_granite_example.py (1)
main(100-232)examples/scripts/sft_granite_example.py (1)
main(94-205)
🪛 Ruff (0.14.0)
examples/scripts/interpolator.py
30-30: Avoid specifying long messages outside the exception class
(TRY003)
45-45: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (3)
examples/scripts/interpolator.py (3)
14-16: LGTM: Imports correctly address previous feedback.The torch import and transformers components are now properly included for dtype handling and model loading.
29-30: LGTM: Weight validation is correct.The validation ensures
trained_model_weightis within [0, 1] and provides a clear error message if not.
79-135: LGTM: CLI interface is well-structured.The argument parser provides clear help messages, appropriate defaults, and correct type handling. The
mainfunction cleanly orchestrates the argument parsing and model interpolation.
|
@Maxusmusti Thank you for the comments. I added the scripts to README.md and addressed all Bot's suggestions. Please review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mtake Thanks again for the contribution! Everything looks good, just a few minor nits, if you have time might be worth addressing, but if not can be merged as-is. I tested the scripts and all seems to be working as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
examples/scripts/interpolator.py (1)
66-67: Add validation that model state_dicts have matching keys.The interpolation loop assumes both models have identical keys. If the trained model checkpoint has a different architecture, is corrupted, or missing parameters, accessing
trained_state_dict[key]will raise aKeyError. This issue was flagged in a previous review but hasn't been fully addressed yet.Add key validation before interpolation:
trained_model = AutoModelForCausalLM.from_pretrained( trained_model_path, **model_kwargs, ) trained_state_dict = trained_model.state_dict() + + # Validate that both models have matching keys + base_keys = set(state_dict.keys()) + trained_keys = set(trained_state_dict.keys()) + if base_keys != trained_keys: + missing_in_trained = base_keys - trained_keys + extra_in_trained = trained_keys - base_keys + error_parts = [] + if missing_in_trained: + error_parts.append(f"Missing in trained model: {list(missing_in_trained)[:5]}") + if extra_in_trained: + error_parts.append(f"Extra in trained model: {list(extra_in_trained)[:5]}") + raise ValueError(f"Model architectures do not match. {'; '.join(error_parts)}") + for key in state_dict.keys(): state_dict[key] += trained_state_dict[key] * trained_model_weight
🧹 Nitpick comments (1)
examples/scripts/osft_granite_example.py (1)
74-97: Consider extracting the duplicated checkpoint helper.The
find_most_recent_checkpointfunction is identical to the one insft_granite_example.py(lines 67-90). To reduce duplication and improve maintainability, consider extracting this helper to a shared utility module.For example, create
examples/scripts/training_utils.py:"""Common utilities for training example scripts.""" import os import glob def find_most_recent_checkpoint(output_dir, pattern="samples_*"): """Find the most recent checkpoint in the training output directory.""" checkpoint_pattern = os.path.join(output_dir, "hf_format", pattern) checkpoint_dirs = glob.glob(checkpoint_pattern) if not checkpoint_dirs: raise ValueError(f"No checkpoints found in {os.path.join(output_dir, 'hf_format')}") return max(checkpoint_dirs, key=os.path.getctime)Then import and use it in both scripts:
from training_utils import find_most_recent_checkpoint
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/README.md(3 hunks)examples/scripts/interpolator.py(1 hunks)examples/scripts/osft_granite_example.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/scripts/osft_granite_example.py (3)
src/training_hub/algorithms/osft.py (1)
osft(448-513)examples/scripts/sft_granite_example.py (2)
find_most_recent_checkpoint(68-91)main(94-205)examples/scripts/interpolator.py (2)
main(119-133)interpolate_models(19-78)
examples/scripts/interpolator.py (2)
examples/scripts/osft_granite_example.py (1)
main(100-232)examples/scripts/sft_granite_example.py (1)
main(94-205)
🪛 LanguageTool
examples/README.md
[grammar] ~110-~110: There might be a mistake here.
Context: ...ne with linear interpolation. Script: - [interpolator.py](scripts/interpolator.py...
(QB_NEW_EN)
🪛 Ruff (0.14.0)
examples/scripts/osft_granite_example.py
1-1: Shebang is present but file is not executable
(EXE001)
71-71: Probable insecure usage of temporary file or directory: "/dev/shm/data/"
(S108)
92-92: Avoid specifying long messages outside the exception class
(TRY003)
134-134: Avoid specifying long messages outside the exception class
(TRY003)
159-159: Local variable result is assigned to but never used
Remove assignment to unused variable result
(F841)
221-221: Do not catch blind exception: Exception
(BLE001)
examples/scripts/interpolator.py
30-30: Avoid specifying long messages outside the exception class
(TRY003)
45-45: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (4)
examples/README.md (1)
30-30: LGTM! Documentation additions are clear and well-structured.The new Granite 3.3 8B example entries follow the existing documentation pattern, and the Model Interpolation section provides both CLI and Python usage examples with generic paths as requested in past reviews.
Also applies to: 64-64, 106-123
examples/scripts/interpolator.py (1)
19-78: Other past review feedback has been addressed.The function now includes:
- Proper weight validation (lines 29-30)
- Correct torch dtype handling with string-to-dtype mapping (lines 32-48)
- Success message indicating where the merged model was saved (line 76)
examples/scripts/osft_granite_example.py (2)
212-219: Interpolation logic correctly handles edge cases.The condition
0.0 < trained_model_weight < 1.0properly excludes edge cases where:
trained_model_weight == 0.0would produce the base model (no interpolation needed)trained_model_weight == 1.0would produce the trained model (no interpolation needed)This is the correct behavior and avoids unnecessary interpolation overhead.
1-236: Overall script structure is solid.The script provides:
- Comprehensive CLI with sensible defaults
- Clear configuration printing before training
- Robust error handling with helpful troubleshooting tips
- Optional post-training interpolation
- Proper integration with training_hub.osft
The unused
resultvariable (line 159) flagged by static analysis is acceptable sinceosft()performs side effects (training and checkpointing) even if the return value isn't used.
|
Hi @Maxusmusti |
Maxusmusti
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me!
This PR is to add training examples for granite-3.3-8b-instruct model.
granite-3.3-8b-instruct is the student model for Japanese and we have been using the code for training the model.
Summary by CodeRabbit
New Features
Documentation