@@ -101,13 +101,16 @@ class Launcher:
101
101
def run (
102
102
self ,
103
103
func : Callable ,
104
- func_kwargs : dict [str , Any ],
104
+ func_args : tuple [Any ] = tuple (),
105
+ func_kwargs : dict [str , Any ] = {},
105
106
) -> dict [int , Any ]:
106
107
"""
107
108
Launch a distributed PyTorch function on the specified nodes. See :mod:`torchrunx.launch`
108
109
109
110
:param func: The distributed function to call on all workers
110
111
:type func: Callable
112
+ :param func_args: Any positional arguments to be provided when calling ``func``
113
+ :type func_args: tuple[Any]
111
114
:param func_kwargs: Any keyword arguments to be provided when calling ``func``
112
115
:type func_kwargs: dict[str, Any]
113
116
:raises RuntimeError: May fail due to misconfiguration, or errors thrown by ``func``
@@ -204,7 +207,7 @@ def run(
204
207
]
205
208
206
209
payload = LauncherPayload (
207
- fn = partial (func , ** func_kwargs ),
210
+ fn = partial (func , * func_args , * *func_kwargs ),
208
211
hostnames = self .hostnames ,
209
212
worker_world_size = worker_world_size ,
210
213
worker_global_ranks = worker_global_ranks ,
@@ -255,7 +258,8 @@ def run(
255
258
256
259
def launch (
257
260
func : Callable ,
258
- func_kwargs : dict [str , Any ],
261
+ func_args : tuple [Any ] = tuple (),
262
+ func_kwargs : dict [str , Any ] = {},
259
263
hostnames : list [str ] = ["localhost" ],
260
264
workers_per_host : int | list [int ] = 1 ,
261
265
ssh_config_file : str | os .PathLike | None = None ,
@@ -279,6 +283,8 @@ def launch(
279
283
280
284
:param func: The distributed function to call on all workers
281
285
:type func: Callable
286
+ :param func_args: Any positional arguments to be provided when calling ``func``
287
+ :type func_args: tuple[Any]
282
288
:param func_kwargs: Any keyword arguments to be provided when calling ``func``
283
289
:type func_kwargs: dict[str, Any]
284
290
:param hostnames: A list of node hostnames to start workers on, defaults to ["localhost"]
@@ -310,4 +316,4 @@ def launch(
310
316
env_vars = env_vars ,
311
317
env_file = env_file ,
312
318
timeout = timeout ,
313
- ).run (func = func , func_kwargs = func_kwargs )
319
+ ).run (func = func , func_args = func_args , func_kwargs = func_kwargs )
0 commit comments