@@ -96,6 +96,7 @@ class Launcher:
96
96
]
97
97
)
98
98
env_file : str | os .PathLike | None = None
99
+ timeout : int = 600
99
100
100
101
def run (
101
102
self ,
@@ -209,6 +210,7 @@ def run(
209
210
worker_global_ranks = worker_global_ranks ,
210
211
worker_log_files = worker_log_files ,
211
212
backend = self .backend ,
213
+ timeout = self .timeout ,
212
214
)
213
215
214
216
agent_payloads : list [AgentPayload ] = launcher_agent_group .sync_payloads (payload = payload )[1 :] # pyright: ignore[reportAssignmentType]
@@ -270,6 +272,7 @@ def launch(
270
272
"NCCL*" ,
271
273
],
272
274
env_file : str | os .PathLike | None = None ,
275
+ timeout : int = 600 ,
273
276
) -> dict [int , Any ]:
274
277
"""
275
278
Launch a distributed PyTorch function on the specified nodes.
@@ -292,6 +295,8 @@ def launch(
292
295
:type env_vars: list[str], optional
293
296
:param env_file: An additional environment file that will be sourced prior to executing ``func``, defaults to None
294
297
:type env_file: str | os.PathLike | None, optional
298
+ :param timeout: Worker process group timeout, defaults to 600
299
+ :type timeout: int, optional
295
300
:raises RuntimeError: May fail due to misconfiguration, or errors thrown by ``func``
296
301
:return: A dictionary mapping worker ranks to their output
297
302
:rtype: dict[int, Any]
@@ -304,4 +309,5 @@ def launch(
304
309
log_dir = log_dir ,
305
310
env_vars = env_vars ,
306
311
env_file = env_file ,
312
+ timeout = timeout ,
307
313
).run (func = func , func_kwargs = func_kwargs )
0 commit comments