Skip to content

Commit 61954a0

Browse files
authored
Added func_args to launch
1 parent 1961e45 commit 61954a0

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

src/torchrunx/launcher.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,16 @@ class Launcher:
101101
def run(
102102
self,
103103
func: Callable,
104-
func_kwargs: dict[str, Any],
104+
func_args: tuple[Any] = (),
105+
func_kwargs: dict[str, Any] = {},
105106
) -> dict[int, Any]:
106107
"""
107108
Launch a distributed PyTorch function on the specified nodes. See :mod:`torchrunx.launch`
108109
109110
:param func: The distributed function to call on all workers
110111
:type func: Callable
112+
:param func_args: Any positional arguments to be provided when calling ``func``
113+
:type func_args: tuple[Any]
111114
:param func_kwargs: Any keyword arguments to be provided when calling ``func``
112115
:type func_kwargs: dict[str, Any]
113116
:raises RuntimeError: May fail due to misconfiguration, or errors thrown by ``func``
@@ -204,7 +207,7 @@ def run(
204207
]
205208

206209
payload = LauncherPayload(
207-
fn=partial(func, **func_kwargs),
210+
fn=partial(func, *func_args, **func_kwargs),
208211
hostnames=self.hostnames,
209212
worker_world_size=worker_world_size,
210213
worker_global_ranks=worker_global_ranks,
@@ -255,7 +258,8 @@ def run(
255258

256259
def launch(
257260
func: Callable,
258-
func_kwargs: dict[str, Any],
261+
func_args: tuple[Any] = (),
262+
func_kwargs: dict[str, Any] = {},
259263
hostnames: list[str] = ["localhost"],
260264
workers_per_host: int | list[int] = 1,
261265
ssh_config_file: str | os.PathLike | None = None,
@@ -279,6 +283,8 @@ def launch(
279283
280284
:param func: The distributed function to call on all workers
281285
:type func: Callable
286+
:param func_args: Any positional arguments to be provided when calling ``func``
287+
:type func_args: tuple[Any]
282288
:param func_kwargs: Any keyword arguments to be provided when calling ``func``
283289
:type func_kwargs: dict[str, Any]
284290
:param hostnames: A list of node hostnames to start workers on, defaults to ["localhost"]
@@ -310,4 +316,4 @@ def launch(
310316
env_vars=env_vars,
311317
env_file=env_file,
312318
timeout=timeout,
313-
).run(func=func, func_kwargs=func_kwargs)
319+
).run(func=func, func_args=func_args, func_kwargs=func_kwargs)

0 commit comments

Comments
 (0)