You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: src/torchrunx/launcher.py
+10-4Lines changed: 10 additions & 4 deletions
Original file line number
Diff line number
Diff line change
@@ -103,13 +103,16 @@ class Launcher:
103
103
defrun(
104
104
self,
105
105
func: Callable,
106
-
func_kwargs: dict[str, Any],
106
+
func_args: tuple[Any] =tuple(),
107
+
func_kwargs: dict[str, Any] = {},
107
108
) ->dict[int, Any]:
108
109
"""
109
110
Launch a distributed PyTorch function on the specified nodes. See :mod:`torchrunx.launch`
110
111
111
112
:param func: The distributed function to call on all workers
112
113
:type func: Callable
114
+
:param func_args: Any positional arguments to be provided when calling ``func``
115
+
:type func_args: tuple[Any]
113
116
:param func_kwargs: Any keyword arguments to be provided when calling ``func``
114
117
:type func_kwargs: dict[str, Any]
115
118
:raises RuntimeError: May fail due to misconfiguration, or errors thrown by ``func``
@@ -213,7 +216,7 @@ def run(
213
216
]
214
217
215
218
payload=LauncherPayload(
216
-
fn=partial(func, **func_kwargs),
219
+
fn=partial(func, *func_args, **func_kwargs),
217
220
hostnames=self.hostnames,
218
221
worker_world_size=worker_world_size,
219
222
worker_global_ranks=worker_global_ranks,
@@ -264,7 +267,8 @@ def run(
264
267
265
268
deflaunch(
266
269
func: Callable,
267
-
func_kwargs: dict[str, Any],
270
+
func_args: tuple[Any] =tuple(),
271
+
func_kwargs: dict[str, Any] = {},
268
272
auto: bool=False,
269
273
hostnames: list[str] |None= ["localhost"],
270
274
workers_per_host: int|list[int] |None=1,
@@ -289,6 +293,8 @@ def launch(
289
293
290
294
:param func: The distributed function to call on all workers
291
295
:type func: Callable
296
+
:param func_args: Any positional arguments to be provided when calling ``func``
297
+
:type func_args: tuple[Any]
292
298
:param func_kwargs: Any keyword arguments to be provided when calling ``func``
293
299
:type func_kwargs: dict[str, Any]
294
300
:param auto: Automatically determine allocation sizes, supports Slurm allocation. ``hostnames`` and ``workers_per_host`` are automatically assigned if they're set to ``None``, defaults to None
0 commit comments