diff --git a/examples/slurm_poc.py b/examples/slurm_poc.py index 3a11ea75..c91bc060 100644 --- a/examples/slurm_poc.py +++ b/examples/slurm_poc.py @@ -11,7 +11,6 @@ def test_launch(): result = torchrunx.launch( func=simple_matmul, - func_kwargs={}, hostnames=torchrunx.slurm_hosts(), workers_per_host=torchrunx.slurm_workers(), ) @@ -22,7 +21,7 @@ def test_launch(): print("PASS") -def simple_matmul(): +def simple_matmul(test): rank = int(os.environ["RANK"]) local_rank = int(os.environ["LOCAL_RANK"]) device = torch.device(local_rank) if torch.cuda.is_available() else torch.device("cpu") @@ -36,7 +35,7 @@ def simple_matmul(): i = torch.rand((500, 100), device=device) # batch, dim o = torch.matmul(i, w) - + print(test) dist.all_reduce(o, op=dist.ReduceOp.SUM) print(i) return o.detach().cpu() diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 346f303b..94219ea1 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -101,13 +101,16 @@ class Launcher: def run( self, func: Callable, - func_kwargs: dict[str, Any], + func_args: tuple[Any] = tuple(), + func_kwargs: dict[str, Any] = {}, ) -> dict[int, Any]: """ Launch a distributed PyTorch function on the specified nodes. See :mod:`torchrunx.launch` :param func: The distributed function to call on all workers :type func: Callable + :param func_args: Any positional arguments to be provided when calling ``func`` + :type func_args: tuple[Any] :param func_kwargs: Any keyword arguments to be provided when calling ``func`` :type func_kwargs: dict[str, Any] :raises RuntimeError: May fail due to misconfiguration, or errors thrown by ``func`` @@ -204,7 +207,7 @@ def run( ] payload = LauncherPayload( - fn=partial(func, **func_kwargs), + fn=partial(func, *func_args, **func_kwargs), hostnames=self.hostnames, worker_world_size=worker_world_size, worker_global_ranks=worker_global_ranks, @@ -255,7 +258,8 @@ def run( def launch( func: Callable, - func_kwargs: dict[str, Any], + func_args: tuple[Any] = tuple(), + func_kwargs: dict[str, Any] = {}, hostnames: list[str] = ["localhost"], workers_per_host: int | list[int] = 1, ssh_config_file: str | os.PathLike | None = None, @@ -279,6 +283,8 @@ def launch( :param func: The distributed function to call on all workers :type func: Callable + :param func_args: Any positional arguments to be provided when calling ``func`` + :type func_args: tuple[Any] :param func_kwargs: Any keyword arguments to be provided when calling ``func`` :type func_kwargs: dict[str, Any] :param hostnames: A list of node hostnames to start workers on, defaults to ["localhost"] @@ -310,4 +316,4 @@ def launch( env_vars=env_vars, env_file=env_file, timeout=timeout, - ).run(func=func, func_kwargs=func_kwargs) + ).run(func=func, func_args=func_args, func_kwargs=func_kwargs)