Skip to content

Commit 3a68eb6

Browse files
committed
removed overloading from LaunchResult
1 parent f967218 commit 3a68eb6

File tree

3 files changed

+15
-79
lines changed

3 files changed

+15
-79
lines changed

src/torchrunx/launcher.py

Lines changed: 13 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from multiprocessing import Process
2020
from operator import add
2121
from pathlib import Path
22-
from typing import Any, Callable, Literal, overload
22+
from typing import Any, Callable, Literal
2323

2424
import fabric
2525
import torch.distributed as dist
@@ -267,85 +267,21 @@ class LaunchResult:
267267
hostnames: list[str]
268268
return_values: list[list[Any]]
269269

270-
@overload
271-
def all(self, by: Literal["hostname"]) -> dict[str, list[Any]]:
272-
pass
270+
def by_hostname(self) -> dict[str, list[Any]]:
271+
"""All return values from workers, indexed by host and local rank."""
272+
return dict(zip(self.hostnames, self.return_values))
273273

274-
@overload
275-
def all(self, by: Literal["rank"]) -> list[Any]:
276-
pass
274+
def by_rank(self) -> list[Any]:
275+
"""All return values from workers, indexed by global rank."""
276+
return reduce(add, self.return_values)
277277

278-
def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[Any]] | list[Any]:
279-
"""Get return values from all workers."""
280-
if by == "hostname":
281-
return dict(zip(self.hostnames, self.return_values))
282-
elif by == "rank": # noqa: RET505
283-
return reduce(add, self.return_values)
284-
else:
285-
msg = "Invalid argument for 'by'. Must be 'hostname' or 'rank'."
286-
raise TypeError(msg)
278+
def get(self, hostname: str, rank: int) -> Any:
279+
"""Get return value from worker (indexed by host and local rank)."""
280+
return self.return_values[self.hostnames.index(hostname)][rank]
287281

288-
@overload
289-
def get(self, hostname: None, rank: None) -> dict[str, list[Any]]: ...
290-
291-
@overload
292-
def get(self, hostname: None, rank: int) -> Any: ...
293-
294-
@overload
295-
def get(self, hostname: None, rank: list[int]) -> list[Any]: ...
296-
297-
@overload
298-
def get(self, hostname: str, rank: None) -> list[Any]: ...
299-
300-
@overload
301-
def get(self, hostname: list[str], rank: None) -> dict[str, list[Any]]: ...
302-
303-
@overload
304-
def get(self, hostname: str, rank: int) -> Any: ...
305-
306-
@overload
307-
def get(self, hostname: str, rank: list[int]) -> list[Any]: ...
308-
309-
@overload
310-
def get(self, hostname: list[str], rank: int) -> list[Any]: ...
311-
312-
@overload
313-
def get(self, hostname: list[str], rank: list[int]) -> dict[str, list[Any]]: ...
314-
315-
def get( # noqa: PLR0911
316-
self,
317-
hostname: str | list[str] | None = None,
318-
rank: int | list[int] | None = None,
319-
) -> dict[str, list[Any]] | list[Any] | Any:
320-
"""Get return values from selected workers."""
321-
if hostname is None and isinstance(rank, int):
322-
return self.all(by="rank")[rank]
323-
324-
if hostname is None and isinstance(rank, list):
325-
_values = self.all(by="rank")
326-
return [_values[r] for r in rank]
327-
328-
if isinstance(hostname, str) and rank is None:
329-
self.return_values[self.hostnames.index(hostname)]
330-
331-
if isinstance(hostname, list) and rank is None:
332-
return {h: self.get(hostname=h) for h in hostname}
333-
334-
if isinstance(hostname, str) and isinstance(rank, int):
335-
return self.get(hostname=hostname)[rank]
336-
337-
if isinstance(hostname, str) and isinstance(rank, list):
338-
return self.get(hostname=hostname)[rank]
339-
340-
if isinstance(hostname, list) and isinstance(rank, int):
341-
return [self.get(hostname=h)[rank] for h in hostname]
342-
343-
if isinstance(hostname, list) and isinstance(rank, list):
344-
_values = self.get(hostname=hostname)
345-
return {h: [_values[h][r] for r in rank] for h in hostname}
346-
347-
# remaining case: hostname=None, rank=None
348-
return self.all(by="hostname")
282+
def rank(self, idx: int) -> Any:
283+
"""Get return value from worker (indexed by global rank)."""
284+
return self.by_rank()[idx]
349285

350286

351287
def _resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]:

tests/test_ci.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def dist_func() -> torch.Tensor:
3737
backend="gloo", # log_dir="./test_logs"
3838
)
3939

40-
assert torch.all(r.get(rank=0) == r.get(rank=1))
40+
assert torch.all(r.rank(0) == r.rank(1))
4141

4242

4343
def test_logging() -> None:

tests/test_func.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_launch() -> None:
1313
workers_per_host="slurm",
1414
)
1515

16-
result_values = result.all(by="rank")
16+
result_values = result.by_rank()
1717

1818
t = True
1919
for i in range(len(result_values)):

0 commit comments

Comments
 (0)