@@ -745,6 +745,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
745745 print (f"\033 [92mRunning training command as subprocess: { ' ' .join (command )} \033 [0m" )
746746 process = None
747747 interrupt : KeyboardInterrupt | Exception | None = None
748+ failure = False
748749 try :
749750 process = StreamablePopen (
750751 f"{ train_args .ckpt_output_dir } /full_logs_global{ torch_args .node_rank } .log" ,
@@ -755,19 +756,20 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
755756 print ("Training subprocess interrupted by user." )
756757 interrupt = e
757758 except Exception as e :
758- print (f"An error occurred: { str ( e ) } " )
759+ print ("Unexpected exception received during distributed training " )
759760 interrupt = e
760761 finally :
761762 if "process" not in locals () or process is None :
762763 return
763- if process .poll () == 0 :
764- print ("\033 [92mTraining subprocess exited successfully! 🎉\033 [0m" )
764+
765+ failure = process .poll () != 0
766+ if not failure :
767+ print ("\033 [92mOperation completed successfully! 🎉\033 [0m" )
765768 else :
766769 print (
767770 "\033 [91mTraining subprocess has not exited yet. Sending SIGTERM.\033 [0m"
768771 )
769772
770- print ("Sending interrupt signal to Training subprocess." )
771773 process .terminate ()
772774 try :
773775 print ("Waiting for process to exit, 60s..." )
@@ -779,8 +781,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
779781 process .kill ()
780782
781783 if interrupt :
782- print (f"Error caught from training subprocess.: { interrupt } " )
783784 raise interrupt
785+ if failure :
786+ raise RuntimeError (
787+ "Suffered a failure during distributed training. Please see the training logs for more context."
788+ )
784789
785790
786791if __name__ == "__main__" :
0 commit comments