Skip to content

Commit a080dad

Browse files
authored
Merge branch 'main' into automatic
2 parents 74b9709 + 0bd7964 commit a080dad

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

examples/slurm_poc.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
def test_launch():
1212
result = torchrunx.launch(
1313
func=simple_matmul,
14-
func_kwargs={},
1514
hostnames=torchrunx.slurm_hosts(),
1615
workers_per_host=torchrunx.slurm_workers(),
1716
)
@@ -22,7 +21,7 @@ def test_launch():
2221
print("PASS")
2322

2423

25-
def simple_matmul():
24+
def simple_matmul(test):
2625
rank = int(os.environ["RANK"])
2726
local_rank = int(os.environ["LOCAL_RANK"])
2827
device = torch.device(local_rank) if torch.cuda.is_available() else torch.device("cpu")
@@ -36,7 +35,7 @@ def simple_matmul():
3635

3736
i = torch.rand((500, 100), device=device) # batch, dim
3837
o = torch.matmul(i, w)
39-
38+
print(test)
4039
dist.all_reduce(o, op=dist.ReduceOp.SUM)
4140
print(i)
4241
return o.detach().cpu()

src/torchrunx/launcher.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,16 @@ class Launcher:
103103
def run(
104104
self,
105105
func: Callable,
106-
func_kwargs: dict[str, Any],
106+
func_args: tuple[Any] = tuple(),
107+
func_kwargs: dict[str, Any] = {},
107108
) -> dict[int, Any]:
108109
"""
109110
Launch a distributed PyTorch function on the specified nodes. See :mod:`torchrunx.launch`
110111
111112
:param func: The distributed function to call on all workers
112113
:type func: Callable
114+
:param func_args: Any positional arguments to be provided when calling ``func``
115+
:type func_args: tuple[Any]
113116
:param func_kwargs: Any keyword arguments to be provided when calling ``func``
114117
:type func_kwargs: dict[str, Any]
115118
:raises RuntimeError: May fail due to misconfiguration, or errors thrown by ``func``
@@ -213,7 +216,7 @@ def run(
213216
]
214217

215218
payload = LauncherPayload(
216-
fn=partial(func, **func_kwargs),
219+
fn=partial(func, *func_args, **func_kwargs),
217220
hostnames=self.hostnames,
218221
worker_world_size=worker_world_size,
219222
worker_global_ranks=worker_global_ranks,
@@ -264,7 +267,8 @@ def run(
264267

265268
def launch(
266269
func: Callable,
267-
func_kwargs: dict[str, Any],
270+
func_args: tuple[Any] = tuple(),
271+
func_kwargs: dict[str, Any] = {},
268272
auto: bool = False,
269273
hostnames: list[str] | None = ["localhost"],
270274
workers_per_host: int | list[int] | None = 1,
@@ -289,6 +293,8 @@ def launch(
289293
290294
:param func: The distributed function to call on all workers
291295
:type func: Callable
296+
:param func_args: Any positional arguments to be provided when calling ``func``
297+
:type func_args: tuple[Any]
292298
:param func_kwargs: Any keyword arguments to be provided when calling ``func``
293299
:type func_kwargs: dict[str, Any]
294300
:param auto: Automatically determine allocation sizes, supports Slurm allocation. ``hostnames`` and ``workers_per_host`` are automatically assigned if they're set to ``None``, defaults to None
@@ -323,4 +329,4 @@ def launch(
323329
env_vars=env_vars,
324330
env_file=env_file,
325331
timeout=timeout,
326-
).run(func=func, func_kwargs=func_kwargs)
332+
).run(func=func, func_args=func_args, func_kwargs=func_kwargs)

0 commit comments

Comments
 (0)