Skip to content

Commit 0bd7964

Browse files
authored
Merge pull request #51 from apoorvkh/func-args
Func args
2 parents 1961e45 + 38d1627 commit 0bd7964

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

examples/slurm_poc.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
def test_launch():
1212
result = torchrunx.launch(
1313
func=simple_matmul,
14-
func_kwargs={},
1514
hostnames=torchrunx.slurm_hosts(),
1615
workers_per_host=torchrunx.slurm_workers(),
1716
)
@@ -22,7 +21,7 @@ def test_launch():
2221
print("PASS")
2322

2423

25-
def simple_matmul():
24+
def simple_matmul(test):
2625
rank = int(os.environ["RANK"])
2726
local_rank = int(os.environ["LOCAL_RANK"])
2827
device = torch.device(local_rank) if torch.cuda.is_available() else torch.device("cpu")
@@ -36,7 +35,7 @@ def simple_matmul():
3635

3736
i = torch.rand((500, 100), device=device) # batch, dim
3837
o = torch.matmul(i, w)
39-
38+
print(test)
4039
dist.all_reduce(o, op=dist.ReduceOp.SUM)
4140
print(i)
4241
return o.detach().cpu()

src/torchrunx/launcher.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,16 @@ class Launcher:
101101
def run(
102102
self,
103103
func: Callable,
104-
func_kwargs: dict[str, Any],
104+
func_args: tuple[Any] = tuple(),
105+
func_kwargs: dict[str, Any] = {},
105106
) -> dict[int, Any]:
106107
"""
107108
Launch a distributed PyTorch function on the specified nodes. See :mod:`torchrunx.launch`
108109
109110
:param func: The distributed function to call on all workers
110111
:type func: Callable
112+
:param func_args: Any positional arguments to be provided when calling ``func``
113+
:type func_args: tuple[Any]
111114
:param func_kwargs: Any keyword arguments to be provided when calling ``func``
112115
:type func_kwargs: dict[str, Any]
113116
:raises RuntimeError: May fail due to misconfiguration, or errors thrown by ``func``
@@ -204,7 +207,7 @@ def run(
204207
]
205208

206209
payload = LauncherPayload(
207-
fn=partial(func, **func_kwargs),
210+
fn=partial(func, *func_args, **func_kwargs),
208211
hostnames=self.hostnames,
209212
worker_world_size=worker_world_size,
210213
worker_global_ranks=worker_global_ranks,
@@ -255,7 +258,8 @@ def run(
255258

256259
def launch(
257260
func: Callable,
258-
func_kwargs: dict[str, Any],
261+
func_args: tuple[Any] = tuple(),
262+
func_kwargs: dict[str, Any] = {},
259263
hostnames: list[str] = ["localhost"],
260264
workers_per_host: int | list[int] = 1,
261265
ssh_config_file: str | os.PathLike | None = None,
@@ -279,6 +283,8 @@ def launch(
279283
280284
:param func: The distributed function to call on all workers
281285
:type func: Callable
286+
:param func_args: Any positional arguments to be provided when calling ``func``
287+
:type func_args: tuple[Any]
282288
:param func_kwargs: Any keyword arguments to be provided when calling ``func``
283289
:type func_kwargs: dict[str, Any]
284290
:param hostnames: A list of node hostnames to start workers on, defaults to ["localhost"]
@@ -310,4 +316,4 @@ def launch(
310316
env_vars=env_vars,
311317
env_file=env_file,
312318
timeout=timeout,
313-
).run(func=func, func_kwargs=func_kwargs)
319+
).run(func=func, func_args=func_args, func_kwargs=func_kwargs)

0 commit comments

Comments
 (0)