Skip to content

Commit 2b7c073

Browse files
authored
Merge pull request #90 from apoorvkh/capture-stdio-from-fd
Capture stdout/stderr from OS fd instead of Python's `sys.stdout` / `sys.stderr`
2 parents 10e5d33 + 62bd030 commit 2b7c073

File tree

3 files changed

+34
-20
lines changed

3 files changed

+34
-20
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ def distributed_training(model: nn.Module, num_steps: int = 10) -> nn.Module | N
6969
We can distribute and run this function (e.g. on 2 machines x 2 GPUs) using **`torchrunx`**!
7070

7171
```python
72+
import logging
73+
logging.basicConfig(level=logging.INFO)
74+
7275
import torchrunx
7376

7477
launcher = torchrunx.Launcher(

docs/source/usage/logging.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Custom Logging
22

3-
We forward all agent and worker logs (i.e. from {mod}`logging`, {obj}`sys.stdout`, and {obj}`sys.stderr`) to the launcher process.
3+
We forward all agent and worker logs (i.e. from {mod}`logging`, `stdout`, and `stderr`) to the launcher process.
44

55
## Defaults
66

src/torchrunx/utils/log_streaming.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,17 @@
1010
]
1111

1212
import logging
13+
import os
1314
import pickle
1415
import signal
1516
import struct
1617
import sys
17-
from contextlib import redirect_stderr, redirect_stdout
1818
from dataclasses import dataclass
19-
from io import StringIO
2019
from logging import Handler, Logger
2120
from logging.handlers import SocketHandler
2221
from multiprocessing.synchronize import Event as EventClass
2322
from socketserver import StreamRequestHandler, ThreadingTCPServer
23+
from threading import Thread
2424
from typing import Callable
2525

2626
import cloudpickle
@@ -129,24 +129,35 @@ def start_logging_server(serialized_args: bytes, stop_event: EventClass) -> None
129129

130130
def redirect_stdio_to_logger(logger: Logger) -> None:
131131
"""Redirect stderr/stdout: send output to logger at every flush."""
132-
133-
class _LoggingStream(StringIO):
134-
def __init__(self, logger: Logger, level: int = logging.NOTSET) -> None:
135-
super().__init__()
136-
self.logger = logger
137-
self.level = level
138-
139-
def flush(self) -> None:
140-
super().flush() # At "flush" to avoid logs of partial bytes
141-
value = self.getvalue()
142-
if value != "":
143-
self.logger.log(self.level, value)
144-
self.truncate(0)
145-
self.seek(0)
146-
147132
logging.captureWarnings(capture=True)
148-
redirect_stderr(_LoggingStream(logger, level=logging.ERROR)).__enter__()
149-
redirect_stdout(_LoggingStream(logger, level=logging.INFO)).__enter__()
133+
134+
def redirect_fd_to_logger(read_fd: int, level: int) -> None:
135+
for line in os.fdopen(read_fd):
136+
logger.log(level, line.rstrip())
137+
138+
# create (r, w) pipe and start logging all outputs from r
139+
read_out_fd, write_out_fd = os.pipe()
140+
Thread(
141+
target=redirect_fd_to_logger,
142+
kwargs={"read_fd": read_out_fd, "level": logging.INFO},
143+
daemon=True,
144+
).start()
145+
# flush buffer before redirecting stdout
146+
sys.stdout.flush()
147+
# pipe: r <-> stdout instead of r <-> w
148+
os.dup2(write_out_fd, sys.stdout.fileno()) # set stdout fd to pipe
149+
os.close(write_out_fd)
150+
151+
# repeat for stderr
152+
read_err_fd, write_err_fd = os.pipe()
153+
Thread(
154+
target=redirect_fd_to_logger,
155+
kwargs={"read_fd": read_err_fd, "level": logging.ERROR},
156+
daemon=True,
157+
).start()
158+
sys.stderr.flush()
159+
os.dup2(write_err_fd, sys.stderr.fileno())
160+
os.close(write_err_fd)
150161

151162

152163
@dataclass

0 commit comments

Comments
 (0)