@@ -20,6 +20,14 @@ def get_open_port() -> int:
2020 return s .getsockname ()[1 ]
2121
2222
23+ class AgentFailedError (Exception ):
24+ pass
25+
26+
27+ class WorkerFailedError (Exception ):
28+ pass
29+
30+
2331@dataclass
2432class LauncherAgentGroup :
2533 launcher_hostname : str
@@ -50,11 +58,15 @@ def _deserialize(self, serialized: bytes) -> Any:
5058
5159 def _all_gather (self , obj : Any ) -> list :
5260 """gather object from every rank to list on every rank"""
53- object_bytes = self ._serialize (obj )
54- object_list = [b"" ] * self .world_size
55- # raises RuntimeError if timeout
56- dist .all_gather_object (object_list = object_list , obj = object_bytes , group = self .group )
57- return [self ._deserialize (o ) for o in object_list ]
61+ try :
62+ object_bytes = self ._serialize (obj )
63+ object_list = [b"" ] * self .world_size
64+ # raises RuntimeError if timeout
65+ dist .all_gather_object (object_list = object_list , obj = object_bytes , group = self .group )
66+ return [self ._deserialize (o ) for o in object_list ]
67+ except RuntimeError as e :
68+ # occurs if launcher or any agent dies and communication times out
69+ raise AgentFailedError from e
5870
5971 def sync_payloads (
6072 self ,
@@ -90,25 +102,25 @@ class AgentPayload:
90102
91103
92104@dataclass
93- class WorkerException :
105+ class ExceptionFromWorker :
94106 exception : Exception
95107
96108
97109@dataclass
98110class AgentStatus :
99111 state : Literal ["running" , "failed" , "done" ]
100- return_values : list [Any | WorkerException ] = field (
112+ return_values : list [Any | WorkerFailedError | ExceptionFromWorker ] = field (
101113 default_factory = list
102114 ) # indexed by local rank
103115
104116 @classmethod
105117 def from_result (cls , result : RunProcsResult | None ) -> Self :
106118 if result is None :
107119 return cls (state = "running" )
108-
120+ for local_rank , failure in result .failures .items ():
121+ result .return_values [local_rank ] = WorkerFailedError (failure .message )
109122 return_values = list (result .return_values .values ())
110-
111- failed = any (isinstance (v , WorkerException ) for v in return_values )
123+ failed = any (isinstance (v , ExceptionFromWorker ) for v in return_values )
112124 state = "failed" if failed else "done"
113125
114126 return cls (
0 commit comments