|
19 | 19 | from multiprocessing import Process
|
20 | 20 | from operator import add
|
21 | 21 | from pathlib import Path
|
22 |
| -from typing import Any, Callable, Literal, overload |
| 22 | +from typing import Any, Callable, Literal |
23 | 23 |
|
24 | 24 | import fabric
|
25 | 25 | import torch.distributed as dist
|
@@ -267,85 +267,21 @@ class LaunchResult:
|
267 | 267 | hostnames: list[str]
|
268 | 268 | return_values: list[list[Any]]
|
269 | 269 |
|
270 |
| - @overload |
271 |
| - def all(self, by: Literal["hostname"]) -> dict[str, list[Any]]: |
272 |
| - pass |
| 270 | + def by_hostname(self) -> dict[str, list[Any]]: |
| 271 | + """All return values from workers, indexed by host and local rank.""" |
| 272 | + return dict(zip(self.hostnames, self.return_values)) |
273 | 273 |
|
274 |
| - @overload |
275 |
| - def all(self, by: Literal["rank"]) -> list[Any]: |
276 |
| - pass |
| 274 | + def by_rank(self) -> list[Any]: |
| 275 | + """All return values from workers, indexed by global rank.""" |
| 276 | + return reduce(add, self.return_values) |
277 | 277 |
|
278 |
| - def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[Any]] | list[Any]: |
279 |
| - """Get return values from all workers.""" |
280 |
| - if by == "hostname": |
281 |
| - return dict(zip(self.hostnames, self.return_values)) |
282 |
| - elif by == "rank": # noqa: RET505 |
283 |
| - return reduce(add, self.return_values) |
284 |
| - else: |
285 |
| - msg = "Invalid argument for 'by'. Must be 'hostname' or 'rank'." |
286 |
| - raise TypeError(msg) |
| 278 | + def get(self, hostname: str, rank: int) -> Any: |
| 279 | + """Get return value from worker (indexed by host and local rank).""" |
| 280 | + return self.return_values[self.hostnames.index(hostname)][rank] |
287 | 281 |
|
288 |
| - @overload |
289 |
| - def get(self, hostname: None, rank: None) -> dict[str, list[Any]]: ... |
290 |
| - |
291 |
| - @overload |
292 |
| - def get(self, hostname: None, rank: int) -> Any: ... |
293 |
| - |
294 |
| - @overload |
295 |
| - def get(self, hostname: None, rank: list[int]) -> list[Any]: ... |
296 |
| - |
297 |
| - @overload |
298 |
| - def get(self, hostname: str, rank: None) -> list[Any]: ... |
299 |
| - |
300 |
| - @overload |
301 |
| - def get(self, hostname: list[str], rank: None) -> dict[str, list[Any]]: ... |
302 |
| - |
303 |
| - @overload |
304 |
| - def get(self, hostname: str, rank: int) -> Any: ... |
305 |
| - |
306 |
| - @overload |
307 |
| - def get(self, hostname: str, rank: list[int]) -> list[Any]: ... |
308 |
| - |
309 |
| - @overload |
310 |
| - def get(self, hostname: list[str], rank: int) -> list[Any]: ... |
311 |
| - |
312 |
| - @overload |
313 |
| - def get(self, hostname: list[str], rank: list[int]) -> dict[str, list[Any]]: ... |
314 |
| - |
315 |
| - def get( # noqa: PLR0911 |
316 |
| - self, |
317 |
| - hostname: str | list[str] | None = None, |
318 |
| - rank: int | list[int] | None = None, |
319 |
| - ) -> dict[str, list[Any]] | list[Any] | Any: |
320 |
| - """Get return values from selected workers.""" |
321 |
| - if hostname is None and isinstance(rank, int): |
322 |
| - return self.all(by="rank")[rank] |
323 |
| - |
324 |
| - if hostname is None and isinstance(rank, list): |
325 |
| - _values = self.all(by="rank") |
326 |
| - return [_values[r] for r in rank] |
327 |
| - |
328 |
| - if isinstance(hostname, str) and rank is None: |
329 |
| - self.return_values[self.hostnames.index(hostname)] |
330 |
| - |
331 |
| - if isinstance(hostname, list) and rank is None: |
332 |
| - return {h: self.get(hostname=h) for h in hostname} |
333 |
| - |
334 |
| - if isinstance(hostname, str) and isinstance(rank, int): |
335 |
| - return self.get(hostname=hostname)[rank] |
336 |
| - |
337 |
| - if isinstance(hostname, str) and isinstance(rank, list): |
338 |
| - return self.get(hostname=hostname)[rank] |
339 |
| - |
340 |
| - if isinstance(hostname, list) and isinstance(rank, int): |
341 |
| - return [self.get(hostname=h)[rank] for h in hostname] |
342 |
| - |
343 |
| - if isinstance(hostname, list) and isinstance(rank, list): |
344 |
| - _values = self.get(hostname=hostname) |
345 |
| - return {h: [_values[h][r] for r in rank] for h in hostname} |
346 |
| - |
347 |
| - # remaining case: hostname=None, rank=None |
348 |
| - return self.all(by="hostname") |
| 282 | + def rank(self, idx: int) -> Any: |
| 283 | + """Get return value from worker (indexed by global rank).""" |
| 284 | + return self.by_rank()[idx] |
349 | 285 |
|
350 | 286 |
|
351 | 287 | def _resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]:
|
|
0 commit comments