Skip to content

Commit a9f8c59

Browse files
authored
Merge pull request #80 from apoorvkh/logging-process-fix
Fix for logging server serialization problems
2 parents 677edcb + 82fa176 commit a9f8c59

File tree

3 files changed

+81
-41
lines changed

3 files changed

+81
-41
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "torchrunx"
7-
version = "0.2.0"
7+
version = "0.2.1"
88
authors = [
99
{name = "Apoorv Khandelwal", email = "[email protected]"},
1010
{name = "Peter Curtin", email = "[email protected]"},

src/torchrunx/launcher.py

Lines changed: 16 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from dataclasses import dataclass
1717
from functools import partial, reduce
1818
from logging import Handler
19-
from multiprocessing import Process
19+
from multiprocessing import Event, Process
2020
from operator import add
2121
from pathlib import Path
2222
from typing import Any, Callable, Literal
@@ -34,7 +34,7 @@
3434
ExceptionFromWorker,
3535
WorkerFailedError,
3636
)
37-
from .utils.logging import LogRecordSocketReceiver, default_handlers
37+
from .utils.logging import LoggingServerArgs, start_logging_server
3838

3939

4040
@dataclass
@@ -76,27 +76,32 @@ def run( # noqa: C901, PLR0912
7676

7777
launcher_hostname = socket.getfqdn()
7878
launcher_port = get_open_port()
79+
logging_port = get_open_port()
7980
world_size = len(hostnames) + 1
8081

81-
log_receiver = None
82+
stop_logging_event = None
8283
log_process = None
8384
launcher_agent_group = None
8485
agent_payloads = None
8586

8687
try:
8788
# Start logging server (recieves LogRecords from agents/workers)
8889

89-
log_receiver = _build_logging_server(
90+
logging_server_args = LoggingServerArgs(
9091
log_handlers=log_handlers,
91-
launcher_hostname=launcher_hostname,
92+
logging_hostname=launcher_hostname,
93+
logging_port=logging_port,
9294
hostnames=hostnames,
9395
workers_per_host=workers_per_host,
9496
log_dir=Path(os.environ.get("TORCHRUNX_LOG_DIR", "torchrunx_logs")),
9597
log_level=logging._nameToLevel[os.environ.get("TORCHRUNX_LOG_LEVEL", "INFO")], # noqa: SLF001
9698
)
9799

100+
stop_logging_event = Event()
101+
98102
log_process = Process(
99-
target=log_receiver.serve_forever,
103+
target=start_logging_server,
104+
args=(logging_server_args.serialize(), stop_logging_event),
100105
daemon=True,
101106
)
102107

@@ -109,7 +114,7 @@ def run( # noqa: C901, PLR0912
109114
command=_build_launch_command(
110115
launcher_hostname=launcher_hostname,
111116
launcher_port=launcher_port,
112-
logger_port=log_receiver.port,
117+
logger_port=logging_port,
113118
world_size=world_size,
114119
rank=i + 1,
115120
env_vars=(self.default_env_vars + self.extra_env_vars),
@@ -166,11 +171,10 @@ def run( # noqa: C901, PLR0912
166171
if all(s.state == "done" for s in agent_statuses):
167172
break
168173
finally:
169-
if log_receiver is not None:
170-
log_receiver.shutdown()
171-
if log_process is not None:
172-
log_receiver.server_close()
173-
log_process.kill()
174+
if stop_logging_event is not None:
175+
stop_logging_event.set()
176+
if log_process is not None:
177+
log_process.kill()
174178

175179
if launcher_agent_group is not None:
176180
launcher_agent_group.shutdown()
@@ -307,31 +311,6 @@ def _resolve_workers_per_host(
307311
return workers_per_host
308312

309313

310-
def _build_logging_server(
311-
log_handlers: list[Handler] | Literal["auto"] | None,
312-
launcher_hostname: str,
313-
hostnames: list[str],
314-
workers_per_host: list[int],
315-
log_dir: str | os.PathLike,
316-
log_level: int,
317-
) -> LogRecordSocketReceiver:
318-
if log_handlers is None:
319-
log_handlers = []
320-
elif log_handlers == "auto":
321-
log_handlers = default_handlers(
322-
hostnames=hostnames,
323-
workers_per_host=workers_per_host,
324-
log_dir=log_dir,
325-
log_level=log_level,
326-
)
327-
328-
return LogRecordSocketReceiver(
329-
host=launcher_hostname,
330-
port=get_open_port(),
331-
handlers=log_handlers,
332-
)
333-
334-
335314
def _build_launch_command(
336315
launcher_hostname: str,
337316
launcher_port: int,

src/torchrunx/utils/logging.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from __future__ import annotations
44

55
__all__ = [
6-
"LogRecordSocketReceiver",
6+
"LoggingServerArgs",
7+
"start_logging_server",
78
"redirect_stdio_to_logger",
89
"log_records_to_socket",
910
"add_filter_to_handler",
@@ -25,12 +26,14 @@
2526
from logging.handlers import SocketHandler
2627
from pathlib import Path
2728
from socketserver import StreamRequestHandler, ThreadingTCPServer
28-
from typing import TYPE_CHECKING
29+
from typing import TYPE_CHECKING, Literal
2930

31+
import cloudpickle
3032
from typing_extensions import Self
3133

3234
if TYPE_CHECKING:
3335
import os
36+
from multiprocessing.synchronize import Event as EventClass
3437

3538
## Handler utilities
3639

@@ -139,7 +142,7 @@ def default_handlers(
139142
## Launcher utilities
140143

141144

142-
class LogRecordSocketReceiver(ThreadingTCPServer):
145+
class _LogRecordSocketReceiver(ThreadingTCPServer):
143146
"""TCP server for recieving Agent/Worker log records in Launcher.
144147
145148
Uses threading to avoid bottlenecks (i.e. "out-of-order" logs in Launcher process).
@@ -180,6 +183,64 @@ def shutdown(self) -> None:
180183
self._BaseServer__is_shut_down.wait(timeout=3) # pyright: ignore[reportAttributeAccessIssue]
181184

182185

186+
@dataclass
187+
class LoggingServerArgs:
188+
"""Arguments for starting a :class:`_LogRecordSocketReceiver`."""
189+
190+
log_handlers: list[Handler] | Literal["auto"] | None
191+
logging_hostname: str
192+
logging_port: int
193+
hostnames: list[str]
194+
workers_per_host: list[int]
195+
log_dir: str | os.PathLike
196+
log_level: int
197+
198+
def serialize(self) -> SerializedLoggingServerArgs:
199+
"""Serialize :class:`LoggingServerArgs` for passing to a new process."""
200+
return SerializedLoggingServerArgs(args=self)
201+
202+
203+
class SerializedLoggingServerArgs:
204+
def __init__(self, args: LoggingServerArgs) -> None:
205+
self.bytes = cloudpickle.dumps(args)
206+
207+
def deserialize(self) -> LoggingServerArgs:
208+
return cloudpickle.loads(self.bytes)
209+
210+
211+
def start_logging_server(
212+
serialized_args: SerializedLoggingServerArgs,
213+
stop_event: EventClass,
214+
) -> None:
215+
"""Serve :class:`_LogRecordSocketReceiver` until stop event triggered."""
216+
args: LoggingServerArgs = serialized_args.deserialize()
217+
218+
log_handlers = args.log_handlers
219+
if log_handlers is None:
220+
log_handlers = []
221+
elif log_handlers == "auto":
222+
log_handlers = default_handlers(
223+
hostnames=args.hostnames,
224+
workers_per_host=args.workers_per_host,
225+
log_dir=args.log_dir,
226+
log_level=args.log_level,
227+
)
228+
229+
log_receiver = _LogRecordSocketReceiver(
230+
host=args.logging_hostname,
231+
port=args.logging_port,
232+
handlers=log_handlers,
233+
)
234+
235+
log_receiver.serve_forever()
236+
237+
while not stop_event.is_set():
238+
pass
239+
240+
log_receiver.shutdown()
241+
log_receiver.server_close()
242+
243+
183244
## Agent/worker utilities
184245

185246

0 commit comments

Comments
 (0)