|
10 | 10 | ]
|
11 | 11 |
|
12 | 12 | import logging
|
| 13 | +import os |
13 | 14 | import pickle
|
14 | 15 | import signal
|
15 | 16 | import struct
|
16 | 17 | import sys
|
17 |
| -from contextlib import redirect_stderr, redirect_stdout |
18 | 18 | from dataclasses import dataclass
|
19 |
| -from io import StringIO |
20 | 19 | from logging import Handler, Logger
|
21 | 20 | from logging.handlers import SocketHandler
|
22 | 21 | from multiprocessing.synchronize import Event as EventClass
|
23 | 22 | from socketserver import StreamRequestHandler, ThreadingTCPServer
|
| 23 | +from threading import Thread |
24 | 24 | from typing import Callable
|
25 | 25 |
|
26 | 26 | import cloudpickle
|
@@ -129,24 +129,35 @@ def start_logging_server(serialized_args: bytes, stop_event: EventClass) -> None
|
129 | 129 |
|
130 | 130 | def redirect_stdio_to_logger(logger: Logger) -> None:
|
131 | 131 | """Redirect stderr/stdout: send output to logger at every flush."""
|
132 |
| - |
133 |
| - class _LoggingStream(StringIO): |
134 |
| - def __init__(self, logger: Logger, level: int = logging.NOTSET) -> None: |
135 |
| - super().__init__() |
136 |
| - self.logger = logger |
137 |
| - self.level = level |
138 |
| - |
139 |
| - def flush(self) -> None: |
140 |
| - super().flush() # At "flush" to avoid logs of partial bytes |
141 |
| - value = self.getvalue() |
142 |
| - if value != "": |
143 |
| - self.logger.log(self.level, value) |
144 |
| - self.truncate(0) |
145 |
| - self.seek(0) |
146 |
| - |
147 | 132 | logging.captureWarnings(capture=True)
|
148 |
| - redirect_stderr(_LoggingStream(logger, level=logging.ERROR)).__enter__() |
149 |
| - redirect_stdout(_LoggingStream(logger, level=logging.INFO)).__enter__() |
| 133 | + |
| 134 | + def redirect_fd_to_logger(read_fd: int, level: int) -> None: |
| 135 | + for line in os.fdopen(read_fd): |
| 136 | + logger.log(level, line.rstrip()) |
| 137 | + |
| 138 | + # create (r, w) pipe and start logging all outputs from r |
| 139 | + read_out_fd, write_out_fd = os.pipe() |
| 140 | + Thread( |
| 141 | + target=redirect_fd_to_logger, |
| 142 | + kwargs={"read_fd": read_out_fd, "level": logging.INFO}, |
| 143 | + daemon=True, |
| 144 | + ).start() |
| 145 | + # flush buffer before redirecting stdout |
| 146 | + sys.stdout.flush() |
| 147 | + # pipe: r <-> stdout instead of r <-> w |
| 148 | + os.dup2(write_out_fd, sys.stdout.fileno()) # set stdout fd to pipe |
| 149 | + os.close(write_out_fd) |
| 150 | + |
| 151 | + # repeat for stderr |
| 152 | + read_err_fd, write_err_fd = os.pipe() |
| 153 | + Thread( |
| 154 | + target=redirect_fd_to_logger, |
| 155 | + kwargs={"read_fd": read_err_fd, "level": logging.ERROR}, |
| 156 | + daemon=True, |
| 157 | + ).start() |
| 158 | + sys.stderr.flush() |
| 159 | + os.dup2(write_err_fd, sys.stderr.fileno()) |
| 160 | + os.close(write_err_fd) |
150 | 161 |
|
151 | 162 |
|
152 | 163 | @dataclass
|
|
0 commit comments