diff --git a/README.md b/README.md index 09a10f9..969d2a5 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,9 @@ def distributed_training(model: nn.Module, num_steps: int = 10) -> nn.Module | N We can distribute and run this function (e.g. on 2 machines x 2 GPUs) using **`torchrunx`**! ```python +import logging +logging.basicConfig(level=logging.INFO) + import torchrunx launcher = torchrunx.Launcher( diff --git a/docs/source/usage/logging.md b/docs/source/usage/logging.md index 0e76493..2352db1 100644 --- a/docs/source/usage/logging.md +++ b/docs/source/usage/logging.md @@ -1,6 +1,6 @@ # Custom Logging -We forward all agent and worker logs (i.e. from {mod}`logging`, {obj}`sys.stdout`, and {obj}`sys.stderr`) to the launcher process. +We forward all agent and worker logs (i.e. from {mod}`logging`, `stdout`, and `stderr`) to the launcher process. ## Defaults diff --git a/src/torchrunx/utils/log_streaming.py b/src/torchrunx/utils/log_streaming.py index af5ff52..69f8d58 100644 --- a/src/torchrunx/utils/log_streaming.py +++ b/src/torchrunx/utils/log_streaming.py @@ -10,17 +10,17 @@ ] import logging +import os import pickle import signal import struct import sys -from contextlib import redirect_stderr, redirect_stdout from dataclasses import dataclass -from io import StringIO from logging import Handler, Logger from logging.handlers import SocketHandler from multiprocessing.synchronize import Event as EventClass from socketserver import StreamRequestHandler, ThreadingTCPServer +from threading import Thread from typing import Callable import cloudpickle @@ -129,24 +129,35 @@ def start_logging_server(serialized_args: bytes, stop_event: EventClass) -> None def redirect_stdio_to_logger(logger: Logger) -> None: """Redirect stderr/stdout: send output to logger at every flush.""" - - class _LoggingStream(StringIO): - def __init__(self, logger: Logger, level: int = logging.NOTSET) -> None: - super().__init__() - self.logger = logger - self.level = level - - def flush(self) -> None: - super().flush() # At "flush" to avoid logs of partial bytes - value = self.getvalue() - if value != "": - self.logger.log(self.level, value) - self.truncate(0) - self.seek(0) - logging.captureWarnings(capture=True) - redirect_stderr(_LoggingStream(logger, level=logging.ERROR)).__enter__() - redirect_stdout(_LoggingStream(logger, level=logging.INFO)).__enter__() + + def redirect_fd_to_logger(read_fd: int, level: int) -> None: + for line in os.fdopen(read_fd): + logger.log(level, line.rstrip()) + + # create (r, w) pipe and start logging all outputs from r + read_out_fd, write_out_fd = os.pipe() + Thread( + target=redirect_fd_to_logger, + kwargs={"read_fd": read_out_fd, "level": logging.INFO}, + daemon=True, + ).start() + # flush buffer before redirecting stdout + sys.stdout.flush() + # pipe: r <-> stdout instead of r <-> w + os.dup2(write_out_fd, sys.stdout.fileno()) # set stdout fd to pipe + os.close(write_out_fd) + + # repeat for stderr + read_err_fd, write_err_fd = os.pipe() + Thread( + target=redirect_fd_to_logger, + kwargs={"read_fd": read_err_fd, "level": logging.ERROR}, + daemon=True, + ).start() + sys.stderr.flush() + os.dup2(write_err_fd, sys.stderr.fileno()) + os.close(write_err_fd) @dataclass