Skip to content

Commit 10d564c

Browse files
committed
add osft
1 parent c0e5616 commit 10d564c

File tree

6 files changed

+380
-2
lines changed

6 files changed

+380
-2
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ dependencies = [
1515
"packaging>=24.2",
1616
"wheel>=0.43",
1717
"instructlab-training>=0.11.1",
18+
"rhai-innovation-mini-trainer @ git+https://github.com/Red-Hat-AI-Innovation-Team/mini_trainer.git@a22c81c52660b1304e8fbef8bef548fcae93aba0", # TODO: update this once we have a release
1819
"torch>=2.6.0",
1920
"numba>=0.50",
2021
"datasets>=2.15.0",
@@ -50,6 +51,7 @@ dynamic = ["version"]
5051
[project.optional-dependencies]
5152
cuda = [
5253
"instructlab-training[cuda]>=0.11.1",
54+
"rhai-innovation-mini-trainer[cuda] @ git+https://github.com/Red-Hat-AI-Innovation-Team/mini_trainer.git@a22c81c52660b1304e8fbef8bef548fcae93aba0", # TODO: update this once we have a release
5355
"flash-attn>=2.8",
5456
"einops>=0.8"
5557
]

src/training_hub/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .algorithms import Algorithm, Backend, AlgorithmRegistry, create_algorithm
22
from .algorithms.sft import sft, SFTAlgorithm, InstructLabTrainingSFTBackend
3+
from .algorithms.osft import OSFTAlgorithm, MiniTrainerOSFTBackend
34
from .hub_core import welcome
45

56
__all__ = [
@@ -10,5 +11,7 @@
1011
'sft',
1112
'SFTAlgorithm',
1213
'InstructLabTrainingSFTBackend',
14+
'OSFTAlgorithm',
15+
'MiniTrainerOSFTBackend',
1316
'welcome'
14-
]
17+
]

src/training_hub/algorithms/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ def get_required_params(self) -> Dict[str, Type]:
1616
"""Return dictionary of required parameter names and their types."""
1717
pass
1818

19+
@abstractmethod
20+
def get_optional_params(self) -> Dict[str, Type]:
21+
"""Return dictionary of optional parameter names and their types."""
22+
pass
1923

2024
class Backend(ABC):
2125
"""Base class for all backend implementations."""
@@ -89,4 +93,4 @@ def create_algorithm(algorithm_name: str, backend_name: str = None, **kwargs) ->
8993
backend_class = AlgorithmRegistry.get_backend(algorithm_name, backend_name)
9094
backend_instance = backend_class()
9195

92-
return algorithm_class(backend=backend_instance, **kwargs)
96+
return algorithm_class(backend=backend_instance, **kwargs)
Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
import os
2+
import shutil
3+
from typing import Literal, get_origin, get_args, Union
4+
from itertools import chain
5+
from dataclasses import fields
6+
7+
import datasets
8+
from training_hub.algorithms import Algorithm, Backend, AlgorithmRegistry
9+
from training_hub.utils import format_type_name
10+
11+
_AlgorithmParamsKeyLiteral = Literal['parameters', 'renames']
12+
13+
class OSFTAlgorithm(Algorithm):
14+
"""Orthogonal Subspace Fine-Tuning algorithm."""
15+
16+
def __init__(self, backend: Backend, **kwargs) -> None:
17+
self.backend = backend
18+
self.kwargs = kwargs
19+
20+
def train(
21+
self,
22+
model_path: str,
23+
data_path: str,
24+
batch_size: int,
25+
max_tokens_per_gpu: int,
26+
max_seq_len: int,
27+
learning_rate: float,
28+
output_dir: str,
29+
unfreeze_rank_ratio: float,
30+
31+
# patterns that we want to match against when selecting
32+
# modules for OSFT
33+
target_patterns: list[str] | None = None,
34+
35+
# settings for training mode
36+
seed: int | None = None,
37+
use_liger: bool | None = None,
38+
unmask_messages: bool | None = None,
39+
40+
# learning rate scheduler
41+
lr_scheduler: str = None,
42+
warmup_steps: int = None,
43+
lr_scheduler_kwargs: dict[str, str] | None = None,
44+
45+
# checkpointing
46+
checkpoint_at_epoch: bool | None = None,
47+
save_final_checkpoint: bool | None = None,
48+
49+
# parameters for the training mode
50+
epochs: int | None = None,
51+
52+
# Torchrun parameters for multi-node support
53+
nproc_per_node: int | None = None,
54+
nnodes: int | None = None,
55+
node_rank: int | None = None,
56+
rdzv_id: int | None = None,
57+
rdzv_endpoint: str | None = None,
58+
**kwargs,
59+
) -> any:
60+
"""Execute OSFT training using MiniTrainer."""
61+
62+
required_params = {
63+
'model_path': model_path,
64+
'data_path': data_path,
65+
'batch_size': batch_size,
66+
'max_tokens_per_gpu': max_tokens_per_gpu,
67+
'max_seq_len': max_seq_len,
68+
'learning_rate': learning_rate,
69+
'output_dir': output_dir,
70+
'unfreeze_rank_ratio': unfreeze_rank_ratio,
71+
}
72+
73+
optional_params = {
74+
'target_patterns': target_patterns,
75+
76+
# for data processing
77+
'unmask_messages': unmask_messages,
78+
79+
# scheduler params
80+
'lr_scheduler': lr_scheduler,
81+
'lr_scheduler_kwargs': lr_scheduler_kwargs,
82+
'warmup_steps': warmup_steps,
83+
84+
# checkpointing settings
85+
'checkpoint_at_epoch': checkpoint_at_epoch,
86+
'save_final_checkpoint': save_final_checkpoint,
87+
88+
# mini trainer supports a few different modes, but we fix this one for simplicty
89+
# another mode can be selected by overriding via kwargs
90+
'training_mode': 'epoch',
91+
'epochs': epochs,
92+
93+
'use_liger': use_liger,
94+
'seed': seed,
95+
96+
# torchrun params
97+
'nproc_per_node': nproc_per_node,
98+
'nnodes': nnodes,
99+
'node_rank': node_rank,
100+
'rdzv_id': rdzv_id,
101+
'rdzv_endpoint': rdzv_endpoint,
102+
}
103+
104+
# data_params = {
105+
# 'data_path': data_path,
106+
# 'unmask_messages': unmask_messages,
107+
# # this should be something like `{output_dir}/_internal`, but we should
108+
# # delegate the responsibility for that onto the backend algorithm
109+
# # Also, we don't pass this to renames since this is also being used as-is in the
110+
# # main backend.
111+
# 'data_output_path': output_dir
112+
# }
113+
114+
# we keep a separate mapping of which parameters will be renamed,
115+
# so this function can make assertions about algorithm requirements
116+
# while the backend can more easily use the original arguments without needing
117+
# to re-map in several places
118+
renames = {
119+
'use_liger': 'use_liger_kernels',
120+
'warmup_steps': 'num_warmup_steps',
121+
'target_patterns': 'osft_target_patterns',
122+
'unfreeze_rank_ratio': 'osft_unfreeze_rank_ratio',
123+
'model_path': 'model_name_or_path',
124+
'epochs': 'max_epochs',
125+
}
126+
127+
# now do validation now that we've set everything up
128+
for required_param in self.get_required_params().keys():
129+
if required_param not in required_params:
130+
raise ValueError(f"error: required parameter not provided: {required_param}")
131+
132+
all_params = dict(
133+
**required_params,
134+
**optional_params,
135+
**kwargs,
136+
)
137+
138+
# validate types of all parameters
139+
self._validate_param_types(all_params)
140+
141+
# now we can build the algorithm params
142+
algorithm_params = dict(
143+
parameters=all_params,
144+
renames=renames
145+
)
146+
147+
return self.backend.execute_training(algorithm_params)
148+
149+
def get_required_params(self) -> dict[str, type]:
150+
"""Return dictionary of required parameter names and their types."""
151+
return {
152+
'model_path': str,
153+
'data_path': str,
154+
'unfreeze_rank_ratio': float,
155+
'batch_size': int,
156+
'max_tokens_per_gpu': int,
157+
'max_seq_len': int,
158+
'learning_rate': float,
159+
'output_dir': str,
160+
}
161+
162+
def get_optional_params(self) -> dict[str, type]:
163+
"""Return dictionary of optional parameter names and their types."""
164+
return {
165+
'target_patterns': list[str],
166+
'unmask_messages': bool,
167+
'lr_scheduler': str,
168+
'lr_scheduler_kwargs': dict[str, str],
169+
'warmup_steps': int,
170+
'checkpoint_at_epoch': bool,
171+
'save_final_checkpoint': bool,
172+
'training_mode': str,
173+
'max_epochs': int,
174+
'use_liger': bool,
175+
'seed': int,
176+
'nproc_per_node': int,
177+
'nnodes': int,
178+
'node_rank': int,
179+
'rdzv_id': int,
180+
'rdzv_endpoint': str,
181+
}
182+
183+
def _validate_param_types(self, params: dict[str, any]):
184+
"""Type-check given parameters, handling modern Python typing constructs."""
185+
required_param_types = self.get_required_params()
186+
optional_param_types = self.get_optional_params()
187+
all_param_types = {**required_param_types, **optional_param_types}
188+
189+
for param, value in params.items():
190+
# use 'any' here to handle the case when the param is not defined by
191+
# either optional or required
192+
param_type = all_param_types.get(param, any)
193+
194+
# allow optional params to be None
195+
if param in optional_param_types and value is None:
196+
continue # None is allowed for optional params
197+
198+
if not self._check_type(value, param_type):
199+
err_msg = (
200+
f"error: param '{param}' received unexpected type, "
201+
f"expected '{format_type_name(param_type)}' but got '{format_type_name(type(value))}'"
202+
)
203+
raise ValueError(err_msg)
204+
205+
def _check_type(self, value, expected_type) -> bool:
206+
"""Check if value matches expected_type, handling modern typing constructs."""
207+
# Handle 'any' type (accepts anything)
208+
if expected_type is any:
209+
return True
210+
211+
# Handle basic types that work with isinstance
212+
try:
213+
if isinstance(expected_type, type):
214+
return isinstance(value, expected_type)
215+
except TypeError:
216+
pass # Fall through to handle complex types
217+
218+
# Handle parameterized generics and unions
219+
origin = get_origin(expected_type)
220+
args = get_args(expected_type)
221+
222+
# Handle Union types (including X | None syntax)
223+
if origin is Union:
224+
return any(self._check_type(value, arg) for arg in args)
225+
226+
# Handle list types
227+
if origin is list:
228+
if not isinstance(value, list):
229+
return False
230+
if args and value: # Check element types if specified and list is not empty
231+
element_type = args[0]
232+
return all(self._check_type(item, element_type) for item in value)
233+
return True
234+
235+
# Handle dict types
236+
if origin is dict:
237+
if not isinstance(value, dict):
238+
return False
239+
if args and value: # Check key/value types if specified and dict is not empty
240+
key_type, val_type = args[0], args[1]
241+
return all(
242+
self._check_type(k, key_type) and self._check_type(v, val_type)
243+
for k, v in value.items()
244+
)
245+
return True
246+
247+
# Fallback for basic isinstance check
248+
try:
249+
return isinstance(value, expected_type)
250+
except TypeError:
251+
# If we can't check the type, assume it's valid
252+
return True
253+
254+
255+
256+
257+
class MiniTrainerOSFTBackend(Backend):
258+
"""MiniTrainer backend for OSFT algorithm."""
259+
260+
def execute_training(self, algorithm_params: dict[_AlgorithmParamsKeyLiteral, dict[str, any]]) -> any:
261+
"""Execute OSFT training using MiniTrainer."""
262+
from mini_trainer import run_training, TrainingArgs, TorchrunArgs, TrainingMode
263+
264+
265+
# mini trainer doesn't do its own data processing, so we use the one from
266+
# instructlab training
267+
from instructlab.training.data_process import process_messages_into_input_ids
268+
269+
270+
# first we need to process data
271+
output_dir = algorithm_params['parameters']['output_dir']
272+
data_output_path = os.path.join(output_dir, '_internal_data_processing')
273+
os.makedirs(data_output_path, exist_ok=True)
274+
275+
# if we received unmask then we need to add that
276+
training_params = algorithm_params['parameters']
277+
processing_data_path = training_params['data_path']
278+
unmask_messages = training_params.get('unmask_messages', False)
279+
if unmask_messages:
280+
ds = datasets.load_dataset(training_params['data_path'], split='train')
281+
ds = ds.map(lambda _: { "unmask": True })
282+
processing_data_path = os.path.join(data_output_path, 'intermediate_data.jsonl')
283+
ds.to_json(processing_data_path)
284+
285+
# now we process the data
286+
process_messages_into_input_ids(
287+
data_path=processing_data_path,
288+
data_output_path=data_output_path,
289+
model_path=training_params['model_path'],
290+
max_seq_len=training_params['max_seq_len'],
291+
num_cpu_procs=8,
292+
)
293+
294+
# above function will save to this file, so we pass this to the trainer
295+
processed_data_path = os.path.join(data_output_path, 'data.jsonl')
296+
297+
298+
# This section converts the parameters we get from the Algorithm into one which works
299+
# for this backend (mini-trainer). Since the algorithm renames parameters for simplicity,
300+
# we map each param back into its original name then place it into the correct dataclass.
301+
renames = algorithm_params['renames']
302+
training_params = {renames.get(k, k): v for k, v in algorithm_params['parameters'].items()}
303+
torchrun_args_fields = {f.name for f in fields(TorchrunArgs)}
304+
training_args_fields = {f.name for f in fields(TrainingArgs)}
305+
306+
# adjust arguments to align with the API definition
307+
training_args_pre = {k: v for k, v in training_params.items() if k in training_args_fields and v is not None}
308+
training_args_pre['data_path'] = processed_data_path # replaces raw data path with processed
309+
training_args_pre['training_mode'] = TrainingMode(training_args_pre['training_mode'])
310+
torchrun_args_pre = {k: v for k, v in training_params.items() if k in torchrun_args_fields and v is not None}
311+
312+
# now we run training
313+
return run_training(
314+
torch_args=TorchrunArgs(**torchrun_args_pre),
315+
train_args=TrainingArgs(**training_args_pre),
316+
)
317+
318+
319+
320+
321+
322+
323+
AlgorithmRegistry.register_algorithm('osft', OSFTAlgorithm)
324+
AlgorithmRegistry.register_backend('osft', 'mini-trainer', MiniTrainerOSFTBackend)

src/training_hub/algorithms/sft.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,22 @@ def get_required_params(self) -> Dict[str, Type]:
152152
'max_batch_len': int,
153153
}
154154

155+
def get_optional_params(self) -> Dict[str, Type]:
156+
"""Return optional parameters for SFT."""
157+
return {
158+
'max_tokens_per_gpu': int,
159+
'data_output_dir': str,
160+
'save_samples': int,
161+
'warmup_steps': int,
162+
'accelerate_full_state_at_epoch': bool,
163+
'checkpoint_at_epoch': bool,
164+
'nproc_per_node': int,
165+
'nnodes': int,
166+
'node_rank': int,
167+
'rdzv_id': int,
168+
'rdzv_endpoint': str,
169+
}
170+
155171

156172
# Register the algorithm and backend
157173
AlgorithmRegistry.register_algorithm('sft', SFTAlgorithm)

0 commit comments

Comments
 (0)