@@ -67,6 +67,7 @@ class TrainingArguments:
6767 saver_option (Optional[SaverOptions]): Options for checkpoint saver. Defaults to None.
6868 ckpt_save_arg (Optional[CheckpointSaveArguments]): Arguments for checkpoint save. Defaults to None.
6969 ckpt_load_arg (Optional[CheckpointLoadArguments]): Arguments for checkpoint load. Defaults to None.
70+ mixed_precision (Optional[str]): Mixed precision training mode. Defaults to None. Only support "bf16" and "fp16".
7071 """
7172
7273 gradient_accumulation_steps : int = 1
@@ -90,6 +91,7 @@ class TrainingArguments:
9091 saver_option : Optional [SaverOptions ] = None
9192 ckpt_save_arg : Optional [CheckpointSaveArguments ] = None
9293 ckpt_load_arg : Optional [CheckpointLoadArguments ] = None
94+ mixed_precision : Optional [str ] = None
9395
9496
9597class Trainer :
@@ -197,11 +199,15 @@ def __init__(
197199 self .dense_lr_scheduler = dense_optimizers [1 ]
198200 self .sparse_optimizer = sparse_optimizer
199201 self .data_to_cuda = data_to_cuda
202+ self .mixed_precision = args .mixed_precision
203+ if self .mixed_precision is not None :
204+ assert self .mixed_precision in ["bf16" , "fp16" ], "mixed_precision must be 'bf16' or 'fp16'"
200205 ddp_kwargs = DistributedDataParallelKwargs (find_unused_parameters = True )
201206 init_kwargs = InitProcessGroupKwargs (timeout = timedelta (seconds = 1800 ))
202207 self .accelerator = Accelerator (
203208 kwargs_handlers = [ddp_kwargs , init_kwargs ],
204209 gradient_accumulation_steps = args .gradient_accumulation_steps ,
210+ mixed_precision = self .mixed_precision ,
205211 ** kwargs ,
206212 )
207213 self .gradient_accumulation_steps = args .gradient_accumulation_steps
@@ -559,7 +565,8 @@ def _train_step(self, data, epoch, metrics):
559565 self .dense_optimizer .zero_grad ()
560566 if self .sparse_optimizer is not None :
561567 self .sparse_optimizer .zero_grad ()
562- loss = self .model (data )
568+ with self .accelerator .autocast ():
569+ loss = self .model (data )
563570 metrics .update (epoch = epoch )
564571 metrics .update (loss = loss )
565572 metrics .update (get_global_metrics ())
0 commit comments