@@ -73,10 +73,10 @@ def train(self,
7373 accelerate_full_state_at_epoch : Optional [bool ] = None ,
7474 checkpoint_at_epoch : Optional [bool ] = None ,
7575 # Torchrun parameters for multi-node support
76- nproc_per_node : Optional [int ] = None ,
76+ nproc_per_node : Optional [str | int ] = None ,
7777 nnodes : Optional [int ] = None ,
7878 node_rank : Optional [int ] = None ,
79- rdzv_id : Optional [int ] = None ,
79+ rdzv_id : Optional [str | int ] = None ,
8080 rdzv_endpoint : Optional [str ] = None ,
8181 ** kwargs ) -> Any :
8282 """Execute SFT training.
@@ -158,10 +158,10 @@ def get_optional_params(self) -> Dict[str, Type]:
158158 'warmup_steps' : int ,
159159 'accelerate_full_state_at_epoch' : bool ,
160160 'checkpoint_at_epoch' : bool ,
161- 'nproc_per_node' : int ,
161+ 'nproc_per_node' : str | int ,
162162 'nnodes' : int ,
163163 'node_rank' : int ,
164- 'rdzv_id' : int ,
164+ 'rdzv_id' : str | int ,
165165 'rdzv_endpoint' : str ,
166166 }
167167
@@ -187,10 +187,10 @@ def sft(model_path: str,
187187 accelerate_full_state_at_epoch : Optional [bool ] = None ,
188188 checkpoint_at_epoch : Optional [bool ] = None ,
189189 # Torchrun parameters for multi-node support
190- nproc_per_node : Optional [int ] = None ,
190+ nproc_per_node : Optional [str | int ] = None ,
191191 nnodes : Optional [int ] = None ,
192192 node_rank : Optional [int ] = None ,
193- rdzv_id : Optional [int ] = None ,
193+ rdzv_id : Optional [str | int ] = None ,
194194 rdzv_endpoint : Optional [str ] = None ,
195195 ** kwargs ) -> Any :
196196 """Convenience function to run SFT training.
0 commit comments