Skip to content

Commit 69c6ec3

Browse files
committed
use str | int for nproc_per_node and rdzv_id
Signed-off-by: Saad Zaher <[email protected]>
1 parent fca6305 commit 69c6ec3

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

src/training_hub/algorithms/osft.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def train(
5858
data_output_dir: str | None = None,
5959

6060
# Torchrun parameters for multi-node support
61-
nproc_per_node: str | None = None,
61+
nproc_per_node: str | int | None = None,
6262
nnodes: int | None = None,
6363
node_rank: int | None = None,
6464
rdzv_id: str | None = None,
@@ -222,7 +222,7 @@ def get_optional_params(self) -> dict[str, type]:
222222
'use_processed_dataset': bool,
223223
'unmask_messages': bool,
224224
'data_output_dir': str,
225-
'nproc_per_node': str,
225+
'nproc_per_node': str | int,
226226
'nnodes': int,
227227
'node_rank': int,
228228
'rdzv_id': str,
@@ -460,7 +460,7 @@ def osft(
460460
save_final_checkpoint: bool | None = None,
461461
num_epochs: int | None = None,
462462
# Torchrun parameters for multi-node support
463-
nproc_per_node: str | None = None,
463+
nproc_per_node: str | int | None = None,
464464
nnodes: int | None = None,
465465
node_rank: int | None = None,
466466
rdzv_id: str | None = None,

src/training_hub/algorithms/sft.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ def train(self,
7373
accelerate_full_state_at_epoch: Optional[bool] = None,
7474
checkpoint_at_epoch: Optional[bool] = None,
7575
# Torchrun parameters for multi-node support
76-
nproc_per_node: Optional[int] = None,
76+
nproc_per_node: Optional[str | int] = None,
7777
nnodes: Optional[int] = None,
7878
node_rank: Optional[int] = None,
79-
rdzv_id: Optional[int] = None,
79+
rdzv_id: Optional[str | int] = None,
8080
rdzv_endpoint: Optional[str] = None,
8181
**kwargs) -> Any:
8282
"""Execute SFT training.
@@ -158,10 +158,10 @@ def get_optional_params(self) -> Dict[str, Type]:
158158
'warmup_steps': int,
159159
'accelerate_full_state_at_epoch': bool,
160160
'checkpoint_at_epoch': bool,
161-
'nproc_per_node': int,
161+
'nproc_per_node': str | int,
162162
'nnodes': int,
163163
'node_rank': int,
164-
'rdzv_id': int,
164+
'rdzv_id': str | int,
165165
'rdzv_endpoint': str,
166166
}
167167

@@ -187,10 +187,10 @@ def sft(model_path: str,
187187
accelerate_full_state_at_epoch: Optional[bool] = None,
188188
checkpoint_at_epoch: Optional[bool] = None,
189189
# Torchrun parameters for multi-node support
190-
nproc_per_node: Optional[int] = None,
190+
nproc_per_node: Optional[str | int] = None,
191191
nnodes: Optional[int] = None,
192192
node_rank: Optional[int] = None,
193-
rdzv_id: Optional[int] = None,
193+
rdzv_id: Optional[str | int] = None,
194194
rdzv_endpoint: Optional[str] = None,
195195
**kwargs) -> Any:
196196
"""Convenience function to run SFT training.

0 commit comments

Comments
 (0)