Skip to content

Commit 81f5e91

Browse files
authored
Merge pull request #60 from apoorvkh/exception-propagation
2 parents b9110d6 + aef0aa7 commit 81f5e91

File tree

5 files changed

+83
-90
lines changed

5 files changed

+83
-90
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"]
4141
line-length = 100
4242
src = ["src", "tests"]
4343
[tool.ruff.lint]
44-
extend-select = ["I"]
44+
select = ["E", "F", "B", "UP", "I"]
4545

4646
[tool.pyright]
4747
include = ["src", "tests"]

src/torchrunx/agent.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sys
88
import tempfile
99
from dataclasses import dataclass
10-
from typing import Callable, Literal
10+
from typing import Any, Callable, Literal
1111

1212
import cloudpickle
1313
import torch
@@ -20,7 +20,7 @@
2020
AgentPayload,
2121
AgentStatus,
2222
LauncherAgentGroup,
23-
LauncherPayload,
23+
WorkerException,
2424
get_open_port,
2525
)
2626

@@ -48,7 +48,7 @@ def from_bytes(cls, serialized: bytes) -> Self:
4848
return cloudpickle.loads(serialized)
4949

5050

51-
def entrypoint(serialized_worker_args: bytes):
51+
def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException:
5252
worker_args = WorkerArgs.from_bytes(serialized_worker_args)
5353

5454
logger = logging.getLogger()
@@ -93,13 +93,14 @@ def entrypoint(serialized_worker_args: bytes):
9393

9494
logger.debug(f"executing function: {worker_args.function}")
9595

96-
r = worker_args.function()
97-
98-
# flush streams
99-
sys.stdout.flush()
100-
sys.stderr.flush()
101-
102-
return r
96+
try:
97+
return worker_args.function()
98+
except Exception as e:
99+
logger.error(e)
100+
return WorkerException(exception=e)
101+
finally:
102+
sys.stdout.flush()
103+
sys.stderr.flush()
103104

104105

105106
def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int):
@@ -111,9 +112,8 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
111112
process_id=os.getpid(),
112113
)
113114

114-
all_payloads = launcher_agent_group.sync_payloads(payload=payload)
115-
launcher_payload: LauncherPayload = all_payloads[0] # pyright: ignore[reportAssignmentType]
116-
main_agent_payload: AgentPayload = all_payloads[1] # pyright: ignore[reportAssignmentType]
115+
launcher_payload, agent_payloads = launcher_agent_group.sync_payloads(payload=payload)
116+
main_agent_payload = agent_payloads[0]
117117

118118
hostname = launcher_payload.hostnames[agent_rank]
119119
worker_world_size = launcher_payload.worker_world_size
@@ -169,20 +169,19 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
169169
logger.info("starting processes")
170170

171171
try:
172-
status = AgentStatus()
172+
status = None
173173
while True:
174-
if status.is_running():
174+
if status is None or status.state == "running":
175175
status = AgentStatus.from_result(
176176
result=ctx.wait(5), worker_global_ranks=worker_global_ranks
177177
)
178178

179179
agent_statuses = launcher_agent_group.sync_agent_statuses(status=status)
180180

181-
if all(s.is_done() for s in agent_statuses):
181+
if all(s.state == "done" for s in agent_statuses):
182+
break
183+
elif any(s.state == "failed" for s in agent_statuses):
182184
break
183-
184-
if any(s.is_failed() for s in agent_statuses):
185-
raise RuntimeError()
186185
except:
187186
raise
188187
finally:

src/torchrunx/launcher.py

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,21 @@
99
import subprocess
1010
import sys
1111
from collections import ChainMap
12-
from dataclasses import dataclass, field
12+
from dataclasses import dataclass
1313
from functools import partial
1414
from logging import Handler
1515
from multiprocessing import Process
16-
from typing import Any, Callable, Literal
16+
from typing import Any, Callable, Literal, Sequence
1717

1818
import fabric
1919
import torch.distributed as dist
2020

2121
from .environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers
2222
from .logging_utils import LogRecordSocketReceiver, default_handlers
2323
from .utils import (
24-
AgentPayload,
25-
AgentStatus,
2624
LauncherAgentGroup,
2725
LauncherPayload,
26+
WorkerException,
2827
get_open_port,
2928
)
3029

@@ -59,31 +58,29 @@ def execute_command(
5958

6059
@dataclass
6160
class Launcher:
62-
hostnames: list[str] | Literal["auto", "slurm"] = field(default_factory=lambda: ["localhost"])
63-
workers_per_host: int | list[int] | Literal["auto", "slurm"] = 1
61+
hostnames: list[str] | Literal["auto", "slurm"] = "auto"
62+
workers_per_host: int | list[int] | Literal["auto", "slurm"] = "auto"
6463
ssh_config_file: str | os.PathLike | None = None
6564
backend: Literal["mpi", "gloo", "nccl", "ucc", None] = None
6665
log_handlers: list[Handler] | Literal["auto"] | None = "auto"
67-
env_vars: list[str] = field(
68-
default_factory=lambda: [
69-
"PATH",
70-
"LD_LIBRARY",
71-
"LIBRARY_PATH",
72-
"PYTHON*",
73-
"CUDA*",
74-
"TORCH*",
75-
"PYTORCH*",
76-
"NCCL*",
77-
]
66+
env_vars: Sequence[str] = (
67+
"PATH",
68+
"LD_LIBRARY",
69+
"LIBRARY_PATH",
70+
"PYTHON*",
71+
"CUDA*",
72+
"TORCH*",
73+
"PYTORCH*",
74+
"NCCL*",
7875
)
7976
env_file: str | os.PathLike | None = None
8077
timeout: int = 600
8178

8279
def run(
8380
self,
8481
func: Callable,
85-
func_args: tuple[Any] = tuple(),
86-
func_kwargs: dict[str, Any] = {},
82+
func_args: tuple[Any] | None = None,
83+
func_kwargs: dict[str, Any] | None = None,
8784
) -> dict[int, Any]:
8885
"""
8986
Launch a distributed PyTorch function on the specified nodes. See :mod:`torchrunx.launch`
@@ -205,6 +202,11 @@ def run(
205202
host_ranks = range(_cumulative_workers[n], _cumulative_workers[n + 1])
206203
worker_global_ranks.append(list(host_ranks))
207204

205+
if func_args is None:
206+
func_args = tuple()
207+
if func_kwargs is None:
208+
func_kwargs = dict()
209+
208210
payload = LauncherPayload(
209211
fn=partial(func, *func_args, **func_kwargs),
210212
hostnames=self.hostnames,
@@ -214,30 +216,23 @@ def run(
214216
timeout=self.timeout,
215217
)
216218

217-
agent_payloads: list[AgentPayload] = launcher_agent_group.sync_payloads(payload=payload)[1:] # pyright: ignore[reportAssignmentType]
219+
launcher_payload, agent_payloads = launcher_agent_group.sync_payloads(payload=payload)
218220
agent_pids = [p.process_id for p in agent_payloads]
219221

220222
# loop to monitor agent statuses (until failed or done)
221223
try:
222224
while True:
223-
agent_statuses = launcher_agent_group.sync_agent_statuses(status=AgentStatus())
225+
agent_statuses = launcher_agent_group.sync_agent_statuses(status=None)
226+
227+
for s in agent_statuses:
228+
if s.state == "failed":
229+
for value in s.return_values.values():
230+
if isinstance(value, WorkerException):
231+
raise value.exception
224232

225-
if all(s.is_done() for s in agent_statuses):
233+
if all(s.state == "done" for s in agent_statuses):
226234
break
227235

228-
if any(s.is_failed() for s in agent_statuses):
229-
# TODO: cleaner way to print these?
230-
e = ""
231-
for i, s in enumerate(agent_statuses):
232-
if s is not None and s.is_failed():
233-
for k, v in s.failures.items():
234-
e += f"Node {i}, local worker {k} exited with error: "
235-
if isinstance(v.message, str):
236-
e += f"{v.message}\n"
237-
else:
238-
e += f"{v.message['message']}\n"
239-
e += f"{v.message['extraInfo']['py_callstack']}\n\n"
240-
raise RuntimeError(e)
241236
except:
242237
# cleanup: SIGTERM all agents
243238
for agent_pid, agent_hostname in zip(agent_pids, self.hostnames):
@@ -259,14 +254,14 @@ def run(
259254

260255
def launch(
261256
func: Callable,
262-
func_args: tuple[Any] = tuple(),
263-
func_kwargs: dict[str, Any] = {},
264-
hostnames: list[str] | Literal["auto", "slurm"] = ["localhost"],
265-
workers_per_host: int | list[int] | Literal["auto", "slurm"] = 1,
257+
func_args: tuple[Any] | None = None,
258+
func_kwargs: dict[str, Any] | None = None,
259+
hostnames: list[str] | Literal["auto", "slurm"] = "auto",
260+
workers_per_host: int | list[int] | Literal["auto", "slurm"] = "auto",
266261
ssh_config_file: str | os.PathLike | None = None,
267262
backend: Literal["mpi", "gloo", "nccl", "ucc", None] = None,
268263
log_handlers: list[Handler] | Literal["auto"] = "auto",
269-
env_vars: list[str] = [
264+
env_vars: Sequence[str] = (
270265
"PATH",
271266
"LD_LIBRARY",
272267
"LIBRARY_PATH",
@@ -275,7 +270,7 @@ def launch(
275270
"TORCH*",
276271
"PYTORCH*",
277272
"NCCL*",
278-
],
273+
),
279274
env_file: str | os.PathLike | None = None,
280275
timeout: int = 600,
281276
) -> dict[int, Any]:

src/torchrunx/utils.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import cloudpickle
1010
import torch.distributed as dist
1111
from torch.distributed.elastic.multiprocessing.api import RunProcsResult
12-
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure
1312
from typing_extensions import Self
1413

1514

@@ -20,6 +19,11 @@ def get_open_port() -> int:
2019
return port
2120

2221

22+
@dataclass
23+
class WorkerException:
24+
exception: Exception
25+
26+
2327
@dataclass
2428
class LauncherPayload:
2529
fn: Callable
@@ -39,33 +43,25 @@ class AgentPayload:
3943

4044
@dataclass
4145
class AgentStatus:
42-
running: bool = True
43-
failed: bool = False
44-
return_values: dict[int, Any] = field(default_factory=dict)
45-
failures: dict[int, ProcessFailure] = field(default_factory=dict)
46-
stdouts: dict[int, str] = field(default_factory=dict)
47-
stderrs: dict[int, str] = field(default_factory=dict)
46+
state: Literal["running", "failed", "done"]
47+
return_values: dict[int, Any | WorkerException] = field(default_factory=dict)
4848

4949
@classmethod
5050
def from_result(cls, result: RunProcsResult | None, worker_global_ranks: list[int]) -> Self:
5151
if result is None:
52-
return cls()
52+
return cls(state="running")
5353

54-
return cls(
55-
running=False,
56-
failed=result.is_failed(),
57-
return_values={worker_global_ranks[k]: v for k, v in result.return_values.items()},
58-
failures={worker_global_ranks[k]: v for k, v in result.failures.items()},
59-
)
54+
return_values = result.return_values
6055

61-
def is_running(self) -> bool:
62-
return self.running
56+
if any(isinstance(v, WorkerException) for v in return_values.values()):
57+
state = "failed"
58+
else:
59+
state = "done"
6360

64-
def is_failed(self) -> bool:
65-
return self.failed
66-
67-
def is_done(self) -> bool:
68-
return not self.running and not self.failed
61+
return cls(
62+
state=state,
63+
return_values={worker_global_ranks[k]: v for k, v in return_values.items()},
64+
)
6965

7066

7167
@dataclass
@@ -98,15 +94,18 @@ def _deserialize(self, serialized: bytes) -> Any:
9894
def _all_gather(self, object: Any) -> list:
9995
"""gather object from every rank to list on every rank"""
10096
object_bytes = self._serialize(object)
101-
object_list = [bytes()] * self.world_size
97+
object_list = [b""] * self.world_size
10298
dist.all_gather_object(object_list=object_list, obj=object_bytes, group=self.group)
10399
object_list = [self._deserialize(o) for o in object_list]
104100
return object_list
105101

106102
def sync_payloads(
107103
self, payload: LauncherPayload | AgentPayload
108-
) -> list[LauncherPayload | AgentPayload]:
109-
return self._all_gather(object=payload)
110-
111-
def sync_agent_statuses(self, status: AgentStatus) -> list[AgentStatus]:
112-
return self._all_gather(object=status)[1:]
104+
) -> tuple[LauncherPayload, list[AgentPayload]]:
105+
payloads = self._all_gather(object=payload)
106+
launcher_payload = payloads[0]
107+
agent_payloads = payloads[1:]
108+
return launcher_payload, agent_payloads
109+
110+
def sync_agent_statuses(self, status: AgentStatus | None) -> list[AgentStatus]:
111+
return self._all_gather(object=status)[1:] # [0] is launcher (status=None)

tests/test_CI.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def dist_func():
6161
assert len(log_files) == 3
6262

6363
for file in log_files:
64-
with open(f"{tmp}/{file}", "r") as f:
64+
with open(f"{tmp}/{file}") as f:
6565
contents = f.read()
6666
print(contents)
6767
if file.endswith("[0].log"):
@@ -79,7 +79,7 @@ def error_func():
7979
tmp = tempfile.mkdtemp()
8080
os.environ["TORCHRUNX_DIR"] = tmp
8181

82-
with pytest.raises(RuntimeError) as excinfo:
82+
with pytest.raises(ValueError) as excinfo:
8383
trx.launch(
8484
func=error_func,
8585
func_kwargs={},

0 commit comments

Comments
 (0)