Skip to content

Capture stdout/stderr from OS fd instead of Python's sys.stdout / sys.stderr #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def distributed_training(model: nn.Module, num_steps: int = 10) -> nn.Module | N
We can distribute and run this function (e.g. on 2 machines x 2 GPUs) using **`torchrunx`**!

```python
import logging
logging.basicConfig(level=logging.INFO)

import torchrunx

launcher = torchrunx.Launcher(
Expand Down
2 changes: 1 addition & 1 deletion docs/source/usage/logging.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Custom Logging

We forward all agent and worker logs (i.e. from {mod}`logging`, {obj}`sys.stdout`, and {obj}`sys.stderr`) to the launcher process.
We forward all agent and worker logs (i.e. from {mod}`logging`, `stdout`, and `stderr`) to the launcher process.

## Defaults

Expand Down
49 changes: 30 additions & 19 deletions src/torchrunx/utils/log_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@
]

import logging
import os
import pickle
import signal
import struct
import sys
from contextlib import redirect_stderr, redirect_stdout
from dataclasses import dataclass
from io import StringIO
from logging import Handler, Logger
from logging.handlers import SocketHandler
from multiprocessing.synchronize import Event as EventClass
from socketserver import StreamRequestHandler, ThreadingTCPServer
from threading import Thread
from typing import Callable

import cloudpickle
Expand Down Expand Up @@ -129,24 +129,35 @@ def start_logging_server(serialized_args: bytes, stop_event: EventClass) -> None

def redirect_stdio_to_logger(logger: Logger) -> None:
"""Redirect stderr/stdout: send output to logger at every flush."""

class _LoggingStream(StringIO):
def __init__(self, logger: Logger, level: int = logging.NOTSET) -> None:
super().__init__()
self.logger = logger
self.level = level

def flush(self) -> None:
super().flush() # At "flush" to avoid logs of partial bytes
value = self.getvalue()
if value != "":
self.logger.log(self.level, value)
self.truncate(0)
self.seek(0)

logging.captureWarnings(capture=True)
redirect_stderr(_LoggingStream(logger, level=logging.ERROR)).__enter__()
redirect_stdout(_LoggingStream(logger, level=logging.INFO)).__enter__()

def redirect_fd_to_logger(read_fd: int, level: int) -> None:
for line in os.fdopen(read_fd):
logger.log(level, line.rstrip())

# create (r, w) pipe and start logging all outputs from r
read_out_fd, write_out_fd = os.pipe()
Thread(
target=redirect_fd_to_logger,
kwargs={"read_fd": read_out_fd, "level": logging.INFO},
daemon=True,
).start()
# flush buffer before redirecting stdout
sys.stdout.flush()
# pipe: r <-> stdout instead of r <-> w
os.dup2(write_out_fd, sys.stdout.fileno()) # set stdout fd to pipe
os.close(write_out_fd)

# repeat for stderr
read_err_fd, write_err_fd = os.pipe()
Thread(
target=redirect_fd_to_logger,
kwargs={"read_fd": read_err_fd, "level": logging.ERROR},
daemon=True,
).start()
sys.stderr.flush()
os.dup2(write_err_fd, sys.stderr.fileno())
os.close(write_err_fd)


@dataclass
Expand Down