Skip to content

Commit 4b227de

Browse files
committed
updates to LaunchResult getters
1 parent 79fb0a8 commit 4b227de

File tree

2 files changed

+49
-26
lines changed

2 files changed

+49
-26
lines changed

src/torchrunx/launcher.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
import subprocess
1111
import sys
1212
from dataclasses import dataclass
13-
from functools import partial
13+
from functools import partial, reduce
1414
from logging import Handler
1515
from multiprocessing import Process
16+
from operator import add
1617
from pathlib import Path
1718
from typing import Any, Callable, Literal, overload
1819

@@ -279,7 +280,7 @@ def run( # noqa: C901, PLR0912
279280

280281
# raises specific exception if any agent fails
281282
for s in agent_statuses:
282-
for value in s.return_values.values():
283+
for value in s.return_values:
283284
if isinstance(value, WorkerException):
284285
raise value.exception
285286

@@ -374,22 +375,44 @@ def launch(
374375

375376
class LaunchResult:
376377
def __init__(self, hostnames: list[str], agent_statuses: list[AgentStatus]) -> None:
377-
self.results = {
378-
hostname: agent_status.return_values
379-
for hostname, agent_status in zip(hostnames, agent_statuses)
380-
}
378+
self.hostnames: list[str] = hostnames
379+
self.return_values: list[list[Any]] = [s.return_values for s in agent_statuses]
381380

381+
@overload
382382
def all(self) -> dict[str, list[Any]]:
383-
return self.results
384-
385-
# all(by='rank')
386-
387-
# value(rank: int)
383+
pass
388384

389385
@overload
390-
def value(self, hostname: str) -> list[Any]:
391-
return list(self.results[hostname].values())
386+
def all(self, by: Literal["hostname"]) -> dict[str, list[Any]]:
387+
pass
392388

393389
@overload
394-
def value(self, hostname: str, rank: int) -> Any:
395-
return self.results[hostname][rank]
390+
def all(self, by: Literal["rank"]) -> list[Any]:
391+
pass
392+
393+
def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[Any]] | list[Any]:
394+
if by == "hostname":
395+
return dict(zip(self.hostnames, self.return_values))
396+
elif by == "rank": # noqa: RET505
397+
return reduce(add, self.return_values)
398+
399+
msg = "Invalid argument: expected by=('hostname' | 'rank')"
400+
raise TypeError(msg)
401+
402+
def values(self, hostname: str) -> list[Any]:
403+
host_idx = self.hostnames.index(hostname)
404+
return self.return_values[host_idx]
405+
406+
def value(self, rank: int) -> Any:
407+
if rank < 0:
408+
msg = f"Rank {rank} must be larger than 0"
409+
raise ValueError(msg)
410+
411+
for values in self.return_values:
412+
if rank >= len(values):
413+
rank -= len(values)
414+
else:
415+
return values[rank]
416+
417+
msg = f"Rank {rank} larger than world_size"
418+
raise ValueError(msg)

src/torchrunx/utils.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,6 @@ def get_open_port() -> int:
2020
return s.getsockname()[1]
2121

2222

23-
@dataclass
24-
class WorkerException:
25-
exception: Exception
26-
27-
2823
@dataclass
2924
class LauncherAgentGroup:
3025
launcher_hostname: str
@@ -94,22 +89,27 @@ class AgentPayload:
9489
process_id: int
9590

9691

92+
@dataclass
93+
class WorkerException:
94+
exception: Exception
95+
96+
9797
@dataclass
9898
class AgentStatus:
9999
state: Literal["running", "failed", "done"]
100-
return_values: dict[int, Any | WorkerException] = field(default_factory=dict)
100+
return_values: list[Any | WorkerException] = field(
101+
default_factory=list
102+
) # indexed by local rank
101103

102104
@classmethod
103105
def from_result(cls, result: RunProcsResult | None) -> Self:
104106
if result is None:
105107
return cls(state="running")
106108

107-
return_values = result.return_values
109+
return_values = list(result.return_values.values())
108110

109-
if any(isinstance(v, WorkerException) for v in return_values.values()):
110-
state = "failed"
111-
else:
112-
state = "done"
111+
failed = any(isinstance(v, WorkerException) for v in return_values)
112+
state = "failed" if failed else "done"
113113

114114
return cls(
115115
state=state,

0 commit comments

Comments
 (0)