|
10 | 10 | import subprocess
|
11 | 11 | import sys
|
12 | 12 | from dataclasses import dataclass
|
13 |
| -from functools import partial |
| 13 | +from functools import partial, reduce |
14 | 14 | from logging import Handler
|
15 | 15 | from multiprocessing import Process
|
| 16 | +from operator import add |
16 | 17 | from pathlib import Path
|
17 | 18 | from typing import Any, Callable, Literal, overload
|
18 | 19 |
|
@@ -279,7 +280,7 @@ def run( # noqa: C901, PLR0912
|
279 | 280 |
|
280 | 281 | # raises specific exception if any agent fails
|
281 | 282 | for s in agent_statuses:
|
282 |
| - for value in s.return_values.values(): |
| 283 | + for value in s.return_values: |
283 | 284 | if isinstance(value, WorkerException):
|
284 | 285 | raise value.exception
|
285 | 286 |
|
@@ -374,22 +375,44 @@ def launch(
|
374 | 375 |
|
375 | 376 | class LaunchResult:
|
376 | 377 | def __init__(self, hostnames: list[str], agent_statuses: list[AgentStatus]) -> None:
|
377 |
| - self.results = { |
378 |
| - hostname: agent_status.return_values |
379 |
| - for hostname, agent_status in zip(hostnames, agent_statuses) |
380 |
| - } |
| 378 | + self.hostnames: list[str] = hostnames |
| 379 | + self.return_values: list[list[Any]] = [s.return_values for s in agent_statuses] |
381 | 380 |
|
| 381 | + @overload |
382 | 382 | def all(self) -> dict[str, list[Any]]:
|
383 |
| - return self.results |
384 |
| - |
385 |
| - # all(by='rank') |
386 |
| - |
387 |
| - # value(rank: int) |
| 383 | + pass |
388 | 384 |
|
389 | 385 | @overload
|
390 |
| - def value(self, hostname: str) -> list[Any]: |
391 |
| - return list(self.results[hostname].values()) |
| 386 | + def all(self, by: Literal["hostname"]) -> dict[str, list[Any]]: |
| 387 | + pass |
392 | 388 |
|
393 | 389 | @overload
|
394 |
| - def value(self, hostname: str, rank: int) -> Any: |
395 |
| - return self.results[hostname][rank] |
| 390 | + def all(self, by: Literal["rank"]) -> list[Any]: |
| 391 | + pass |
| 392 | + |
| 393 | + def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[Any]] | list[Any]: |
| 394 | + if by == "hostname": |
| 395 | + return dict(zip(self.hostnames, self.return_values)) |
| 396 | + elif by == "rank": # noqa: RET505 |
| 397 | + return reduce(add, self.return_values) |
| 398 | + |
| 399 | + msg = "Invalid argument: expected by=('hostname' | 'rank')" |
| 400 | + raise TypeError(msg) |
| 401 | + |
| 402 | + def values(self, hostname: str) -> list[Any]: |
| 403 | + host_idx = self.hostnames.index(hostname) |
| 404 | + return self.return_values[host_idx] |
| 405 | + |
| 406 | + def value(self, rank: int) -> Any: |
| 407 | + if rank < 0: |
| 408 | + msg = f"Rank {rank} must be larger than 0" |
| 409 | + raise ValueError(msg) |
| 410 | + |
| 411 | + for values in self.return_values: |
| 412 | + if rank >= len(values): |
| 413 | + rank -= len(values) |
| 414 | + else: |
| 415 | + return values[rank] |
| 416 | + |
| 417 | + msg = f"Rank {rank} larger than world_size" |
| 418 | + raise ValueError(msg) |
0 commit comments