Skip to content

Conversation

@Maxusmusti
Copy link
Contributor

@Maxusmusti Maxusmusti commented Nov 13, 2025

  • Adding new algorithm LoRASFTAlgorithm (lora_sft)
  • Adding external PEFTExtender abstract (and LoRAPEFTExtender class) to attach to existing algorithms for parameter-efficient versions
  • Adding Unsloth as a backend implementation for LoRA / QLoRA

Summary by CodeRabbit

  • New Features

    • Added Low-Rank Adaptation (LoRA) + SFT training with 4-bit QLoRA, distributed multi-GPU support, and memory-efficient options.
  • Examples

    • New example workflows for basic LoRA+SFT, QLoRA (4-bit), distributed training, and data-format demos with sample data.
  • Documentation

    • Comprehensive LoRA+SFT guide: quick start, configs, launch instructions, memory tips, and troubleshooting.
  • Chores

    • New LoRA installation extras and detailed installation/CUDA/build guidance.

@Maxusmusti Maxusmusti self-assigned this Nov 13, 2025
@coderabbitai
Copy link

coderabbitai bot commented Nov 13, 2025

Walkthrough

Adds LoRA + SFT support: new Unsloth backend, LoRA algorithm and PEFT extender modules, example scripts and docs, an optional lora dependency group in pyproject, public API exports for LoRA components, and a small typing import fallback in the memory estimator.

Changes

Cohort / File(s) Summary
Documentation
README.md, examples/README.md, examples/docs/lora_usage.md
Add LoRA + SFT docs and examples, installation notes (LoRA extras, Unsloth, xformers), QLoRA/4-bit guidance, multi-GPU launch instructions, and troubleshooting.
Examples
examples/scripts/lora_example.py
New example script with basic LoRA+SFT, QLoRA (4-bit), distributed training demos, dataset format helpers, and sample-data creation functions.
PEFT Abstraction
src/training_hub/algorithms/peft_extender.py
New PEFTExtender base and LoRAPEFTExtender implementation; functions to return LoRA parameter defs and apply defaults/metadata.
LoRA Core
src/training_hub/algorithms/lora.py
New UnslothLoRABackend, LoRASFTAlgorithm, and lora_sft() factory; model loading (including 4/8-bit quant), LoRA application, dataset preparation, training args builder, trainer invocation, and registry/backend registration.
Packaging / API
pyproject.toml, src/training_hub/__init__.py
Add [project.optional-dependencies].lora with unsloth, trl, xformers; normalize license field; export lora_sft, LoRASFTAlgorithm, UnslothLoRABackend, and ensure estimator exports.
Compatibility Fix
src/training_hub/profiling/memory_estimator.py
Wrap typing.override import in try/except to fall back to typing_extensions.override.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant API as lora_sft()
    participant Alg as LoRASFTAlgorithm
    participant Ext as LoRAPEFTExtender
    participant Backend as UnslothLoRABackend
    participant Unsloth
    participant Data as Dataset Prep
    participant Trainer as SFT Trainer

    User->>API: call(model_path, data_path, params)
    API->>Alg: instantiate(params)
    Alg->>Ext: apply PEFT defaults
    Alg->>Backend: train(params)
    Backend->>Unsloth: load model (opt: 4/8-bit, device_map)
    Unsloth-->>Backend: return model & tokenizer
    Backend->>Backend: apply LoRA config
    Backend->>Data: prepare dataset (messages / Alpaca mapping)
    Data-->>Backend: dataset ready
    Backend->>Trainer: build args & start training
    Trainer-->>Backend: training complete
    Backend->>Backend: save checkpoint & tokenizer
    Backend-->>Alg: return training result
    Alg-->>API: return artifact
    API-->>User: done
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

  • Areas needing extra attention:
    • src/training_hub/algorithms/lora.py (model loading, quantization flags, device_map/distributed logic, import error handling)
    • src/training_hub/algorithms/peft_extender.py (defaults merging and parameter shapes)
    • pyproject.toml (optional-dependencies changes)
    • Registry and public API exports in src/training_hub/__init__.py

Possibly related PRs

Poem

🐇 Hopping in with LoRA light,
I stitch small ranks through day and night,
Unsloth helps me learn with speed,
Tiny weights do mighty deeds,
Fine-tuned and ready — hop to flight! 🚀

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and directly summarizes the main change: adding Unsloth support for LoRA/QLoRA fine-tuning with SFT, which aligns with the PR's core objectives and the majority of code changes.
Docstring Coverage ✅ Passed Docstring coverage is 86.96% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch lora-check

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 58c87a3 and 7f0d4c3.

📒 Files selected for processing (1)
  • pyproject.toml (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • pyproject.toml

Tip

📝 Customizable high-level summaries are now available in beta!

You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.

  • Provide your own instructions using the high_level_summary_instructions setting.
  • Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
  • Use high_level_summary_in_walkthrough to move the summary from the description to the walkthrough section.

Example instruction:

"Divide the high-level summary into five sections:

  1. 📝 Description — Summarize the main change in 50–60 words, explaining why this PR is needed, why this solution was chosen, and what was done.
  2. 📓 References — List relevant issues, discussions, documentation, or related PRs.
  3. 📦 Dependencies & Requirements — Mention any new/updated dependencies, environment variable changes, or configuration updates.
  4. 📊 Contributor Summary — Include a Markdown table showing contributions:
    | Contributor | Lines Added | Lines Removed | Files Changed |
  5. ✔️ Additional Notes — Add any extra reviewer context.
    Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."

Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@Maxusmusti Maxusmusti marked this pull request as ready for review November 14, 2025 16:28
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 11

🧹 Nitpick comments (4)
examples/scripts/lora_example.py (1)

1-1: Consider making the script executable.

The shebang is present but the file is not marked executable. Consider running chmod +x examples/scripts/lora_example.py to match the shebang.

src/training_hub/algorithms/peft_extender.py (1)

106-111: Consider sorting __all__ alphabetically.

While the current logical grouping (base class, implementation, functions) is reasonable, alphabetical sorting would align with common Python conventions and satisfy the linter.

 __all__ = [
-    'PEFTExtender',
     'LoRAPEFTExtender',
+    'PEFTExtender',
+    'apply_lora_defaults',
     'get_lora_parameters',
-    'apply_lora_defaults'
 ]
src/training_hub/algorithms/lora.py (2)

199-236: Effective batch size computation can overshoot the requested value

In _build_training_args, when effective_batch_size is provided you compute:

micro_batch_size = effective_batch_size // (gradient_accumulation_steps * num_gpus)
micro_batch_size = max(1, micro_batch_size)

If effective_batch_size < gradient_accumulation_steps * num_gpus, this yields micro_batch_size = 0 → 1, so the actual effective batch becomes gradient_accumulation_steps * num_gpus, which is larger than the requested effective_batch_size.

This may be acceptable, but it’s worth either:

  • Documenting that effective_batch_size is a lower bound, or
  • Adjusting gradient_accumulation_steps (or warning) when the requested effective batch cannot be met exactly.
-        if effective_batch_size is not None:
-            # Calculate micro_batch_size from effective_batch_size
-            micro_batch_size = effective_batch_size // (gradient_accumulation_steps * num_gpus)
-            micro_batch_size = max(1, micro_batch_size)
+        if effective_batch_size is not None:
+            # Calculate micro_batch_size from effective_batch_size
+            denom = gradient_accumulation_steps * num_gpus
+            micro_batch_size = max(1, effective_batch_size // max(1, denom))
+            if micro_batch_size * denom != effective_batch_size:
+                # Optional: log or warn about mismatch instead of silently overshooting.
+                pass

540-615: Optional param metadata omits save_model, which the backend reads

UnslothLoRABackend.execute_training checks training_params.get('save_model', True) before saving, but LoRASFTAlgorithm.get_optional_params does not list save_model.

If callers rely on get_optional_params() to drive validation or CLI argument exposure, they won’t discover that save_model is supported.

Consider adding save_model: bool to the optional parameter definitions.

         extended_params = {
@@
             'enable_model_splitting': bool,
+            'save_model': bool,
         }
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ce5903a and 586b781.

📒 Files selected for processing (9)
  • README.md (3 hunks)
  • examples/README.md (1 hunks)
  • examples/docs/lora_usage.md (1 hunks)
  • examples/scripts/lora_example.py (1 hunks)
  • pyproject.toml (2 hunks)
  • src/training_hub/__init__.py (1 hunks)
  • src/training_hub/algorithms/lora.py (1 hunks)
  • src/training_hub/algorithms/peft_extender.py (1 hunks)
  • src/training_hub/profiling/memory_estimator.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/training_hub/__init__.py (1)
src/training_hub/algorithms/lora.py (3)
  • lora_sft (623-823)
  • LoRASFTAlgorithm (277-615)
  • UnslothLoRABackend (11-272)
src/training_hub/algorithms/lora.py (3)
src/training_hub/algorithms/__init__.py (6)
  • Algorithm (6-22)
  • Backend (24-30)
  • AlgorithmRegistry (33-79)
  • register_algorithm (40-44)
  • register_backend (47-51)
  • create_algorithm (82-96)
src/training_hub/algorithms/sft.py (2)
  • sft (169-248)
  • SFTAlgorithm (40-161)
src/training_hub/algorithms/peft_extender.py (7)
  • LoRAPEFTExtender (26-90)
  • get_lora_parameters (93-96)
  • apply_lora_defaults (99-102)
  • apply_peft_config (21-23)
  • apply_peft_config (54-90)
  • get_peft_params (16-18)
  • get_peft_params (29-52)
examples/scripts/lora_example.py (1)
src/training_hub/algorithms/lora.py (1)
  • lora_sft (623-823)
🪛 LanguageTool
examples/docs/lora_usage.md

[style] ~186-~186: As an alternative to the over-used intensifier ‘very’, consider replacing this phrase.
Context: ...proc_per_node * micro_batch_size- For very large models, tryenable_model_splitting=Tru...

(EN_WEAK_ADJECTIVE)

🪛 Ruff (0.14.4)
src/training_hub/algorithms/peft_extender.py

106-111: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

src/training_hub/algorithms/lora.py

24-27: Avoid specifying long messages outside the exception class

(TRY003)


29-32: Avoid specifying long messages outside the exception class

(TRY003)


34-37: Avoid specifying long messages outside the exception class

(TRY003)


43-43: Local variable torchrun_params is assigned to but never used

Remove assignment to unused variable torchrun_params

(F841)


64-64: Local variable dataset_type is assigned to but never used

Remove assignment to unused variable dataset_type

(F841)


96-96: Local variable load_in_8bit is assigned to but never used

Remove assignment to unused variable load_in_8bit

(F841)


100-100: Redefinition of unused os from line 1

Remove definition: os

(F811)


208-208: Redefinition of unused os from line 1

Remove definition: os

(F811)

examples/scripts/lora_example.py

1-1: Shebang is present but file is not executable

(EXE001)


45-45: Do not catch blind exception: Exception

(BLE001)


89-89: Do not catch blind exception: Exception

(BLE001)


102-102: String contains ambiguous (INFORMATION SOURCE). Did you mean i (LATIN SMALL LETTER I)?

(RUF001)


105-105: Local variable result is assigned to but never used

Remove assignment to unused variable result

(F841)


136-136: Do not catch blind exception: Exception

(BLE001)


221-221: Do not catch blind exception: Exception

(BLE001)


240-240: Do not catch blind exception: Exception

(BLE001)


246-246: Local variable result is assigned to but never used

Remove assignment to unused variable result

(F841)


263-263: Do not catch blind exception: Exception

(BLE001)

🔇 Additional comments (20)
src/training_hub/profiling/memory_estimator.py (1)

1-4: LGTM! Good compatibility improvement.

The try/except fallback pattern correctly handles Python versions before 3.12 where typing.override was introduced.

examples/README.md (1)

84-112: LGTM! Documentation is clear and comprehensive.

The new LoRA + SFT section is well-structured and highlights the important difference that multi-GPU training requires torchrun (unlike SFT/OSFT). The quick example demonstrates the key parameters effectively.

README.md (3)

10-10: LGTM! Support matrix accurately reflects the new LoRA + SFT capability.


66-88: LGTM! Clear and informative documentation.

The LoRA + SFT section effectively communicates the key features and provides a practical code example with appropriate parameters.


111-120: LGTM! Installation instructions are clear.

The LoRA installation section provides straightforward guidance and helpful context about the included dependencies.

examples/docs/lora_usage.md (5)

1-41: LGTM! Excellent quick start documentation.

The introduction clearly explains LoRA benefits and provides a complete working example. The explicit differentiation between single-GPU (Python) and multi-GPU (torchrun) launch commands is very helpful.


54-75: LGTM! Parameter documentation is thorough and practical.

The core LoRA settings and QLoRA example provide clear guidance on parameter selection with helpful context on when to use different values.


77-102: LGTM! Dataset format examples are clear and accurate.

The JSON examples for both Messages and Alpaca formats are correct and well-presented. The memory benefits section appropriately notes that savings depend on configuration.


141-166: LGTM! Advanced configuration examples are helpful.

The custom target modules and W&B integration examples provide practical guidance for common customization needs.


176-190: No issues found—parameter is valid and properly documented.

The enable_model_splitting parameter is a confirmed, supported option in the LoRA implementation. It appears in the function signatures with documentation describing it as enabling device_map="balanced" for large models (default: False). The troubleshooting guidance on line 186 is accurate and correctly advises using this parameter for very large models.

src/training_hub/__init__.py (1)

8-28: LGTM! Public API exports are clean and consistent.

The additions to __all__ properly expose the new LoRA functionality and follow the existing pattern used for SFT and OSFT algorithms.

pyproject.toml (1)

11-11: LGTM! License format simplification.

The change from dictionary to string format is valid and follows the more common convention.

examples/scripts/lora_example.py (5)

11-46: LGTM! Basic example is clear and practical.

The function provides a good starting point with a small model and reasonable default parameters. The broad exception handling is acceptable for example code.


51-90: LGTM! QLoRA example demonstrates appropriate parameter adjustments.

The function correctly shows how to configure LoRA for 4-bit quantization with higher rank and lower learning rate, which are sensible defaults for quantized training.


93-137: LGTM! Distributed training example is well-documented.

The function clearly explains the torchrun requirement and demonstrates appropriate parameters for multi-GPU training. The unused result variable is acceptable in example code.


140-199: LGTM! Sample data creation is helpful and well-structured.

The function creates appropriate sample datasets in both Messages and Alpaca formats, providing useful test data for users trying the examples.


202-299: LGTM! Comprehensive demonstration of dataset formats and usage.

The data_format_examples function clearly shows how to work with different dataset formats, and the main block provides excellent usage instructions and benefits summary. The unused variables and broad exception handling are acceptable for example code.

src/training_hub/algorithms/peft_extender.py (3)

12-23: LGTM! Clean abstraction design.

The PEFTExtender ABC provides a clear interface for adding PEFT functionality to algorithms. The two abstract methods establish a good contract for parameter discovery and configuration.


26-90: LGTM! Comprehensive LoRA implementation.

The LoRAPEFTExtender provides thorough coverage of LoRA parameters including quantization and advanced options. The default values are sensible (especially dropout=0.0 for Unsloth optimization), and the metadata tagging (peft_enabled, peft_type) is helpful for downstream processing.


93-102: LGTM! Useful convenience functions.

The wrapper functions provide a clean API for common operations without requiring users to instantiate the extender directly.

Comment on lines +149 to +196
def _prepare_dataset(self, params: Dict[str, Any], tokenizer) -> Any:
"""Prepare dataset for training."""
from datasets import load_dataset

# Load dataset
if params['data_path'].endswith('.jsonl') or params['data_path'].endswith('.json'):
dataset = load_dataset('json', data_files=params['data_path'], split='train')
else:
dataset = load_dataset(params['data_path'], split='train')

# Handle different dataset formats
dataset_type = params.get('dataset_type', 'chat_template')

if dataset_type == 'chat_template':
# Convert messages format using chat template
def format_chat_template(examples):
# examples['messages'] is a list of conversations (batched)
texts = []
for conversation in examples['messages']:
text = tokenizer.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=False
)
texts.append(text)
return {"text": texts}

dataset = dataset.map(format_chat_template, batched=True)

elif dataset_type == 'alpaca':
# Convert alpaca format to text
def format_alpaca(examples):
texts = []
for i in range(len(examples['instruction'])):
instruction = examples['instruction'][i]
input_text = examples.get('input', [''] * len(examples['instruction']))[i]
output = examples['output'][i]

if input_text:
text = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}"
else:
text = f"### Instruction:\n{instruction}\n\n### Response:\n{output}"
texts.append(text)
return {"text": texts}

dataset = dataset.map(format_alpaca, batched=True)

return dataset
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Dataset field customisation is ignored; only hard-coded columns are used

LoRASFTAlgorithm.train exposes dataset_type, field_messages, field_instruction, field_input, and field_output, but _prepare_dataset:

  • Always assumes messages for chat data and instruction / input / output for Alpaca, ignoring the field_* parameters.
  • Only implements two dataset types (chat_template and alpaca) even though docstrings mention other formats like input_output.
  • Recomputes dataset_type here, while the earlier dataset_type variable in execute_training is set but never used.

This will break for datasets that use different column names and makes the field_* parameters misleading.

I’d suggest:

  • Respecting the field_* names from params with sensible defaults, and
  • Explicitly validating/handling only the dataset types you actually support.
-        # Handle different dataset formats
-        dataset_type = params.get('dataset_type', 'chat_template')
+        # Handle different dataset formats
+        dataset_type = params.get('dataset_type', 'chat_template')
+        field_messages = params.get('field_messages', 'messages')
+        field_instruction = params.get('field_instruction', 'instruction')
+        field_input = params.get('field_input', 'input')
+        field_output = params.get('field_output', 'output')
@@
-        if dataset_type == 'chat_template':
+        if dataset_type == 'chat_template':
@@
-                for conversation in examples['messages']:
+                for conversation in examples[field_messages]:
@@
-        elif dataset_type == 'alpaca':
+        elif dataset_type == 'alpaca':
@@
-                for i in range(len(examples['instruction'])):
-                    instruction = examples['instruction'][i]
-                    input_text = examples.get('input', [''] * len(examples['instruction']))[i]
-                    output = examples['output'][i]
+                n = len(examples[field_instruction])
+                for i in range(n):
+                    instruction = examples[field_instruction][i]
+                    input_text = examples.get(field_input, [''] * n)[i]
+                    output = examples[field_output][i]
+                    ...

You might also want to raise a clear error for unsupported dataset_type values instead of silently returning the raw dataset.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _prepare_dataset(self, params: Dict[str, Any], tokenizer) -> Any:
"""Prepare dataset for training."""
from datasets import load_dataset
# Load dataset
if params['data_path'].endswith('.jsonl') or params['data_path'].endswith('.json'):
dataset = load_dataset('json', data_files=params['data_path'], split='train')
else:
dataset = load_dataset(params['data_path'], split='train')
# Handle different dataset formats
dataset_type = params.get('dataset_type', 'chat_template')
if dataset_type == 'chat_template':
# Convert messages format using chat template
def format_chat_template(examples):
# examples['messages'] is a list of conversations (batched)
texts = []
for conversation in examples['messages']:
text = tokenizer.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=False
)
texts.append(text)
return {"text": texts}
dataset = dataset.map(format_chat_template, batched=True)
elif dataset_type == 'alpaca':
# Convert alpaca format to text
def format_alpaca(examples):
texts = []
for i in range(len(examples['instruction'])):
instruction = examples['instruction'][i]
input_text = examples.get('input', [''] * len(examples['instruction']))[i]
output = examples['output'][i]
if input_text:
text = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}"
else:
text = f"### Instruction:\n{instruction}\n\n### Response:\n{output}"
texts.append(text)
return {"text": texts}
dataset = dataset.map(format_alpaca, batched=True)
return dataset
def _prepare_dataset(self, params: Dict[str, Any], tokenizer) -> Any:
"""Prepare dataset for training."""
from datasets import load_dataset
# Load dataset
if params['data_path'].endswith('.jsonl') or params['data_path'].endswith('.json'):
dataset = load_dataset('json', data_files=params['data_path'], split='train')
else:
dataset = load_dataset(params['data_path'], split='train')
# Handle different dataset formats
dataset_type = params.get('dataset_type', 'chat_template')
field_messages = params.get('field_messages', 'messages')
field_instruction = params.get('field_instruction', 'instruction')
field_input = params.get('field_input', 'input')
field_output = params.get('field_output', 'output')
if dataset_type == 'chat_template':
# Convert messages format using chat template
def format_chat_template(examples):
# examples['messages'] is a list of conversations (batched)
texts = []
for conversation in examples[field_messages]:
text = tokenizer.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=False
)
texts.append(text)
return {"text": texts}
dataset = dataset.map(format_chat_template, batched=True)
elif dataset_type == 'alpaca':
# Convert alpaca format to text
def format_alpaca(examples):
texts = []
n = len(examples[field_instruction])
for i in range(n):
instruction = examples[field_instruction][i]
input_text = examples.get(field_input, [''] * n)[i]
output = examples[field_output][i]
if input_text:
text = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}"
else:
text = f"### Instruction:\n{instruction}\n\n### Response:\n{output}"
texts.append(text)
return {"text": texts}
dataset = dataset.map(format_alpaca, batched=True)
return dataset
🤖 Prompt for AI Agents
In src/training_hub/algorithms/lora.py around lines 149-196, _prepare_dataset
currently hardcodes column names and dataset types (expects 'messages' for chat
and 'instruction'/'input'/'output' for alpaca), ignores
params['field_messages']/['field_instruction']/['field_input']/['field_output'],
and recomputes dataset_type instead of using the value passed in params; update
the function to read dataset_type = params.get('dataset_type', 'chat_template')
once and then select behavior accordingly, use field names from params with
sensible defaults (e.g., field_messages='messages',
field_instruction='instruction', field_input='input', field_output='output')
when extracting columns inside formatters, explicitly handle only supported
dataset_type values (chat_template, alpaca, input_output, etc.) and raise a
clear ValueError for unsupported types, and return the mapped dataset; also
ensure batching logic uses the provided field names and preserves existing
tokenize/add_generation_prompt behavior for chat_template.

Comment on lines +470 to +487
# LoRA parameters (from PEFT extender)
'lora_r': lora_r,
'lora_alpha': lora_alpha,
'lora_dropout': lora_dropout,
'target_modules': target_modules,
'use_rslora': use_rslora,
'use_dora': use_dora,
'init_lora_weights': init_lora_weights,
'rank_pattern': rank_pattern,
'alpha_pattern': alpha_pattern,
'loftq_config': loftq_config,
# Quantization parameters (QLoRA)
'load_in_4bit': load_in_4bit,
'load_in_8bit': load_in_8bit,
'bnb_4bit_quant_type': bnb_4bit_quant_type,
'bnb_4bit_compute_dtype': bnb_4bit_compute_dtype,
'bnb_4bit_use_double_quant': bnb_4bit_use_double_quant,
# Extended training parameters
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

Several advertised knobs are currently unused in this backend

The algorithm and PEFT extender expose a broad set of knobs:

  • QLoRA-related: bnb_4bit_quant_type, bnb_4bit_compute_dtype, bnb_4bit_use_double_quant
  • Training/meta: early_stopping_patience, wandb_watch
  • Multi-GPU: enable_model_splitting (used), plus advanced LoRA options discussed earlier.

However, in the Unsloth backend:

  • The bnb_4bit_* fields are never read.
  • early_stopping_patience is never converted into a callback or trainer argument.
  • wandb_watch is never used with W&B APIs.
  • max_tokens_per_gpu is accepted and surfaced via get_optional_params but not used in batch sizing or truncation logic.

This is not strictly a bug, but the API surface suggests these parameters have an effect when they currently don’t.

I’d recommend either wiring them into the backend (where supported by TRL/Unsloth/W&B) or explicitly documenting them as “reserved/unimplemented” for now to reduce surprises.

Also applies to: 576-607


🌐 Web query:

Check which of the following SFT/TRL/Unsloth features have first-class support and the expected integration points: QLoRA `bnb_4bit_*` options, `max_tokens_per_gpu`-style sequence length limits, early stopping callbacks, and W&B `wandb.watch` integration.

💡 Result:

Summary — first-class support / integration points

  • QLoRA bnb_4bit_* options (bnb_4bit_quant_type, bnb_4bit_use_double_quant, bnb_4bit_compute_dtype, etc.): First‑class support via BitsAndBytesConfig used by Transformers/QLoRA workflows (used in AutoModelForCausalLM.from_pretrained and many examples). Use BitsAndBytesConfig / quantization_config to pass bnb_4bit_* flags. [1][2]

  • max_tokens_per_gpu‑style per‑GPU token limits: Supported by several training stacks (Megatron-style / Slime / related training frameworks) as a first‑class runtime parameter for dynamic/batched packing and OOM control (often used with --use-dynamic-batch-size). Not a single HF‑Transformers standard name, but widely supported in production training frameworks. Cite Slime / examples that document max_tokens_per_gpu semantics. [3][4]

  • Early stopping callbacks: First‑class support via Transformers callbacks (e.g., transformers.EarlyStoppingCallback) and TRL’s SFTTrainer accepts standard callbacks — you can pass EarlyStoppingCallback to SFTTrainer. [5]

  • Weights & Biases wandb.watch integration: Supported — Hugging Face provides WandbCallback/W&B integration (Trainer integrations and env vars like WANDB_WATCH); many SFT/TRL workflows also call wandb.init / wandb.watch in trainer setup/examples. Use Transformers’ WandbCallback or call wandb.init/wandb.watch yourself inside trainer code. [6][7]

Sources

  1. QLoRA / BitsAndBytes examples and BitsAndBytesConfig usage. [1]
  2. vLLM / BitsAndBytesConfig source and QLoRA examples showing bnb_4bit_* options. [2]
  3. Slime / Megatron-style docs describing max_tokens_per_gpu behavior for dynamic batching and OOM control. [3]
  4. Training-hub / examples showing max_tokens_per_gpu in training configs. [4]
  5. TRL SFTTrainer example using transformers.EarlyStoppingCallback. [5]
  6. Hugging Face transformers callbacks / WandbCallback documentation (WANDB_WATCH etc.). [6]
  7. Example trainers calling wandb.watch (project examples / community trainers). [7]

If you want, I can open the specific repo/docs you use (SFT/TRL/Unsloth code) and verify the exact API names and where to pass each option.


I need to inspect the actual backend code to verify whether these parameters are truly unused. Let me examine the file and trace parameter flow.


🏁 Script executed:

#!/bin/bash
# Inspect the target lines and surrounding context
echo "=== Lines 470-487 context ===" 
sed -n '460,500p' src/training_hub/algorithms/lora.py

echo -e "\n=== Lines 576-607 context ==="
sed -n '570,620p' src/training_hub/algorithms/lora.py

Length of output: 3553


🏁 Script executed:

#!/bin/bash
# Search for where parameters are initialized/defined
echo "=== Parameter definitions ===" 
rg -n "early_stopping_patience|wandb_watch|bnb_4bit_quant_type|bnb_4bit_compute_dtype|bnb_4bit_use_double_quant|max_tokens_per_gpu" src/training_hub/algorithms/lora.py -C 2

Length of output: 5711


🏁 Script executed:

#!/bin/bash
# Search for Unsloth backend implementation
echo "=== Unsloth backend references ===" 
rg -n "unsloth|Unsloth" src/training_hub/algorithms/lora.py -C 3

Length of output: 4907


🏁 Script executed:

#!/bin/bash
# Search for SFTTrainer initialization and parameter passing
echo "=== SFTTrainer setup ===" 
rg -n "SFTTrainer|from_pretrained" src/training_hub/algorithms/lora.py -C 5

Length of output: 1907


Let me trace where these specific parameters are actually consumed in the backend:


---

🏁 Script executed:

```shell
#!/bin/bash
# Examine _build_training_args full implementation
echo "=== _build_training_args implementation ===" 
sed -n '199,280p' src/training_hub/algorithms/lora.py

Length of output: 3636


🏁 Script executed:

#!/bin/bash
# Examine _apply_lora_config full implementation
echo "=== _apply_lora_config implementation ===" 
sed -n '125,180p' src/training_hub/algorithms/lora.py

Length of output: 2376


🏁 Script executed:

#!/bin/bash
# Search for BitsAndBytesConfig usage
echo "=== BitsAndBytesConfig usage ===" 
rg -n "BitsAndBytesConfig" src/training_hub/algorithms/lora.py -C 3

Length of output: 32


🏁 Script executed:

#!/bin/bash
# Search for callback creation/EarlyStoppingCallback/wandb callbacks
echo "=== Callback usage ===" 
rg -n "Callback|callback|wandb|WandbCallback" src/training_hub/algorithms/lora.py -C 2

Length of output: 2846


🏁 Script executed:

#!/bin/bash
# Check what happens with max_tokens_per_gpu, early_stopping_patience, wandb_watch
echo "=== Parameter consumption trace ===" 
rg -n "max_tokens_per_gpu|get\('early_stopping_patience|get\('wandb_watch" src/training_hub/algorithms/lora.py

Length of output: 271


Wire unused parameters into backend or explicitly document as reserved

Verification confirms the review comment is accurate. The backend accepts these parameters but does not use them:

  • bnb_4bit_quant_type, bnb_4bit_compute_dtype, bnb_4bit_use_double_quant: Defined and passed to algorithm_params, but never read. Only load_in_4bit is used in _load_unsloth_model (line 115); no BitsAndBytesConfig is created.

  • early_stopping_patience: Defined and passed, but never converted to a callback or passed to SFTConfig. No EarlyStoppingCallback instantiation anywhere.

  • wandb_watch: Defined and passed, but only wandb_project and wandb_run_name are used in SFTConfig (lines 268–269). No wandb.watch() call or WandbCallback with watch parameter.

  • max_tokens_per_gpu: Defined and passed (line 464), but never read or used in batch sizing/truncation logic.

Additionally, in _apply_lora_config, use_rslora (line 143) and loftq_config (line 144) are hardcoded instead of reading from params, suggesting a broader pattern of parameter passthrough without integration.

These parameters do have first-class support in the underlying libraries (Transformers/TRL/Unsloth/W&B). Either integrate them into the backend (as suggested in the review) or explicitly mark them as reserved/unimplemented to set user expectations correctly.

Comment on lines 623 to 771
def lora_sft(model_path: str,
data_path: str,
ckpt_output_dir: str,
backend: str = "unsloth",
# LoRA-specific parameters
lora_r: Optional[int] = None,
lora_alpha: Optional[int] = None,
lora_dropout: Optional[float] = None,
target_modules: Optional[List[str]] = None,
# Training parameters
num_epochs: Optional[int] = None,
effective_batch_size: Optional[int] = None,
micro_batch_size: Optional[int] = None,
gradient_accumulation_steps: Optional[int] = None,
learning_rate: Optional[float] = None,
max_seq_len: Optional[int] = None,
lr_scheduler: Optional[str] = None,
warmup_steps: Optional[int] = None,
# Quantization parameters
load_in_4bit: Optional[bool] = None,
load_in_8bit: Optional[bool] = None,
bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[str] = None,
bnb_4bit_use_double_quant: Optional[bool] = None,
# Optimization parameters
flash_attention: Optional[bool] = None,
sample_packing: Optional[bool] = None,
bf16: Optional[bool] = None,
fp16: Optional[bool] = None,
tf32: Optional[bool] = None,
# Saving and logging
save_steps: Optional[int] = None,
eval_steps: Optional[int] = None,
logging_steps: Optional[int] = None,
save_total_limit: Optional[int] = None,
# Weights & Biases
wandb_project: Optional[str] = None,
wandb_entity: Optional[str] = None,
wandb_watch: Optional[str] = None,
# Early stopping
early_stopping_patience: Optional[int] = None,
# Dataset format parameters
dataset_type: Optional[str] = None,
field_messages: Optional[str] = None,
field_instruction: Optional[str] = None,
field_input: Optional[str] = None,
field_output: Optional[str] = None,
# Distributed training parameters
nproc_per_node: Optional[Union[str, int]] = None,
nnodes: Optional[int] = None,
node_rank: Optional[int] = None,
rdzv_id: Optional[Union[str, int]] = None,
rdzv_endpoint: Optional[str] = None,
master_addr: Optional[str] = None,
master_port: Optional[int] = None,
# Multi-GPU model splitting
enable_model_splitting: Optional[bool] = None,
**kwargs) -> Any:
"""Convenience function to run LoRA + SFT training.
Args:
model_path: Path to the model to fine-tune (local path or HuggingFace model ID)
data_path: Path to the training data (JSON/JSONL format)
ckpt_output_dir: Directory to save checkpoints and outputs
backend: Backend implementation to use (default: "unsloth")
LoRA Parameters:
lora_r: LoRA rank (default: 16)
lora_alpha: LoRA alpha parameter (default: 32)
lora_dropout: LoRA dropout rate (default: 0.1)
target_modules: List of module names to apply LoRA to (default: auto-detect)
Training Parameters:
num_epochs: Number of training epochs (default: 3)
effective_batch_size: Effective batch size across all GPUs
micro_batch_size: Batch size per GPU (default: 1)
gradient_accumulation_steps: Steps to accumulate gradients (default: 1)
learning_rate: Learning rate (default: 2e-4)
max_seq_len: Maximum sequence length (default: 2048)
lr_scheduler: Learning rate scheduler (default: 'cosine')
warmup_steps: Number of warmup steps (default: 10)
Quantization Parameters (QLoRA):
load_in_4bit: Use 4-bit quantization for QLoRA
load_in_8bit: Use 8-bit quantization
bnb_4bit_quant_type: 4-bit quantization type (default: 'nf4')
bnb_4bit_compute_dtype: Compute dtype for 4-bit (default: 'bfloat16')
bnb_4bit_use_double_quant: Use double quantization (default: True)
Optimization Parameters:
flash_attention: Use Flash Attention for memory efficiency (default: True)
sample_packing: Pack multiple samples per sequence (default: True)
bf16: Use bfloat16 precision (default: True)
fp16: Use float16 precision (default: False)
tf32: Use TensorFloat-32 (default: True)
Logging and Saving:
save_steps: Steps between checkpoints (default: 500)
eval_steps: Steps between evaluations (default: 500)
logging_steps: Steps between log outputs (default: 10)
save_total_limit: Maximum number of checkpoints to keep (default: 3)
wandb_project: Weights & Biases project name
wandb_entity: Weights & Biases entity name
wandb_watch: What to watch in W&B ('gradients', 'all', etc.)
early_stopping_patience: Early stopping patience (epochs)
Distributed Training:
nproc_per_node: Number of processes (GPUs) per node
nnodes: Total number of nodes
node_rank: Rank of this node (0 to nnodes-1)
rdzv_id: Unique job ID for rendezvous
rdzv_endpoint: Master node endpoint for multi-node training
master_addr: Master node address for distributed training
master_port: Master node port for distributed training
Multi-GPU Configuration:
enable_model_splitting: Enable device_map="balanced" for large models (default: False)
Use for models that don't fit on a single GPU (e.g., Llama 70B)
For smaller models, use standard DDP with torchrun instead
Advanced:
**kwargs: Additional parameters passed to the backend
Returns:
Dictionary containing trained model, tokenizer, and trainer
Example:
# Basic LoRA training
result = lora(
model_path="microsoft/DialoGPT-medium",
data_path="./training_data.jsonl",
ckpt_output_dir="./outputs",
lora_r=16,
lora_alpha=32,
num_epochs=3,
learning_rate=2e-4
)
# QLoRA with 4-bit quantization
result = lora(
model_path="meta-llama/Llama-2-7b-hf",
data_path="./training_data.jsonl",
ckpt_output_dir="./outputs",
load_in_4bit=True,
lora_r=64,
lora_alpha=128,
max_seq_len=4096
)
"""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

lora_sft docstring example uses the wrong function name and misses some parameters

In the lora_sft docstring:

  • The examples call lora(...) instead of lora_sft(...), which will fail if copied verbatim.
  • Some knobs available on LoRASFTAlgorithm.train (e.g., weight_decay, max_grad_norm, wandb_run_name, LoRA advanced options) are only accessible via **kwargs and not reflected in the documented signature.

I’d at least fix the function name in the examples, and ideally add a short note that additional LoRA/PEFT and training options can be passed via **kwargs to keep the docs honest.

-        # Basic LoRA training
-        result = lora(
+        # Basic LoRA+SFT training
+        result = lora_sft(
@@
-        # QLoRA with 4-bit quantization
-        result = lora(
+        # QLoRA with 4-bit quantization
+        result = lora_sft(
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def lora_sft(model_path: str,
data_path: str,
ckpt_output_dir: str,
backend: str = "unsloth",
# LoRA-specific parameters
lora_r: Optional[int] = None,
lora_alpha: Optional[int] = None,
lora_dropout: Optional[float] = None,
target_modules: Optional[List[str]] = None,
# Training parameters
num_epochs: Optional[int] = None,
effective_batch_size: Optional[int] = None,
micro_batch_size: Optional[int] = None,
gradient_accumulation_steps: Optional[int] = None,
learning_rate: Optional[float] = None,
max_seq_len: Optional[int] = None,
lr_scheduler: Optional[str] = None,
warmup_steps: Optional[int] = None,
# Quantization parameters
load_in_4bit: Optional[bool] = None,
load_in_8bit: Optional[bool] = None,
bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[str] = None,
bnb_4bit_use_double_quant: Optional[bool] = None,
# Optimization parameters
flash_attention: Optional[bool] = None,
sample_packing: Optional[bool] = None,
bf16: Optional[bool] = None,
fp16: Optional[bool] = None,
tf32: Optional[bool] = None,
# Saving and logging
save_steps: Optional[int] = None,
eval_steps: Optional[int] = None,
logging_steps: Optional[int] = None,
save_total_limit: Optional[int] = None,
# Weights & Biases
wandb_project: Optional[str] = None,
wandb_entity: Optional[str] = None,
wandb_watch: Optional[str] = None,
# Early stopping
early_stopping_patience: Optional[int] = None,
# Dataset format parameters
dataset_type: Optional[str] = None,
field_messages: Optional[str] = None,
field_instruction: Optional[str] = None,
field_input: Optional[str] = None,
field_output: Optional[str] = None,
# Distributed training parameters
nproc_per_node: Optional[Union[str, int]] = None,
nnodes: Optional[int] = None,
node_rank: Optional[int] = None,
rdzv_id: Optional[Union[str, int]] = None,
rdzv_endpoint: Optional[str] = None,
master_addr: Optional[str] = None,
master_port: Optional[int] = None,
# Multi-GPU model splitting
enable_model_splitting: Optional[bool] = None,
**kwargs) -> Any:
"""Convenience function to run LoRA + SFT training.
Args:
model_path: Path to the model to fine-tune (local path or HuggingFace model ID)
data_path: Path to the training data (JSON/JSONL format)
ckpt_output_dir: Directory to save checkpoints and outputs
backend: Backend implementation to use (default: "unsloth")
LoRA Parameters:
lora_r: LoRA rank (default: 16)
lora_alpha: LoRA alpha parameter (default: 32)
lora_dropout: LoRA dropout rate (default: 0.1)
target_modules: List of module names to apply LoRA to (default: auto-detect)
Training Parameters:
num_epochs: Number of training epochs (default: 3)
effective_batch_size: Effective batch size across all GPUs
micro_batch_size: Batch size per GPU (default: 1)
gradient_accumulation_steps: Steps to accumulate gradients (default: 1)
learning_rate: Learning rate (default: 2e-4)
max_seq_len: Maximum sequence length (default: 2048)
lr_scheduler: Learning rate scheduler (default: 'cosine')
warmup_steps: Number of warmup steps (default: 10)
Quantization Parameters (QLoRA):
load_in_4bit: Use 4-bit quantization for QLoRA
load_in_8bit: Use 8-bit quantization
bnb_4bit_quant_type: 4-bit quantization type (default: 'nf4')
bnb_4bit_compute_dtype: Compute dtype for 4-bit (default: 'bfloat16')
bnb_4bit_use_double_quant: Use double quantization (default: True)
Optimization Parameters:
flash_attention: Use Flash Attention for memory efficiency (default: True)
sample_packing: Pack multiple samples per sequence (default: True)
bf16: Use bfloat16 precision (default: True)
fp16: Use float16 precision (default: False)
tf32: Use TensorFloat-32 (default: True)
Logging and Saving:
save_steps: Steps between checkpoints (default: 500)
eval_steps: Steps between evaluations (default: 500)
logging_steps: Steps between log outputs (default: 10)
save_total_limit: Maximum number of checkpoints to keep (default: 3)
wandb_project: Weights & Biases project name
wandb_entity: Weights & Biases entity name
wandb_watch: What to watch in W&B ('gradients', 'all', etc.)
early_stopping_patience: Early stopping patience (epochs)
Distributed Training:
nproc_per_node: Number of processes (GPUs) per node
nnodes: Total number of nodes
node_rank: Rank of this node (0 to nnodes-1)
rdzv_id: Unique job ID for rendezvous
rdzv_endpoint: Master node endpoint for multi-node training
master_addr: Master node address for distributed training
master_port: Master node port for distributed training
Multi-GPU Configuration:
enable_model_splitting: Enable device_map="balanced" for large models (default: False)
Use for models that don't fit on a single GPU (e.g., Llama 70B)
For smaller models, use standard DDP with torchrun instead
Advanced:
**kwargs: Additional parameters passed to the backend
Returns:
Dictionary containing trained model, tokenizer, and trainer
Example:
# Basic LoRA training
result = lora(
model_path="microsoft/DialoGPT-medium",
data_path="./training_data.jsonl",
ckpt_output_dir="./outputs",
lora_r=16,
lora_alpha=32,
num_epochs=3,
learning_rate=2e-4
)
# QLoRA with 4-bit quantization
result = lora(
model_path="meta-llama/Llama-2-7b-hf",
data_path="./training_data.jsonl",
ckpt_output_dir="./outputs",
load_in_4bit=True,
lora_r=64,
lora_alpha=128,
max_seq_len=4096
)
"""
def lora_sft(model_path: str,
data_path: str,
ckpt_output_dir: str,
backend: str = "unsloth",
# LoRA-specific parameters
lora_r: Optional[int] = None,
lora_alpha: Optional[int] = None,
lora_dropout: Optional[float] = None,
target_modules: Optional[List[str]] = None,
# Training parameters
num_epochs: Optional[int] = None,
effective_batch_size: Optional[int] = None,
micro_batch_size: Optional[int] = None,
gradient_accumulation_steps: Optional[int] = None,
learning_rate: Optional[float] = None,
max_seq_len: Optional[int] = None,
lr_scheduler: Optional[str] = None,
warmup_steps: Optional[int] = None,
# Quantization parameters
load_in_4bit: Optional[bool] = None,
load_in_8bit: Optional[bool] = None,
bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[str] = None,
bnb_4bit_use_double_quant: Optional[bool] = None,
# Optimization parameters
flash_attention: Optional[bool] = None,
sample_packing: Optional[bool] = None,
bf16: Optional[bool] = None,
fp16: Optional[bool] = None,
tf32: Optional[bool] = None,
# Saving and logging
save_steps: Optional[int] = None,
eval_steps: Optional[int] = None,
logging_steps: Optional[int] = None,
save_total_limit: Optional[int] = None,
# Weights & Biases
wandb_project: Optional[str] = None,
wandb_entity: Optional[str] = None,
wandb_watch: Optional[str] = None,
# Early stopping
early_stopping_patience: Optional[int] = None,
# Dataset format parameters
dataset_type: Optional[str] = None,
field_messages: Optional[str] = None,
field_instruction: Optional[str] = None,
field_input: Optional[str] = None,
field_output: Optional[str] = None,
# Distributed training parameters
nproc_per_node: Optional[Union[str, int]] = None,
nnodes: Optional[int] = None,
node_rank: Optional[int] = None,
rdzv_id: Optional[Union[str, int]] = None,
rdzv_endpoint: Optional[str] = None,
master_addr: Optional[str] = None,
master_port: Optional[int] = None,
# Multi-GPU model splitting
enable_model_splitting: Optional[bool] = None,
**kwargs) -> Any:
"""Convenience function to run LoRA + SFT training.
Args:
model_path: Path to the model to fine-tune (local path or HuggingFace model ID)
data_path: Path to the training data (JSON/JSONL format)
ckpt_output_dir: Directory to save checkpoints and outputs
backend: Backend implementation to use (default: "unsloth")
LoRA Parameters:
lora_r: LoRA rank (default: 16)
lora_alpha: LoRA alpha parameter (default: 32)
lora_dropout: LoRA dropout rate (default: 0.1)
target_modules: List of module names to apply LoRA to (default: auto-detect)
Training Parameters:
num_epochs: Number of training epochs (default: 3)
effective_batch_size: Effective batch size across all GPUs
micro_batch_size: Batch size per GPU (default: 1)
gradient_accumulation_steps: Steps to accumulate gradients (default: 1)
learning_rate: Learning rate (default: 2e-4)
max_seq_len: Maximum sequence length (default: 2048)
lr_scheduler: Learning rate scheduler (default: 'cosine')
warmup_steps: Number of warmup steps (default: 10)
Quantization Parameters (QLoRA):
load_in_4bit: Use 4-bit quantization for QLoRA
load_in_8bit: Use 8-bit quantization
bnb_4bit_quant_type: 4-bit quantization type (default: 'nf4')
bnb_4bit_compute_dtype: Compute dtype for 4-bit (default: 'bfloat16')
bnb_4bit_use_double_quant: Use double quantization (default: True)
Optimization Parameters:
flash_attention: Use Flash Attention for memory efficiency (default: True)
sample_packing: Pack multiple samples per sequence (default: True)
bf16: Use bfloat16 precision (default: True)
fp16: Use float16 precision (default: False)
tf32: Use TensorFloat-32 (default: True)
Logging and Saving:
save_steps: Steps between checkpoints (default: 500)
eval_steps: Steps between evaluations (default: 500)
logging_steps: Steps between log outputs (default: 10)
save_total_limit: Maximum number of checkpoints to keep (default: 3)
wandb_project: Weights & Biases project name
wandb_entity: Weights & Biases entity name
wandb_watch: What to watch in W&B ('gradients', 'all', etc.)
early_stopping_patience: Early stopping patience (epochs)
Distributed Training:
nproc_per_node: Number of processes (GPUs) per node
nnodes: Total number of nodes
node_rank: Rank of this node (0 to nnodes-1)
rdzv_id: Unique job ID for rendezvous
rdzv_endpoint: Master node endpoint for multi-node training
master_addr: Master node address for distributed training
master_port: Master node port for distributed training
Multi-GPU Configuration:
enable_model_splitting: Enable device_map="balanced" for large models (default: False)
Use for models that don't fit on a single GPU (e.g., Llama 70B)
For smaller models, use standard DDP with torchrun instead
Advanced:
**kwargs: Additional parameters passed to the backend
Returns:
Dictionary containing trained model, tokenizer, and trainer
Example:
# Basic LoRA+SFT training
result = lora_sft(
model_path="microsoft/DialoGPT-medium",
data_path="./training_data.jsonl",
ckpt_output_dir="./outputs",
lora_r=16,
lora_alpha=32,
num_epochs=3,
learning_rate=2e-4
)
# QLoRA with 4-bit quantization
result = lora_sft(
model_path="meta-llama/Llama-2-7b-hf",
data_path="./training_data.jsonl",
ckpt_output_dir="./outputs",
load_in_4bit=True,
lora_r=64,
lora_alpha=128,
max_seq_len=4096
)
"""
🤖 Prompt for AI Agents
In src/training_hub/algorithms/lora.py around lines 623 to 771, update the
docstring examples to call lora_sft(...) instead of lora(...), and add one short
sentence stating that additional LoRA/PEFT and training options (for example:
weight_decay, max_grad_norm, wandb_run_name, and advanced LoRA/PEFT knobs) can
be supplied via **kwargs to LoRASFTAlgorithm.train so callers know those
parameters exist even though they aren’t explicit in the signature.

Copy link
Collaborator

@RobotSail RobotSail left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partial review

@@ -1,4 +1,7 @@
from typing import override
try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind indicating what the comaptibility is targeting? This'll make it easier to remove when 3.11 is eventually dropped.

Suggested change
try:
# python3.11 compatibility, `override` was added to typing in 3.12
try:

lora_alpha=32
)
```

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does wandb need to be installed separately or does it come with one of the new dependencies? Iirc, wandb currently isn't installed automatically and usually we instead recommend users to install it separately.

sample_packing=True,

# Weights & Biases logging
wandb_project="lora_training",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't checked the installation yet, but if wandb isn't automatically installed through one of the new deps like unsloth or xformers, then we'd need to make sure that users install it separately.

print(f"❌ Custom field mapping training failed: {e}")


if __name__ == "__main__":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Maxusmusti Would it make sense to make this script more interactive using something like typer or argparse? I see that a few of the functions here are commented out, and I'm guessing most of the time people will just run the samples without inspecting the source. It might help users see the full range of functionality if users can toggle these demos via python lora_example.py --qlora.

# qlora_4bit_example()

# Uncomment to test distributed training (launch with torchrun for actual multi-GPU)
# distributed_lora_sft_example()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we make this script interactive, this could have a dedicated codepath, and we could recommend users to run it in the example guide via torchrun lora_example.py --distributed, blocking execution if we don't detect a torchrun environment.

```bash
pip install training-hub[cuda]
# or for development
pip install -e .[cuda]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Maxusmusti Would you mind adding this flag to the README? I think in most cases this flag will need to be present, unless the users get lucky with the prebuilt wheels.

Suggested change
pip install -e .[cuda]
pip install -e .[cuda] --no-build-isolation

I can't suggest it, but the same change needs to be made on line 106.

Signed-off-by: Mustafa Eyceoz <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (6)
src/training_hub/__init__.py (1)

4-4: LGTM! The critical import issue has been resolved.

The non-existent AxolotlLoRABackend import has been removed, fixing the ImportError reported in the previous review. The current imports (lora_sft, LoRASFTAlgorithm, UnslothLoRABackend) match the implementations in src/training_hub/algorithms/lora.py, and the __all__ list correctly exposes these symbols.

Also applies to: 15-15, 20-21

src/training_hub/algorithms/lora.py (5)

129-151: Advanced LoRA parameters are still hardcoded and not applied.

The use_rslora (line 147) and loftq_config (line 148) parameters are hardcoded, and use_dora, rank_pattern, alpha_pattern, and init_lora_weights are not passed at all. These parameters are exposed in LoRASFTAlgorithm.train (lines 308-313) and advertised in the public API, but they have no effect.

As confirmed in previous reviews, Unsloth's FastLanguageModel.get_peft_model does accept these parameters. Please pass them through from params:

 model = FastLanguageModel.get_peft_model(
     model,
     r=params.get('lora_r', 16),
     target_modules=target_modules,
     lora_alpha=params.get('lora_alpha', 32),
     lora_dropout=params.get('lora_dropout', 0.0),
     bias="none",
     use_gradient_checkpointing="unsloth",
     random_state=params.get('seed', 3407),
-    use_rslora=False,
-    loftq_config=None,
+    use_rslora=params.get('use_rslora', False),
+    use_dora=params.get('use_dora', False),
+    loftq_config=params.get('loftq_config'),
+    rank_pattern=params.get('rank_pattern'),
+    alpha_pattern=params.get('alpha_pattern'),
+    init_lora_weights=params.get('init_lora_weights', True),
 )

153-200: Dataset field names are hardcoded; customization parameters are ignored.

The field_messages, field_instruction, field_input, and field_output parameters (exposed at lines 346-349 and documented at lines 428-431) are accepted but never used. The code hardcodes 'messages' (line 171), 'instruction' (lines 186-187), 'input' (line 188), and 'output' (line 189), breaking datasets with different column names.

Read the field names from params with sensible defaults:

 def _prepare_dataset(self, params: Dict[str, Any], tokenizer) -> Any:
     """Prepare dataset for training."""
     from datasets import load_dataset
     
     # Load dataset
     if params['data_path'].endswith('.jsonl') or params['data_path'].endswith('.json'):
         dataset = load_dataset('json', data_files=params['data_path'], split='train')
     else:
         dataset = load_dataset(params['data_path'], split='train')
     
     # Handle different dataset formats
     dataset_type = params.get('dataset_type', 'chat_template')
+    field_messages = params.get('field_messages', 'messages')
+    field_instruction = params.get('field_instruction', 'instruction')
+    field_input = params.get('field_input', 'input')
+    field_output = params.get('field_output', 'output')
     
     if dataset_type == 'chat_template':
         # Convert messages format using chat template
         def format_chat_template(examples):
-            # examples['messages'] is a list of conversations (batched)
+            # examples[field_messages] is a list of conversations (batched)
             texts = []
-            for conversation in examples['messages']:
+            for conversation in examples[field_messages]:
                 text = tokenizer.apply_chat_template(
                     conversation,
                     tokenize=False,
                     add_generation_prompt=False
                 )
                 texts.append(text)
             return {"text": texts}
         
         dataset = dataset.map(format_chat_template, batched=True)
     
     elif dataset_type == 'alpaca':
         # Convert alpaca format to text
         def format_alpaca(examples):
             texts = []
-            for i in range(len(examples['instruction'])):
-                instruction = examples['instruction'][i]
-                input_text = examples.get('input', [''] * len(examples['instruction']))[i]
-                output = examples['output'][i]
+            n = len(examples[field_instruction])
+            for i in range(n):
+                instruction = examples[field_instruction][i]
+                input_text = examples.get(field_input, [''] * n)[i]
+                output = examples[field_output][i]
                 
                 if input_text:
                     text = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}"
                 else:
                     text = f"### Instruction:\n{instruction}\n\n### Response:\n{output}"
                 texts.append(text)
             return {"text": texts}
         
         dataset = dataset.map(format_alpaca, batched=True)
     
     return dataset

203-275: wandb_watch and early_stopping_patience parameters are accepted but not used.

While wandb_project and wandb_run_name are wired through (lines 271-272), the wandb_watch parameter (defined at line 340, documented at line 423) is accepted but never applied. Similarly, early_stopping_patience (line 343, documented at line 424) is never converted into an EarlyStoppingCallback.

Both features have first-class support in the underlying libraries (TRL accepts EarlyStoppingCallback, W&B supports wandb.watch via WandbCallback), but the current implementation silently ignores them, potentially surprising users.

Consider either implementing these features or documenting them as reserved/not-yet-implemented to set correct expectations.


297-297: max_tokens_per_gpu parameter is exposed but not used.

The max_tokens_per_gpu parameter is defined in the train method signature (line 297), included in optional_params (line 467), listed in get_optional_params (line 560), and documented (line 376), but it's never read or used in the backend's batch sizing, truncation, or packing logic.

Either wire this parameter through to control batch packing/sizing behavior, or remove it from the API surface to avoid misleading users into thinking it has an effect.

Also applies to: 467-467, 560-560


754-775: Examples call lora(...) instead of lora_sft(...).

The docstring examples (lines 756 and 767) call lora(...) instead of lora_sft(...), which will cause a NameError if users copy the code verbatim.

     Example:
-        # Basic LoRA training
-        result = lora(
+        # Basic LoRA+SFT training
+        result = lora_sft(
             model_path="microsoft/DialoGPT-medium",
             data_path="./training_data.jsonl",
             ckpt_output_dir="./outputs",
             lora_r=16,
             lora_alpha=32,
             num_epochs=3,
             learning_rate=2e-4
         )
         
         # QLoRA with 4-bit quantization
-        result = lora(
+        result = lora_sft(
             model_path="meta-llama/Llama-2-7b-hf",
             data_path="./training_data.jsonl",
             ckpt_output_dir="./outputs",
             load_in_4bit=True,
             lora_r=64,
             lora_alpha=128,
             max_seq_len=4096
         )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 586b781 and 58c87a3.

📒 Files selected for processing (3)
  • pyproject.toml (2 hunks)
  • src/training_hub/__init__.py (1 hunks)
  • src/training_hub/algorithms/lora.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/training_hub/__init__.py (1)
src/training_hub/algorithms/lora.py (3)
  • lora_sft (628-828)
  • LoRASFTAlgorithm (280-620)
  • UnslothLoRABackend (11-275)
src/training_hub/algorithms/lora.py (3)
src/training_hub/algorithms/__init__.py (6)
  • Algorithm (6-22)
  • Backend (24-30)
  • AlgorithmRegistry (33-79)
  • register_algorithm (40-44)
  • register_backend (47-51)
  • create_algorithm (82-96)
src/training_hub/algorithms/sft.py (2)
  • sft (169-248)
  • SFTAlgorithm (40-161)
src/training_hub/algorithms/peft_extender.py (7)
  • LoRAPEFTExtender (26-90)
  • get_lora_parameters (93-96)
  • apply_lora_defaults (99-102)
  • apply_peft_config (21-23)
  • apply_peft_config (54-90)
  • get_peft_params (16-18)
  • get_peft_params (29-52)
🪛 Ruff (0.14.4)
src/training_hub/algorithms/lora.py

24-27: Avoid specifying long messages outside the exception class

(TRY003)


29-32: Avoid specifying long messages outside the exception class

(TRY003)


34-37: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (3)
src/training_hub/algorithms/lora.py (2)

14-82: LGTM! Import error handling is clear and the torchrun parameter issue has been resolved.

The import error messages provide helpful installation instructions, and the code now correctly uses algorithm_params directly without stripping torchrun parameters. The comments clearly explain that torchrun parameters are handled by the launcher, not the Python training code.


84-127: LGTM! Quantization parameters are now properly wired through.

The previously unused load_in_8bit parameter is now passed to FastLanguageModel.from_pretrained (line 118), and the bnb_4bit_* parameters are correctly read from params and applied when 4-bit quantization is enabled (lines 106-111). The redundant os import issue has also been resolved.

pyproject.toml (1)

11-11: License format modernization approved.

The license field change from license = {text = "Apache-2.0"} to license = "Apache-2.0" aligns with modern PEP 639 conventions and is more concise.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants