-
Notifications
You must be signed in to change notification settings - Fork 18
Adding Unsloth for LoRA / QLoRA SFT Algorithm #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughAdds LoRA + SFT support: new Unsloth backend, LoRA algorithm and PEFT extender modules, example scripts and docs, an optional Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant API as lora_sft()
participant Alg as LoRASFTAlgorithm
participant Ext as LoRAPEFTExtender
participant Backend as UnslothLoRABackend
participant Unsloth
participant Data as Dataset Prep
participant Trainer as SFT Trainer
User->>API: call(model_path, data_path, params)
API->>Alg: instantiate(params)
Alg->>Ext: apply PEFT defaults
Alg->>Backend: train(params)
Backend->>Unsloth: load model (opt: 4/8-bit, device_map)
Unsloth-->>Backend: return model & tokenizer
Backend->>Backend: apply LoRA config
Backend->>Data: prepare dataset (messages / Alpaca mapping)
Data-->>Backend: dataset ready
Backend->>Trainer: build args & start training
Trainer-->>Backend: training complete
Backend->>Backend: save checkpoint & tokenizer
Backend-->>Alg: return training result
Alg-->>API: return artifact
API-->>User: done
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
379ec3a to
586b781
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 11
🧹 Nitpick comments (4)
examples/scripts/lora_example.py (1)
1-1: Consider making the script executable.The shebang is present but the file is not marked executable. Consider running
chmod +x examples/scripts/lora_example.pyto match the shebang.src/training_hub/algorithms/peft_extender.py (1)
106-111: Consider sorting__all__alphabetically.While the current logical grouping (base class, implementation, functions) is reasonable, alphabetical sorting would align with common Python conventions and satisfy the linter.
__all__ = [ - 'PEFTExtender', 'LoRAPEFTExtender', + 'PEFTExtender', + 'apply_lora_defaults', 'get_lora_parameters', - 'apply_lora_defaults' ]src/training_hub/algorithms/lora.py (2)
199-236: Effective batch size computation can overshoot the requested valueIn
_build_training_args, wheneffective_batch_sizeis provided you compute:micro_batch_size = effective_batch_size // (gradient_accumulation_steps * num_gpus) micro_batch_size = max(1, micro_batch_size)If
effective_batch_size < gradient_accumulation_steps * num_gpus, this yieldsmicro_batch_size = 0 → 1, so the actual effective batch becomesgradient_accumulation_steps * num_gpus, which is larger than the requestedeffective_batch_size.This may be acceptable, but it’s worth either:
- Documenting that
effective_batch_sizeis a lower bound, or- Adjusting
gradient_accumulation_steps(or warning) when the requested effective batch cannot be met exactly.- if effective_batch_size is not None: - # Calculate micro_batch_size from effective_batch_size - micro_batch_size = effective_batch_size // (gradient_accumulation_steps * num_gpus) - micro_batch_size = max(1, micro_batch_size) + if effective_batch_size is not None: + # Calculate micro_batch_size from effective_batch_size + denom = gradient_accumulation_steps * num_gpus + micro_batch_size = max(1, effective_batch_size // max(1, denom)) + if micro_batch_size * denom != effective_batch_size: + # Optional: log or warn about mismatch instead of silently overshooting. + pass
540-615: Optional param metadata omitssave_model, which the backend reads
UnslothLoRABackend.execute_trainingcheckstraining_params.get('save_model', True)before saving, butLoRASFTAlgorithm.get_optional_paramsdoes not listsave_model.If callers rely on
get_optional_params()to drive validation or CLI argument exposure, they won’t discover thatsave_modelis supported.Consider adding
save_model: boolto the optional parameter definitions.extended_params = { @@ 'enable_model_splitting': bool, + 'save_model': bool, }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
README.md(3 hunks)examples/README.md(1 hunks)examples/docs/lora_usage.md(1 hunks)examples/scripts/lora_example.py(1 hunks)pyproject.toml(2 hunks)src/training_hub/__init__.py(1 hunks)src/training_hub/algorithms/lora.py(1 hunks)src/training_hub/algorithms/peft_extender.py(1 hunks)src/training_hub/profiling/memory_estimator.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/training_hub/__init__.py (1)
src/training_hub/algorithms/lora.py (3)
lora_sft(623-823)LoRASFTAlgorithm(277-615)UnslothLoRABackend(11-272)
src/training_hub/algorithms/lora.py (3)
src/training_hub/algorithms/__init__.py (6)
Algorithm(6-22)Backend(24-30)AlgorithmRegistry(33-79)register_algorithm(40-44)register_backend(47-51)create_algorithm(82-96)src/training_hub/algorithms/sft.py (2)
sft(169-248)SFTAlgorithm(40-161)src/training_hub/algorithms/peft_extender.py (7)
LoRAPEFTExtender(26-90)get_lora_parameters(93-96)apply_lora_defaults(99-102)apply_peft_config(21-23)apply_peft_config(54-90)get_peft_params(16-18)get_peft_params(29-52)
examples/scripts/lora_example.py (1)
src/training_hub/algorithms/lora.py (1)
lora_sft(623-823)
🪛 LanguageTool
examples/docs/lora_usage.md
[style] ~186-~186: As an alternative to the over-used intensifier ‘very’, consider replacing this phrase.
Context: ...proc_per_node * micro_batch_size- For very large models, tryenable_model_splitting=Tru...
(EN_WEAK_ADJECTIVE)
🪛 Ruff (0.14.4)
src/training_hub/algorithms/peft_extender.py
106-111: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
src/training_hub/algorithms/lora.py
24-27: Avoid specifying long messages outside the exception class
(TRY003)
29-32: Avoid specifying long messages outside the exception class
(TRY003)
34-37: Avoid specifying long messages outside the exception class
(TRY003)
43-43: Local variable torchrun_params is assigned to but never used
Remove assignment to unused variable torchrun_params
(F841)
64-64: Local variable dataset_type is assigned to but never used
Remove assignment to unused variable dataset_type
(F841)
96-96: Local variable load_in_8bit is assigned to but never used
Remove assignment to unused variable load_in_8bit
(F841)
100-100: Redefinition of unused os from line 1
Remove definition: os
(F811)
208-208: Redefinition of unused os from line 1
Remove definition: os
(F811)
examples/scripts/lora_example.py
1-1: Shebang is present but file is not executable
(EXE001)
45-45: Do not catch blind exception: Exception
(BLE001)
89-89: Do not catch blind exception: Exception
(BLE001)
102-102: String contains ambiguous ℹ (INFORMATION SOURCE). Did you mean i (LATIN SMALL LETTER I)?
(RUF001)
105-105: Local variable result is assigned to but never used
Remove assignment to unused variable result
(F841)
136-136: Do not catch blind exception: Exception
(BLE001)
221-221: Do not catch blind exception: Exception
(BLE001)
240-240: Do not catch blind exception: Exception
(BLE001)
246-246: Local variable result is assigned to but never used
Remove assignment to unused variable result
(F841)
263-263: Do not catch blind exception: Exception
(BLE001)
🔇 Additional comments (20)
src/training_hub/profiling/memory_estimator.py (1)
1-4: LGTM! Good compatibility improvement.The try/except fallback pattern correctly handles Python versions before 3.12 where
typing.overridewas introduced.examples/README.md (1)
84-112: LGTM! Documentation is clear and comprehensive.The new LoRA + SFT section is well-structured and highlights the important difference that multi-GPU training requires
torchrun(unlike SFT/OSFT). The quick example demonstrates the key parameters effectively.README.md (3)
10-10: LGTM! Support matrix accurately reflects the new LoRA + SFT capability.
66-88: LGTM! Clear and informative documentation.The LoRA + SFT section effectively communicates the key features and provides a practical code example with appropriate parameters.
111-120: LGTM! Installation instructions are clear.The LoRA installation section provides straightforward guidance and helpful context about the included dependencies.
examples/docs/lora_usage.md (5)
1-41: LGTM! Excellent quick start documentation.The introduction clearly explains LoRA benefits and provides a complete working example. The explicit differentiation between single-GPU (Python) and multi-GPU (torchrun) launch commands is very helpful.
54-75: LGTM! Parameter documentation is thorough and practical.The core LoRA settings and QLoRA example provide clear guidance on parameter selection with helpful context on when to use different values.
77-102: LGTM! Dataset format examples are clear and accurate.The JSON examples for both Messages and Alpaca formats are correct and well-presented. The memory benefits section appropriately notes that savings depend on configuration.
141-166: LGTM! Advanced configuration examples are helpful.The custom target modules and W&B integration examples provide practical guidance for common customization needs.
176-190: No issues found—parameter is valid and properly documented.The
enable_model_splittingparameter is a confirmed, supported option in the LoRA implementation. It appears in the function signatures with documentation describing it as enablingdevice_map="balanced"for large models (default: False). The troubleshooting guidance on line 186 is accurate and correctly advises using this parameter for very large models.src/training_hub/__init__.py (1)
8-28: LGTM! Public API exports are clean and consistent.The additions to
__all__properly expose the new LoRA functionality and follow the existing pattern used for SFT and OSFT algorithms.pyproject.toml (1)
11-11: LGTM! License format simplification.The change from dictionary to string format is valid and follows the more common convention.
examples/scripts/lora_example.py (5)
11-46: LGTM! Basic example is clear and practical.The function provides a good starting point with a small model and reasonable default parameters. The broad exception handling is acceptable for example code.
51-90: LGTM! QLoRA example demonstrates appropriate parameter adjustments.The function correctly shows how to configure LoRA for 4-bit quantization with higher rank and lower learning rate, which are sensible defaults for quantized training.
93-137: LGTM! Distributed training example is well-documented.The function clearly explains the torchrun requirement and demonstrates appropriate parameters for multi-GPU training. The unused
resultvariable is acceptable in example code.
140-199: LGTM! Sample data creation is helpful and well-structured.The function creates appropriate sample datasets in both Messages and Alpaca formats, providing useful test data for users trying the examples.
202-299: LGTM! Comprehensive demonstration of dataset formats and usage.The
data_format_examplesfunction clearly shows how to work with different dataset formats, and the main block provides excellent usage instructions and benefits summary. The unused variables and broad exception handling are acceptable for example code.src/training_hub/algorithms/peft_extender.py (3)
12-23: LGTM! Clean abstraction design.The
PEFTExtenderABC provides a clear interface for adding PEFT functionality to algorithms. The two abstract methods establish a good contract for parameter discovery and configuration.
26-90: LGTM! Comprehensive LoRA implementation.The
LoRAPEFTExtenderprovides thorough coverage of LoRA parameters including quantization and advanced options. The default values are sensible (especially dropout=0.0 for Unsloth optimization), and the metadata tagging (peft_enabled, peft_type) is helpful for downstream processing.
93-102: LGTM! Useful convenience functions.The wrapper functions provide a clean API for common operations without requiring users to instantiate the extender directly.
| def _prepare_dataset(self, params: Dict[str, Any], tokenizer) -> Any: | ||
| """Prepare dataset for training.""" | ||
| from datasets import load_dataset | ||
|
|
||
| # Load dataset | ||
| if params['data_path'].endswith('.jsonl') or params['data_path'].endswith('.json'): | ||
| dataset = load_dataset('json', data_files=params['data_path'], split='train') | ||
| else: | ||
| dataset = load_dataset(params['data_path'], split='train') | ||
|
|
||
| # Handle different dataset formats | ||
| dataset_type = params.get('dataset_type', 'chat_template') | ||
|
|
||
| if dataset_type == 'chat_template': | ||
| # Convert messages format using chat template | ||
| def format_chat_template(examples): | ||
| # examples['messages'] is a list of conversations (batched) | ||
| texts = [] | ||
| for conversation in examples['messages']: | ||
| text = tokenizer.apply_chat_template( | ||
| conversation, | ||
| tokenize=False, | ||
| add_generation_prompt=False | ||
| ) | ||
| texts.append(text) | ||
| return {"text": texts} | ||
|
|
||
| dataset = dataset.map(format_chat_template, batched=True) | ||
|
|
||
| elif dataset_type == 'alpaca': | ||
| # Convert alpaca format to text | ||
| def format_alpaca(examples): | ||
| texts = [] | ||
| for i in range(len(examples['instruction'])): | ||
| instruction = examples['instruction'][i] | ||
| input_text = examples.get('input', [''] * len(examples['instruction']))[i] | ||
| output = examples['output'][i] | ||
|
|
||
| if input_text: | ||
| text = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}" | ||
| else: | ||
| text = f"### Instruction:\n{instruction}\n\n### Response:\n{output}" | ||
| texts.append(text) | ||
| return {"text": texts} | ||
|
|
||
| dataset = dataset.map(format_alpaca, batched=True) | ||
|
|
||
| return dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dataset field customisation is ignored; only hard-coded columns are used
LoRASFTAlgorithm.train exposes dataset_type, field_messages, field_instruction, field_input, and field_output, but _prepare_dataset:
- Always assumes
messagesfor chat data andinstruction/input/outputfor Alpaca, ignoring thefield_*parameters. - Only implements two dataset types (
chat_templateandalpaca) even though docstrings mention other formats likeinput_output. - Recomputes
dataset_typehere, while the earlierdataset_typevariable inexecute_trainingis set but never used.
This will break for datasets that use different column names and makes the field_* parameters misleading.
I’d suggest:
- Respecting the
field_*names fromparamswith sensible defaults, and - Explicitly validating/handling only the dataset types you actually support.
- # Handle different dataset formats
- dataset_type = params.get('dataset_type', 'chat_template')
+ # Handle different dataset formats
+ dataset_type = params.get('dataset_type', 'chat_template')
+ field_messages = params.get('field_messages', 'messages')
+ field_instruction = params.get('field_instruction', 'instruction')
+ field_input = params.get('field_input', 'input')
+ field_output = params.get('field_output', 'output')
@@
- if dataset_type == 'chat_template':
+ if dataset_type == 'chat_template':
@@
- for conversation in examples['messages']:
+ for conversation in examples[field_messages]:
@@
- elif dataset_type == 'alpaca':
+ elif dataset_type == 'alpaca':
@@
- for i in range(len(examples['instruction'])):
- instruction = examples['instruction'][i]
- input_text = examples.get('input', [''] * len(examples['instruction']))[i]
- output = examples['output'][i]
+ n = len(examples[field_instruction])
+ for i in range(n):
+ instruction = examples[field_instruction][i]
+ input_text = examples.get(field_input, [''] * n)[i]
+ output = examples[field_output][i]
+ ...You might also want to raise a clear error for unsupported dataset_type values instead of silently returning the raw dataset.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _prepare_dataset(self, params: Dict[str, Any], tokenizer) -> Any: | |
| """Prepare dataset for training.""" | |
| from datasets import load_dataset | |
| # Load dataset | |
| if params['data_path'].endswith('.jsonl') or params['data_path'].endswith('.json'): | |
| dataset = load_dataset('json', data_files=params['data_path'], split='train') | |
| else: | |
| dataset = load_dataset(params['data_path'], split='train') | |
| # Handle different dataset formats | |
| dataset_type = params.get('dataset_type', 'chat_template') | |
| if dataset_type == 'chat_template': | |
| # Convert messages format using chat template | |
| def format_chat_template(examples): | |
| # examples['messages'] is a list of conversations (batched) | |
| texts = [] | |
| for conversation in examples['messages']: | |
| text = tokenizer.apply_chat_template( | |
| conversation, | |
| tokenize=False, | |
| add_generation_prompt=False | |
| ) | |
| texts.append(text) | |
| return {"text": texts} | |
| dataset = dataset.map(format_chat_template, batched=True) | |
| elif dataset_type == 'alpaca': | |
| # Convert alpaca format to text | |
| def format_alpaca(examples): | |
| texts = [] | |
| for i in range(len(examples['instruction'])): | |
| instruction = examples['instruction'][i] | |
| input_text = examples.get('input', [''] * len(examples['instruction']))[i] | |
| output = examples['output'][i] | |
| if input_text: | |
| text = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}" | |
| else: | |
| text = f"### Instruction:\n{instruction}\n\n### Response:\n{output}" | |
| texts.append(text) | |
| return {"text": texts} | |
| dataset = dataset.map(format_alpaca, batched=True) | |
| return dataset | |
| def _prepare_dataset(self, params: Dict[str, Any], tokenizer) -> Any: | |
| """Prepare dataset for training.""" | |
| from datasets import load_dataset | |
| # Load dataset | |
| if params['data_path'].endswith('.jsonl') or params['data_path'].endswith('.json'): | |
| dataset = load_dataset('json', data_files=params['data_path'], split='train') | |
| else: | |
| dataset = load_dataset(params['data_path'], split='train') | |
| # Handle different dataset formats | |
| dataset_type = params.get('dataset_type', 'chat_template') | |
| field_messages = params.get('field_messages', 'messages') | |
| field_instruction = params.get('field_instruction', 'instruction') | |
| field_input = params.get('field_input', 'input') | |
| field_output = params.get('field_output', 'output') | |
| if dataset_type == 'chat_template': | |
| # Convert messages format using chat template | |
| def format_chat_template(examples): | |
| # examples['messages'] is a list of conversations (batched) | |
| texts = [] | |
| for conversation in examples[field_messages]: | |
| text = tokenizer.apply_chat_template( | |
| conversation, | |
| tokenize=False, | |
| add_generation_prompt=False | |
| ) | |
| texts.append(text) | |
| return {"text": texts} | |
| dataset = dataset.map(format_chat_template, batched=True) | |
| elif dataset_type == 'alpaca': | |
| # Convert alpaca format to text | |
| def format_alpaca(examples): | |
| texts = [] | |
| n = len(examples[field_instruction]) | |
| for i in range(n): | |
| instruction = examples[field_instruction][i] | |
| input_text = examples.get(field_input, [''] * n)[i] | |
| output = examples[field_output][i] | |
| if input_text: | |
| text = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}" | |
| else: | |
| text = f"### Instruction:\n{instruction}\n\n### Response:\n{output}" | |
| texts.append(text) | |
| return {"text": texts} | |
| dataset = dataset.map(format_alpaca, batched=True) | |
| return dataset |
🤖 Prompt for AI Agents
In src/training_hub/algorithms/lora.py around lines 149-196, _prepare_dataset
currently hardcodes column names and dataset types (expects 'messages' for chat
and 'instruction'/'input'/'output' for alpaca), ignores
params['field_messages']/['field_instruction']/['field_input']/['field_output'],
and recomputes dataset_type instead of using the value passed in params; update
the function to read dataset_type = params.get('dataset_type', 'chat_template')
once and then select behavior accordingly, use field names from params with
sensible defaults (e.g., field_messages='messages',
field_instruction='instruction', field_input='input', field_output='output')
when extracting columns inside formatters, explicitly handle only supported
dataset_type values (chat_template, alpaca, input_output, etc.) and raise a
clear ValueError for unsupported types, and return the mapped dataset; also
ensure batching logic uses the provided field names and preserves existing
tokenize/add_generation_prompt behavior for chat_template.
| # LoRA parameters (from PEFT extender) | ||
| 'lora_r': lora_r, | ||
| 'lora_alpha': lora_alpha, | ||
| 'lora_dropout': lora_dropout, | ||
| 'target_modules': target_modules, | ||
| 'use_rslora': use_rslora, | ||
| 'use_dora': use_dora, | ||
| 'init_lora_weights': init_lora_weights, | ||
| 'rank_pattern': rank_pattern, | ||
| 'alpha_pattern': alpha_pattern, | ||
| 'loftq_config': loftq_config, | ||
| # Quantization parameters (QLoRA) | ||
| 'load_in_4bit': load_in_4bit, | ||
| 'load_in_8bit': load_in_8bit, | ||
| 'bnb_4bit_quant_type': bnb_4bit_quant_type, | ||
| 'bnb_4bit_compute_dtype': bnb_4bit_compute_dtype, | ||
| 'bnb_4bit_use_double_quant': bnb_4bit_use_double_quant, | ||
| # Extended training parameters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Several advertised knobs are currently unused in this backend
The algorithm and PEFT extender expose a broad set of knobs:
- QLoRA-related:
bnb_4bit_quant_type,bnb_4bit_compute_dtype,bnb_4bit_use_double_quant - Training/meta:
early_stopping_patience,wandb_watch - Multi-GPU:
enable_model_splitting(used), plus advanced LoRA options discussed earlier.
However, in the Unsloth backend:
- The
bnb_4bit_*fields are never read. early_stopping_patienceis never converted into a callback or trainer argument.wandb_watchis never used with W&B APIs.max_tokens_per_gpuis accepted and surfaced viaget_optional_paramsbut not used in batch sizing or truncation logic.
This is not strictly a bug, but the API surface suggests these parameters have an effect when they currently don’t.
I’d recommend either wiring them into the backend (where supported by TRL/Unsloth/W&B) or explicitly documenting them as “reserved/unimplemented” for now to reduce surprises.
Also applies to: 576-607
🌐 Web query:
Check which of the following SFT/TRL/Unsloth features have first-class support and the expected integration points: QLoRA `bnb_4bit_*` options, `max_tokens_per_gpu`-style sequence length limits, early stopping callbacks, and W&B `wandb.watch` integration.
💡 Result:
Summary — first-class support / integration points
-
QLoRA bnb_4bit_* options (bnb_4bit_quant_type, bnb_4bit_use_double_quant, bnb_4bit_compute_dtype, etc.): First‑class support via BitsAndBytesConfig used by Transformers/QLoRA workflows (used in AutoModelForCausalLM.from_pretrained and many examples). Use BitsAndBytesConfig / quantization_config to pass bnb_4bit_* flags. [1][2]
-
max_tokens_per_gpu‑style per‑GPU token limits: Supported by several training stacks (Megatron-style / Slime / related training frameworks) as a first‑class runtime parameter for dynamic/batched packing and OOM control (often used with --use-dynamic-batch-size). Not a single HF‑Transformers standard name, but widely supported in production training frameworks. Cite Slime / examples that document max_tokens_per_gpu semantics. [3][4]
-
Early stopping callbacks: First‑class support via Transformers callbacks (e.g., transformers.EarlyStoppingCallback) and TRL’s SFTTrainer accepts standard callbacks — you can pass EarlyStoppingCallback to SFTTrainer. [5]
-
Weights & Biases wandb.watch integration: Supported — Hugging Face provides WandbCallback/W&B integration (Trainer integrations and env vars like WANDB_WATCH); many SFT/TRL workflows also call wandb.init / wandb.watch in trainer setup/examples. Use Transformers’ WandbCallback or call wandb.init/wandb.watch yourself inside trainer code. [6][7]
Sources
- QLoRA / BitsAndBytes examples and BitsAndBytesConfig usage. [1]
- vLLM / BitsAndBytesConfig source and QLoRA examples showing bnb_4bit_* options. [2]
- Slime / Megatron-style docs describing max_tokens_per_gpu behavior for dynamic batching and OOM control. [3]
- Training-hub / examples showing max_tokens_per_gpu in training configs. [4]
- TRL SFTTrainer example using transformers.EarlyStoppingCallback. [5]
- Hugging Face transformers callbacks / WandbCallback documentation (WANDB_WATCH etc.). [6]
- Example trainers calling wandb.watch (project examples / community trainers). [7]
If you want, I can open the specific repo/docs you use (SFT/TRL/Unsloth code) and verify the exact API names and where to pass each option.
I need to inspect the actual backend code to verify whether these parameters are truly unused. Let me examine the file and trace parameter flow.
🏁 Script executed:
#!/bin/bash
# Inspect the target lines and surrounding context
echo "=== Lines 470-487 context ==="
sed -n '460,500p' src/training_hub/algorithms/lora.py
echo -e "\n=== Lines 576-607 context ==="
sed -n '570,620p' src/training_hub/algorithms/lora.pyLength of output: 3553
🏁 Script executed:
#!/bin/bash
# Search for where parameters are initialized/defined
echo "=== Parameter definitions ==="
rg -n "early_stopping_patience|wandb_watch|bnb_4bit_quant_type|bnb_4bit_compute_dtype|bnb_4bit_use_double_quant|max_tokens_per_gpu" src/training_hub/algorithms/lora.py -C 2Length of output: 5711
🏁 Script executed:
#!/bin/bash
# Search for Unsloth backend implementation
echo "=== Unsloth backend references ==="
rg -n "unsloth|Unsloth" src/training_hub/algorithms/lora.py -C 3Length of output: 4907
🏁 Script executed:
#!/bin/bash
# Search for SFTTrainer initialization and parameter passing
echo "=== SFTTrainer setup ==="
rg -n "SFTTrainer|from_pretrained" src/training_hub/algorithms/lora.py -C 5Length of output: 1907
Let me trace where these specific parameters are actually consumed in the backend:
---
🏁 Script executed:
```shell
#!/bin/bash
# Examine _build_training_args full implementation
echo "=== _build_training_args implementation ==="
sed -n '199,280p' src/training_hub/algorithms/lora.py
Length of output: 3636
🏁 Script executed:
#!/bin/bash
# Examine _apply_lora_config full implementation
echo "=== _apply_lora_config implementation ==="
sed -n '125,180p' src/training_hub/algorithms/lora.pyLength of output: 2376
🏁 Script executed:
#!/bin/bash
# Search for BitsAndBytesConfig usage
echo "=== BitsAndBytesConfig usage ==="
rg -n "BitsAndBytesConfig" src/training_hub/algorithms/lora.py -C 3Length of output: 32
🏁 Script executed:
#!/bin/bash
# Search for callback creation/EarlyStoppingCallback/wandb callbacks
echo "=== Callback usage ==="
rg -n "Callback|callback|wandb|WandbCallback" src/training_hub/algorithms/lora.py -C 2Length of output: 2846
🏁 Script executed:
#!/bin/bash
# Check what happens with max_tokens_per_gpu, early_stopping_patience, wandb_watch
echo "=== Parameter consumption trace ==="
rg -n "max_tokens_per_gpu|get\('early_stopping_patience|get\('wandb_watch" src/training_hub/algorithms/lora.pyLength of output: 271
Wire unused parameters into backend or explicitly document as reserved
Verification confirms the review comment is accurate. The backend accepts these parameters but does not use them:
-
bnb_4bit_quant_type, bnb_4bit_compute_dtype, bnb_4bit_use_double_quant: Defined and passed to algorithm_params, but never read. Only
load_in_4bitis used in_load_unsloth_model(line 115); noBitsAndBytesConfigis created. -
early_stopping_patience: Defined and passed, but never converted to a callback or passed to
SFTConfig. NoEarlyStoppingCallbackinstantiation anywhere. -
wandb_watch: Defined and passed, but only
wandb_projectandwandb_run_nameare used inSFTConfig(lines 268–269). Nowandb.watch()call orWandbCallbackwith watch parameter. -
max_tokens_per_gpu: Defined and passed (line 464), but never read or used in batch sizing/truncation logic.
Additionally, in _apply_lora_config, use_rslora (line 143) and loftq_config (line 144) are hardcoded instead of reading from params, suggesting a broader pattern of parameter passthrough without integration.
These parameters do have first-class support in the underlying libraries (Transformers/TRL/Unsloth/W&B). Either integrate them into the backend (as suggested in the review) or explicitly mark them as reserved/unimplemented to set user expectations correctly.
| def lora_sft(model_path: str, | ||
| data_path: str, | ||
| ckpt_output_dir: str, | ||
| backend: str = "unsloth", | ||
| # LoRA-specific parameters | ||
| lora_r: Optional[int] = None, | ||
| lora_alpha: Optional[int] = None, | ||
| lora_dropout: Optional[float] = None, | ||
| target_modules: Optional[List[str]] = None, | ||
| # Training parameters | ||
| num_epochs: Optional[int] = None, | ||
| effective_batch_size: Optional[int] = None, | ||
| micro_batch_size: Optional[int] = None, | ||
| gradient_accumulation_steps: Optional[int] = None, | ||
| learning_rate: Optional[float] = None, | ||
| max_seq_len: Optional[int] = None, | ||
| lr_scheduler: Optional[str] = None, | ||
| warmup_steps: Optional[int] = None, | ||
| # Quantization parameters | ||
| load_in_4bit: Optional[bool] = None, | ||
| load_in_8bit: Optional[bool] = None, | ||
| bnb_4bit_quant_type: Optional[str] = None, | ||
| bnb_4bit_compute_dtype: Optional[str] = None, | ||
| bnb_4bit_use_double_quant: Optional[bool] = None, | ||
| # Optimization parameters | ||
| flash_attention: Optional[bool] = None, | ||
| sample_packing: Optional[bool] = None, | ||
| bf16: Optional[bool] = None, | ||
| fp16: Optional[bool] = None, | ||
| tf32: Optional[bool] = None, | ||
| # Saving and logging | ||
| save_steps: Optional[int] = None, | ||
| eval_steps: Optional[int] = None, | ||
| logging_steps: Optional[int] = None, | ||
| save_total_limit: Optional[int] = None, | ||
| # Weights & Biases | ||
| wandb_project: Optional[str] = None, | ||
| wandb_entity: Optional[str] = None, | ||
| wandb_watch: Optional[str] = None, | ||
| # Early stopping | ||
| early_stopping_patience: Optional[int] = None, | ||
| # Dataset format parameters | ||
| dataset_type: Optional[str] = None, | ||
| field_messages: Optional[str] = None, | ||
| field_instruction: Optional[str] = None, | ||
| field_input: Optional[str] = None, | ||
| field_output: Optional[str] = None, | ||
| # Distributed training parameters | ||
| nproc_per_node: Optional[Union[str, int]] = None, | ||
| nnodes: Optional[int] = None, | ||
| node_rank: Optional[int] = None, | ||
| rdzv_id: Optional[Union[str, int]] = None, | ||
| rdzv_endpoint: Optional[str] = None, | ||
| master_addr: Optional[str] = None, | ||
| master_port: Optional[int] = None, | ||
| # Multi-GPU model splitting | ||
| enable_model_splitting: Optional[bool] = None, | ||
| **kwargs) -> Any: | ||
| """Convenience function to run LoRA + SFT training. | ||
| Args: | ||
| model_path: Path to the model to fine-tune (local path or HuggingFace model ID) | ||
| data_path: Path to the training data (JSON/JSONL format) | ||
| ckpt_output_dir: Directory to save checkpoints and outputs | ||
| backend: Backend implementation to use (default: "unsloth") | ||
| LoRA Parameters: | ||
| lora_r: LoRA rank (default: 16) | ||
| lora_alpha: LoRA alpha parameter (default: 32) | ||
| lora_dropout: LoRA dropout rate (default: 0.1) | ||
| target_modules: List of module names to apply LoRA to (default: auto-detect) | ||
| Training Parameters: | ||
| num_epochs: Number of training epochs (default: 3) | ||
| effective_batch_size: Effective batch size across all GPUs | ||
| micro_batch_size: Batch size per GPU (default: 1) | ||
| gradient_accumulation_steps: Steps to accumulate gradients (default: 1) | ||
| learning_rate: Learning rate (default: 2e-4) | ||
| max_seq_len: Maximum sequence length (default: 2048) | ||
| lr_scheduler: Learning rate scheduler (default: 'cosine') | ||
| warmup_steps: Number of warmup steps (default: 10) | ||
| Quantization Parameters (QLoRA): | ||
| load_in_4bit: Use 4-bit quantization for QLoRA | ||
| load_in_8bit: Use 8-bit quantization | ||
| bnb_4bit_quant_type: 4-bit quantization type (default: 'nf4') | ||
| bnb_4bit_compute_dtype: Compute dtype for 4-bit (default: 'bfloat16') | ||
| bnb_4bit_use_double_quant: Use double quantization (default: True) | ||
| Optimization Parameters: | ||
| flash_attention: Use Flash Attention for memory efficiency (default: True) | ||
| sample_packing: Pack multiple samples per sequence (default: True) | ||
| bf16: Use bfloat16 precision (default: True) | ||
| fp16: Use float16 precision (default: False) | ||
| tf32: Use TensorFloat-32 (default: True) | ||
| Logging and Saving: | ||
| save_steps: Steps between checkpoints (default: 500) | ||
| eval_steps: Steps between evaluations (default: 500) | ||
| logging_steps: Steps between log outputs (default: 10) | ||
| save_total_limit: Maximum number of checkpoints to keep (default: 3) | ||
| wandb_project: Weights & Biases project name | ||
| wandb_entity: Weights & Biases entity name | ||
| wandb_watch: What to watch in W&B ('gradients', 'all', etc.) | ||
| early_stopping_patience: Early stopping patience (epochs) | ||
| Distributed Training: | ||
| nproc_per_node: Number of processes (GPUs) per node | ||
| nnodes: Total number of nodes | ||
| node_rank: Rank of this node (0 to nnodes-1) | ||
| rdzv_id: Unique job ID for rendezvous | ||
| rdzv_endpoint: Master node endpoint for multi-node training | ||
| master_addr: Master node address for distributed training | ||
| master_port: Master node port for distributed training | ||
| Multi-GPU Configuration: | ||
| enable_model_splitting: Enable device_map="balanced" for large models (default: False) | ||
| Use for models that don't fit on a single GPU (e.g., Llama 70B) | ||
| For smaller models, use standard DDP with torchrun instead | ||
| Advanced: | ||
| **kwargs: Additional parameters passed to the backend | ||
| Returns: | ||
| Dictionary containing trained model, tokenizer, and trainer | ||
| Example: | ||
| # Basic LoRA training | ||
| result = lora( | ||
| model_path="microsoft/DialoGPT-medium", | ||
| data_path="./training_data.jsonl", | ||
| ckpt_output_dir="./outputs", | ||
| lora_r=16, | ||
| lora_alpha=32, | ||
| num_epochs=3, | ||
| learning_rate=2e-4 | ||
| ) | ||
| # QLoRA with 4-bit quantization | ||
| result = lora( | ||
| model_path="meta-llama/Llama-2-7b-hf", | ||
| data_path="./training_data.jsonl", | ||
| ckpt_output_dir="./outputs", | ||
| load_in_4bit=True, | ||
| lora_r=64, | ||
| lora_alpha=128, | ||
| max_seq_len=4096 | ||
| ) | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lora_sft docstring example uses the wrong function name and misses some parameters
In the lora_sft docstring:
- The examples call
lora(...)instead oflora_sft(...), which will fail if copied verbatim. - Some knobs available on
LoRASFTAlgorithm.train(e.g.,weight_decay,max_grad_norm,wandb_run_name, LoRA advanced options) are only accessible via**kwargsand not reflected in the documented signature.
I’d at least fix the function name in the examples, and ideally add a short note that additional LoRA/PEFT and training options can be passed via **kwargs to keep the docs honest.
- # Basic LoRA training
- result = lora(
+ # Basic LoRA+SFT training
+ result = lora_sft(
@@
- # QLoRA with 4-bit quantization
- result = lora(
+ # QLoRA with 4-bit quantization
+ result = lora_sft(📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def lora_sft(model_path: str, | |
| data_path: str, | |
| ckpt_output_dir: str, | |
| backend: str = "unsloth", | |
| # LoRA-specific parameters | |
| lora_r: Optional[int] = None, | |
| lora_alpha: Optional[int] = None, | |
| lora_dropout: Optional[float] = None, | |
| target_modules: Optional[List[str]] = None, | |
| # Training parameters | |
| num_epochs: Optional[int] = None, | |
| effective_batch_size: Optional[int] = None, | |
| micro_batch_size: Optional[int] = None, | |
| gradient_accumulation_steps: Optional[int] = None, | |
| learning_rate: Optional[float] = None, | |
| max_seq_len: Optional[int] = None, | |
| lr_scheduler: Optional[str] = None, | |
| warmup_steps: Optional[int] = None, | |
| # Quantization parameters | |
| load_in_4bit: Optional[bool] = None, | |
| load_in_8bit: Optional[bool] = None, | |
| bnb_4bit_quant_type: Optional[str] = None, | |
| bnb_4bit_compute_dtype: Optional[str] = None, | |
| bnb_4bit_use_double_quant: Optional[bool] = None, | |
| # Optimization parameters | |
| flash_attention: Optional[bool] = None, | |
| sample_packing: Optional[bool] = None, | |
| bf16: Optional[bool] = None, | |
| fp16: Optional[bool] = None, | |
| tf32: Optional[bool] = None, | |
| # Saving and logging | |
| save_steps: Optional[int] = None, | |
| eval_steps: Optional[int] = None, | |
| logging_steps: Optional[int] = None, | |
| save_total_limit: Optional[int] = None, | |
| # Weights & Biases | |
| wandb_project: Optional[str] = None, | |
| wandb_entity: Optional[str] = None, | |
| wandb_watch: Optional[str] = None, | |
| # Early stopping | |
| early_stopping_patience: Optional[int] = None, | |
| # Dataset format parameters | |
| dataset_type: Optional[str] = None, | |
| field_messages: Optional[str] = None, | |
| field_instruction: Optional[str] = None, | |
| field_input: Optional[str] = None, | |
| field_output: Optional[str] = None, | |
| # Distributed training parameters | |
| nproc_per_node: Optional[Union[str, int]] = None, | |
| nnodes: Optional[int] = None, | |
| node_rank: Optional[int] = None, | |
| rdzv_id: Optional[Union[str, int]] = None, | |
| rdzv_endpoint: Optional[str] = None, | |
| master_addr: Optional[str] = None, | |
| master_port: Optional[int] = None, | |
| # Multi-GPU model splitting | |
| enable_model_splitting: Optional[bool] = None, | |
| **kwargs) -> Any: | |
| """Convenience function to run LoRA + SFT training. | |
| Args: | |
| model_path: Path to the model to fine-tune (local path or HuggingFace model ID) | |
| data_path: Path to the training data (JSON/JSONL format) | |
| ckpt_output_dir: Directory to save checkpoints and outputs | |
| backend: Backend implementation to use (default: "unsloth") | |
| LoRA Parameters: | |
| lora_r: LoRA rank (default: 16) | |
| lora_alpha: LoRA alpha parameter (default: 32) | |
| lora_dropout: LoRA dropout rate (default: 0.1) | |
| target_modules: List of module names to apply LoRA to (default: auto-detect) | |
| Training Parameters: | |
| num_epochs: Number of training epochs (default: 3) | |
| effective_batch_size: Effective batch size across all GPUs | |
| micro_batch_size: Batch size per GPU (default: 1) | |
| gradient_accumulation_steps: Steps to accumulate gradients (default: 1) | |
| learning_rate: Learning rate (default: 2e-4) | |
| max_seq_len: Maximum sequence length (default: 2048) | |
| lr_scheduler: Learning rate scheduler (default: 'cosine') | |
| warmup_steps: Number of warmup steps (default: 10) | |
| Quantization Parameters (QLoRA): | |
| load_in_4bit: Use 4-bit quantization for QLoRA | |
| load_in_8bit: Use 8-bit quantization | |
| bnb_4bit_quant_type: 4-bit quantization type (default: 'nf4') | |
| bnb_4bit_compute_dtype: Compute dtype for 4-bit (default: 'bfloat16') | |
| bnb_4bit_use_double_quant: Use double quantization (default: True) | |
| Optimization Parameters: | |
| flash_attention: Use Flash Attention for memory efficiency (default: True) | |
| sample_packing: Pack multiple samples per sequence (default: True) | |
| bf16: Use bfloat16 precision (default: True) | |
| fp16: Use float16 precision (default: False) | |
| tf32: Use TensorFloat-32 (default: True) | |
| Logging and Saving: | |
| save_steps: Steps between checkpoints (default: 500) | |
| eval_steps: Steps between evaluations (default: 500) | |
| logging_steps: Steps between log outputs (default: 10) | |
| save_total_limit: Maximum number of checkpoints to keep (default: 3) | |
| wandb_project: Weights & Biases project name | |
| wandb_entity: Weights & Biases entity name | |
| wandb_watch: What to watch in W&B ('gradients', 'all', etc.) | |
| early_stopping_patience: Early stopping patience (epochs) | |
| Distributed Training: | |
| nproc_per_node: Number of processes (GPUs) per node | |
| nnodes: Total number of nodes | |
| node_rank: Rank of this node (0 to nnodes-1) | |
| rdzv_id: Unique job ID for rendezvous | |
| rdzv_endpoint: Master node endpoint for multi-node training | |
| master_addr: Master node address for distributed training | |
| master_port: Master node port for distributed training | |
| Multi-GPU Configuration: | |
| enable_model_splitting: Enable device_map="balanced" for large models (default: False) | |
| Use for models that don't fit on a single GPU (e.g., Llama 70B) | |
| For smaller models, use standard DDP with torchrun instead | |
| Advanced: | |
| **kwargs: Additional parameters passed to the backend | |
| Returns: | |
| Dictionary containing trained model, tokenizer, and trainer | |
| Example: | |
| # Basic LoRA training | |
| result = lora( | |
| model_path="microsoft/DialoGPT-medium", | |
| data_path="./training_data.jsonl", | |
| ckpt_output_dir="./outputs", | |
| lora_r=16, | |
| lora_alpha=32, | |
| num_epochs=3, | |
| learning_rate=2e-4 | |
| ) | |
| # QLoRA with 4-bit quantization | |
| result = lora( | |
| model_path="meta-llama/Llama-2-7b-hf", | |
| data_path="./training_data.jsonl", | |
| ckpt_output_dir="./outputs", | |
| load_in_4bit=True, | |
| lora_r=64, | |
| lora_alpha=128, | |
| max_seq_len=4096 | |
| ) | |
| """ | |
| def lora_sft(model_path: str, | |
| data_path: str, | |
| ckpt_output_dir: str, | |
| backend: str = "unsloth", | |
| # LoRA-specific parameters | |
| lora_r: Optional[int] = None, | |
| lora_alpha: Optional[int] = None, | |
| lora_dropout: Optional[float] = None, | |
| target_modules: Optional[List[str]] = None, | |
| # Training parameters | |
| num_epochs: Optional[int] = None, | |
| effective_batch_size: Optional[int] = None, | |
| micro_batch_size: Optional[int] = None, | |
| gradient_accumulation_steps: Optional[int] = None, | |
| learning_rate: Optional[float] = None, | |
| max_seq_len: Optional[int] = None, | |
| lr_scheduler: Optional[str] = None, | |
| warmup_steps: Optional[int] = None, | |
| # Quantization parameters | |
| load_in_4bit: Optional[bool] = None, | |
| load_in_8bit: Optional[bool] = None, | |
| bnb_4bit_quant_type: Optional[str] = None, | |
| bnb_4bit_compute_dtype: Optional[str] = None, | |
| bnb_4bit_use_double_quant: Optional[bool] = None, | |
| # Optimization parameters | |
| flash_attention: Optional[bool] = None, | |
| sample_packing: Optional[bool] = None, | |
| bf16: Optional[bool] = None, | |
| fp16: Optional[bool] = None, | |
| tf32: Optional[bool] = None, | |
| # Saving and logging | |
| save_steps: Optional[int] = None, | |
| eval_steps: Optional[int] = None, | |
| logging_steps: Optional[int] = None, | |
| save_total_limit: Optional[int] = None, | |
| # Weights & Biases | |
| wandb_project: Optional[str] = None, | |
| wandb_entity: Optional[str] = None, | |
| wandb_watch: Optional[str] = None, | |
| # Early stopping | |
| early_stopping_patience: Optional[int] = None, | |
| # Dataset format parameters | |
| dataset_type: Optional[str] = None, | |
| field_messages: Optional[str] = None, | |
| field_instruction: Optional[str] = None, | |
| field_input: Optional[str] = None, | |
| field_output: Optional[str] = None, | |
| # Distributed training parameters | |
| nproc_per_node: Optional[Union[str, int]] = None, | |
| nnodes: Optional[int] = None, | |
| node_rank: Optional[int] = None, | |
| rdzv_id: Optional[Union[str, int]] = None, | |
| rdzv_endpoint: Optional[str] = None, | |
| master_addr: Optional[str] = None, | |
| master_port: Optional[int] = None, | |
| # Multi-GPU model splitting | |
| enable_model_splitting: Optional[bool] = None, | |
| **kwargs) -> Any: | |
| """Convenience function to run LoRA + SFT training. | |
| Args: | |
| model_path: Path to the model to fine-tune (local path or HuggingFace model ID) | |
| data_path: Path to the training data (JSON/JSONL format) | |
| ckpt_output_dir: Directory to save checkpoints and outputs | |
| backend: Backend implementation to use (default: "unsloth") | |
| LoRA Parameters: | |
| lora_r: LoRA rank (default: 16) | |
| lora_alpha: LoRA alpha parameter (default: 32) | |
| lora_dropout: LoRA dropout rate (default: 0.1) | |
| target_modules: List of module names to apply LoRA to (default: auto-detect) | |
| Training Parameters: | |
| num_epochs: Number of training epochs (default: 3) | |
| effective_batch_size: Effective batch size across all GPUs | |
| micro_batch_size: Batch size per GPU (default: 1) | |
| gradient_accumulation_steps: Steps to accumulate gradients (default: 1) | |
| learning_rate: Learning rate (default: 2e-4) | |
| max_seq_len: Maximum sequence length (default: 2048) | |
| lr_scheduler: Learning rate scheduler (default: 'cosine') | |
| warmup_steps: Number of warmup steps (default: 10) | |
| Quantization Parameters (QLoRA): | |
| load_in_4bit: Use 4-bit quantization for QLoRA | |
| load_in_8bit: Use 8-bit quantization | |
| bnb_4bit_quant_type: 4-bit quantization type (default: 'nf4') | |
| bnb_4bit_compute_dtype: Compute dtype for 4-bit (default: 'bfloat16') | |
| bnb_4bit_use_double_quant: Use double quantization (default: True) | |
| Optimization Parameters: | |
| flash_attention: Use Flash Attention for memory efficiency (default: True) | |
| sample_packing: Pack multiple samples per sequence (default: True) | |
| bf16: Use bfloat16 precision (default: True) | |
| fp16: Use float16 precision (default: False) | |
| tf32: Use TensorFloat-32 (default: True) | |
| Logging and Saving: | |
| save_steps: Steps between checkpoints (default: 500) | |
| eval_steps: Steps between evaluations (default: 500) | |
| logging_steps: Steps between log outputs (default: 10) | |
| save_total_limit: Maximum number of checkpoints to keep (default: 3) | |
| wandb_project: Weights & Biases project name | |
| wandb_entity: Weights & Biases entity name | |
| wandb_watch: What to watch in W&B ('gradients', 'all', etc.) | |
| early_stopping_patience: Early stopping patience (epochs) | |
| Distributed Training: | |
| nproc_per_node: Number of processes (GPUs) per node | |
| nnodes: Total number of nodes | |
| node_rank: Rank of this node (0 to nnodes-1) | |
| rdzv_id: Unique job ID for rendezvous | |
| rdzv_endpoint: Master node endpoint for multi-node training | |
| master_addr: Master node address for distributed training | |
| master_port: Master node port for distributed training | |
| Multi-GPU Configuration: | |
| enable_model_splitting: Enable device_map="balanced" for large models (default: False) | |
| Use for models that don't fit on a single GPU (e.g., Llama 70B) | |
| For smaller models, use standard DDP with torchrun instead | |
| Advanced: | |
| **kwargs: Additional parameters passed to the backend | |
| Returns: | |
| Dictionary containing trained model, tokenizer, and trainer | |
| Example: | |
| # Basic LoRA+SFT training | |
| result = lora_sft( | |
| model_path="microsoft/DialoGPT-medium", | |
| data_path="./training_data.jsonl", | |
| ckpt_output_dir="./outputs", | |
| lora_r=16, | |
| lora_alpha=32, | |
| num_epochs=3, | |
| learning_rate=2e-4 | |
| ) | |
| # QLoRA with 4-bit quantization | |
| result = lora_sft( | |
| model_path="meta-llama/Llama-2-7b-hf", | |
| data_path="./training_data.jsonl", | |
| ckpt_output_dir="./outputs", | |
| load_in_4bit=True, | |
| lora_r=64, | |
| lora_alpha=128, | |
| max_seq_len=4096 | |
| ) | |
| """ |
🤖 Prompt for AI Agents
In src/training_hub/algorithms/lora.py around lines 623 to 771, update the
docstring examples to call lora_sft(...) instead of lora(...), and add one short
sentence stating that additional LoRA/PEFT and training options (for example:
weight_decay, max_grad_norm, wandb_run_name, and advanced LoRA/PEFT knobs) can
be supplied via **kwargs to LoRASFTAlgorithm.train so callers know those
parameters exist even though they aren’t explicit in the signature.
RobotSail
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Partial review
| @@ -1,4 +1,7 @@ | |||
| from typing import override | |||
| try: | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you mind indicating what the comaptibility is targeting? This'll make it easier to remove when 3.11 is eventually dropped.
| try: | |
| # python3.11 compatibility, `override` was added to typing in 3.12 | |
| try: |
| lora_alpha=32 | ||
| ) | ||
| ``` | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does wandb need to be installed separately or does it come with one of the new dependencies? Iirc, wandb currently isn't installed automatically and usually we instead recommend users to install it separately.
| sample_packing=True, | ||
|
|
||
| # Weights & Biases logging | ||
| wandb_project="lora_training", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't checked the installation yet, but if wandb isn't automatically installed through one of the new deps like unsloth or xformers, then we'd need to make sure that users install it separately.
| print(f"❌ Custom field mapping training failed: {e}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Maxusmusti Would it make sense to make this script more interactive using something like typer or argparse? I see that a few of the functions here are commented out, and I'm guessing most of the time people will just run the samples without inspecting the source. It might help users see the full range of functionality if users can toggle these demos via python lora_example.py --qlora.
| # qlora_4bit_example() | ||
|
|
||
| # Uncomment to test distributed training (launch with torchrun for actual multi-GPU) | ||
| # distributed_lora_sft_example() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we make this script interactive, this could have a dedicated codepath, and we could recommend users to run it in the example guide via torchrun lora_example.py --distributed, blocking execution if we don't detect a torchrun environment.
| ```bash | ||
| pip install training-hub[cuda] | ||
| # or for development | ||
| pip install -e .[cuda] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Maxusmusti Would you mind adding this flag to the README? I think in most cases this flag will need to be present, unless the users get lucky with the prebuilt wheels.
| pip install -e .[cuda] | |
| pip install -e .[cuda] --no-build-isolation |
I can't suggest it, but the same change needs to be made on line 106.
Signed-off-by: Mustafa Eyceoz <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (6)
src/training_hub/__init__.py (1)
4-4: LGTM! The critical import issue has been resolved.The non-existent
AxolotlLoRABackendimport has been removed, fixing theImportErrorreported in the previous review. The current imports (lora_sft,LoRASFTAlgorithm,UnslothLoRABackend) match the implementations insrc/training_hub/algorithms/lora.py, and the__all__list correctly exposes these symbols.Also applies to: 15-15, 20-21
src/training_hub/algorithms/lora.py (5)
129-151: Advanced LoRA parameters are still hardcoded and not applied.The
use_rslora(line 147) andloftq_config(line 148) parameters are hardcoded, anduse_dora,rank_pattern,alpha_pattern, andinit_lora_weightsare not passed at all. These parameters are exposed inLoRASFTAlgorithm.train(lines 308-313) and advertised in the public API, but they have no effect.As confirmed in previous reviews, Unsloth's
FastLanguageModel.get_peft_modeldoes accept these parameters. Please pass them through fromparams:model = FastLanguageModel.get_peft_model( model, r=params.get('lora_r', 16), target_modules=target_modules, lora_alpha=params.get('lora_alpha', 32), lora_dropout=params.get('lora_dropout', 0.0), bias="none", use_gradient_checkpointing="unsloth", random_state=params.get('seed', 3407), - use_rslora=False, - loftq_config=None, + use_rslora=params.get('use_rslora', False), + use_dora=params.get('use_dora', False), + loftq_config=params.get('loftq_config'), + rank_pattern=params.get('rank_pattern'), + alpha_pattern=params.get('alpha_pattern'), + init_lora_weights=params.get('init_lora_weights', True), )
153-200: Dataset field names are hardcoded; customization parameters are ignored.The
field_messages,field_instruction,field_input, andfield_outputparameters (exposed at lines 346-349 and documented at lines 428-431) are accepted but never used. The code hardcodes'messages'(line 171),'instruction'(lines 186-187),'input'(line 188), and'output'(line 189), breaking datasets with different column names.Read the field names from
paramswith sensible defaults:def _prepare_dataset(self, params: Dict[str, Any], tokenizer) -> Any: """Prepare dataset for training.""" from datasets import load_dataset # Load dataset if params['data_path'].endswith('.jsonl') or params['data_path'].endswith('.json'): dataset = load_dataset('json', data_files=params['data_path'], split='train') else: dataset = load_dataset(params['data_path'], split='train') # Handle different dataset formats dataset_type = params.get('dataset_type', 'chat_template') + field_messages = params.get('field_messages', 'messages') + field_instruction = params.get('field_instruction', 'instruction') + field_input = params.get('field_input', 'input') + field_output = params.get('field_output', 'output') if dataset_type == 'chat_template': # Convert messages format using chat template def format_chat_template(examples): - # examples['messages'] is a list of conversations (batched) + # examples[field_messages] is a list of conversations (batched) texts = [] - for conversation in examples['messages']: + for conversation in examples[field_messages]: text = tokenizer.apply_chat_template( conversation, tokenize=False, add_generation_prompt=False ) texts.append(text) return {"text": texts} dataset = dataset.map(format_chat_template, batched=True) elif dataset_type == 'alpaca': # Convert alpaca format to text def format_alpaca(examples): texts = [] - for i in range(len(examples['instruction'])): - instruction = examples['instruction'][i] - input_text = examples.get('input', [''] * len(examples['instruction']))[i] - output = examples['output'][i] + n = len(examples[field_instruction]) + for i in range(n): + instruction = examples[field_instruction][i] + input_text = examples.get(field_input, [''] * n)[i] + output = examples[field_output][i] if input_text: text = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}" else: text = f"### Instruction:\n{instruction}\n\n### Response:\n{output}" texts.append(text) return {"text": texts} dataset = dataset.map(format_alpaca, batched=True) return dataset
203-275:wandb_watchandearly_stopping_patienceparameters are accepted but not used.While
wandb_projectandwandb_run_nameare wired through (lines 271-272), thewandb_watchparameter (defined at line 340, documented at line 423) is accepted but never applied. Similarly,early_stopping_patience(line 343, documented at line 424) is never converted into anEarlyStoppingCallback.Both features have first-class support in the underlying libraries (TRL accepts
EarlyStoppingCallback, W&B supportswandb.watchviaWandbCallback), but the current implementation silently ignores them, potentially surprising users.Consider either implementing these features or documenting them as reserved/not-yet-implemented to set correct expectations.
297-297:max_tokens_per_gpuparameter is exposed but not used.The
max_tokens_per_gpuparameter is defined in thetrainmethod signature (line 297), included inoptional_params(line 467), listed inget_optional_params(line 560), and documented (line 376), but it's never read or used in the backend's batch sizing, truncation, or packing logic.Either wire this parameter through to control batch packing/sizing behavior, or remove it from the API surface to avoid misleading users into thinking it has an effect.
Also applies to: 467-467, 560-560
754-775: Examples calllora(...)instead oflora_sft(...).The docstring examples (lines 756 and 767) call
lora(...)instead oflora_sft(...), which will cause aNameErrorif users copy the code verbatim.Example: - # Basic LoRA training - result = lora( + # Basic LoRA+SFT training + result = lora_sft( model_path="microsoft/DialoGPT-medium", data_path="./training_data.jsonl", ckpt_output_dir="./outputs", lora_r=16, lora_alpha=32, num_epochs=3, learning_rate=2e-4 ) # QLoRA with 4-bit quantization - result = lora( + result = lora_sft( model_path="meta-llama/Llama-2-7b-hf", data_path="./training_data.jsonl", ckpt_output_dir="./outputs", load_in_4bit=True, lora_r=64, lora_alpha=128, max_seq_len=4096 )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
pyproject.toml(2 hunks)src/training_hub/__init__.py(1 hunks)src/training_hub/algorithms/lora.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/training_hub/__init__.py (1)
src/training_hub/algorithms/lora.py (3)
lora_sft(628-828)LoRASFTAlgorithm(280-620)UnslothLoRABackend(11-275)
src/training_hub/algorithms/lora.py (3)
src/training_hub/algorithms/__init__.py (6)
Algorithm(6-22)Backend(24-30)AlgorithmRegistry(33-79)register_algorithm(40-44)register_backend(47-51)create_algorithm(82-96)src/training_hub/algorithms/sft.py (2)
sft(169-248)SFTAlgorithm(40-161)src/training_hub/algorithms/peft_extender.py (7)
LoRAPEFTExtender(26-90)get_lora_parameters(93-96)apply_lora_defaults(99-102)apply_peft_config(21-23)apply_peft_config(54-90)get_peft_params(16-18)get_peft_params(29-52)
🪛 Ruff (0.14.4)
src/training_hub/algorithms/lora.py
24-27: Avoid specifying long messages outside the exception class
(TRY003)
29-32: Avoid specifying long messages outside the exception class
(TRY003)
34-37: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (3)
src/training_hub/algorithms/lora.py (2)
14-82: LGTM! Import error handling is clear and the torchrun parameter issue has been resolved.The import error messages provide helpful installation instructions, and the code now correctly uses
algorithm_paramsdirectly without stripping torchrun parameters. The comments clearly explain that torchrun parameters are handled by the launcher, not the Python training code.
84-127: LGTM! Quantization parameters are now properly wired through.The previously unused
load_in_8bitparameter is now passed toFastLanguageModel.from_pretrained(line 118), and thebnb_4bit_*parameters are correctly read fromparamsand applied when 4-bit quantization is enabled (lines 106-111). The redundantosimport issue has also been resolved.pyproject.toml (1)
11-11: License format modernization approved.The license field change from
license = {text = "Apache-2.0"}tolicense = "Apache-2.0"aligns with modern PEP 639 conventions and is more concise.
Signed-off-by: Mustafa Eyceoz <[email protected]>
lora_sft)Summary by CodeRabbit
New Features
Examples
Documentation
Chores