diff --git a/docs/source/mypy_daemon.rst b/docs/source/mypy_daemon.rst index 9b97a027bcfe..17fa9e757bb7 100644 --- a/docs/source/mypy_daemon.rst +++ b/docs/source/mypy_daemon.rst @@ -22,16 +22,19 @@ you'll find errors sooner. The mypy daemon is experimental. In particular, the command-line interface may change in future mypy releases. -.. note:: - - The mypy daemon currently supports macOS and Linux only. - .. note:: Each mypy daemon process supports one user and one set of source files, and it can only process one type checking request at a time. You can run multiple mypy daemon processes to type check multiple repositories. +.. note:: + + On Windows, due to platform limitations, the mypy daemon does not currently + support a timeout for the server process. The client will still time out if + a connection to the server cannot be made, but the server will wait forever + for a new client connection. + Basic usage *********** @@ -103,5 +106,3 @@ Limitations limitation. This can be defined through the command line or through a :ref:`configuration file `. - -* Windows is not supported. diff --git a/mypy/dmypy.py b/mypy/dmypy.py index 9af5cb9ccd26..84af22ed65bd 100644 --- a/mypy/dmypy.py +++ b/mypy/dmypy.py @@ -7,16 +7,21 @@ """ import argparse +import base64 import json import os +import pickle import signal -import socket +import subprocess import sys import time -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple +from typing import Any, Callable, Dict, Mapping, Optional, Tuple from mypy.dmypy_util import STATUS_FILE, receive +from mypy.ipc import IPCClient, IPCException +from mypy.dmypy_os import alive, kill + from mypy.version import __version__ # Argument parser. Subparsers are tied to action functions by the @@ -92,7 +97,7 @@ def __init__(self, prog: str) -> None: help="Server shutdown timeout (in seconds)") p.add_argument('flags', metavar='FLAG', nargs='*', type=str, help="Regular mypy flags (precede with --)") - +p.add_argument('--options-data', help=argparse.SUPPRESS) help_parser = p = subparsers.add_parser('help') del p @@ -179,10 +184,9 @@ def restart_server(args: argparse.Namespace, allow_sources: bool = False) -> Non def start_server(args: argparse.Namespace, allow_sources: bool = False) -> None: """Start the server from command arguments and wait for it.""" # Lazy import so this import doesn't slow down other commands. - from mypy.dmypy_server import daemonize, Server, process_start_options - if daemonize(Server(process_start_options(args.flags, allow_sources), - timeout=args.timeout).serve, - args.log_file) != 0: + from mypy.dmypy_server import daemonize, process_start_options + start_options = process_start_options(args.flags, allow_sources) + if daemonize(start_options, timeout=args.timeout, log_file=args.log_file): sys.exit(1) wait_for_server() @@ -201,7 +205,7 @@ def wait_for_server(timeout: float = 5.0) -> None: time.sleep(0.1) continue # If the file's content is bogus or the process is dead, fail. - pid, sockname = check_status(data) + check_status(data) print("Daemon started") return sys.exit("Timed out waiting for daemon to start") @@ -224,7 +228,6 @@ def do_run(args: argparse.Namespace) -> None: if not is_running(): # Bad or missing status file or dead process; good to start. start_server(args, allow_sources=True) - t0 = time.time() response = request('run', version=__version__, args=args.flags) # If the daemon signals that a restart is necessary, do it @@ -273,9 +276,9 @@ def do_stop(args: argparse.Namespace) -> None: @action(kill_parser) def do_kill(args: argparse.Namespace) -> None: """Kill daemon process with SIGKILL.""" - pid, sockname = get_status() + pid, _ = get_status() try: - os.kill(pid, signal.SIGKILL) + kill(pid) except OSError as err: sys.exit(str(err)) else: @@ -363,7 +366,20 @@ def do_daemon(args: argparse.Namespace) -> None: """Serve requests in the foreground.""" # Lazy import so this import doesn't slow down other commands. from mypy.dmypy_server import Server, process_start_options - Server(process_start_options(args.flags, allow_sources=False), timeout=args.timeout).serve() + if args.options_data: + from mypy.options import Options + options_dict, timeout, log_file = pickle.loads(base64.b64decode(args.options_data)) + options_obj = Options() + options = options_obj.apply_changes(options_dict) + if log_file: + sys.stdout = sys.stderr = open(log_file, 'a', buffering=1) + fd = sys.stdout.fileno() + os.dup2(fd, 2) + os.dup2(fd, 1) + else: + options = process_start_options(args.flags, allow_sources=False) + timeout = args.timeout + Server(options, timeout=timeout).serve() @action(help_parser) @@ -375,7 +391,7 @@ def do_help(args: argparse.Namespace) -> None: # Client-side infrastructure. -def request(command: str, *, timeout: Optional[float] = None, +def request(command: str, *, timeout: Optional[int] = None, **kwds: object) -> Dict[str, Any]: """Send a request to the daemon. @@ -384,35 +400,30 @@ def request(command: str, *, timeout: Optional[float] = None, Raise BadStatus if there is something wrong with the status file or if the process whose pid is in the status file has died. - Return {'error': } if a socket operation or receive() + Return {'error': } if an IPC operation or receive() raised OSError. This covers cases such as connection refused or closed prematurely as well as invalid JSON received. """ + response = {} # type: Dict[str, str] args = dict(kwds) args.update(command=command) bdata = json.dumps(args).encode('utf8') - pid, sockname = get_status() - sock = socket.socket(socket.AF_UNIX) - if timeout is not None: - sock.settimeout(timeout) + _, name = get_status() try: - sock.connect(sockname) - sock.sendall(bdata) - sock.shutdown(socket.SHUT_WR) - response = receive(sock) - except OSError as err: + with IPCClient(name, timeout) as client: + client.write(bdata) + response = receive(client) + except (OSError, IPCException) as err: return {'error': str(err)} # TODO: Other errors, e.g. ValueError, UnicodeError else: return response - finally: - sock.close() def get_status() -> Tuple[int, str]: """Read status file and check if the process is alive. - Return (pid, sockname) on success. + Return (pid, connection_name) on success. Raise BadStatus if something's wrong. """ @@ -423,7 +434,7 @@ def get_status() -> Tuple[int, str]: def check_status(data: Dict[str, Any]) -> Tuple[int, str]: """Check if the process is alive. - Return (pid, sockname) on success. + Return (pid, connection_name) on success. Raise BadStatus if something's wrong. """ @@ -432,16 +443,14 @@ def check_status(data: Dict[str, Any]) -> Tuple[int, str]: pid = data['pid'] if not isinstance(pid, int): raise BadStatus("pid field is not an int") - try: - os.kill(pid, 0) - except OSError: + if not alive(pid): raise BadStatus("Daemon has died") - if 'sockname' not in data: - raise BadStatus("Invalid status file (no sockname field)") - sockname = data['sockname'] - if not isinstance(sockname, str): - raise BadStatus("sockname field is not a string") - return pid, sockname + if 'connection_name' not in data: + raise BadStatus("Invalid status file (no connection_name field)") + connection_name = data['connection_name'] + if not isinstance(connection_name, str): + raise BadStatus("connection_name field is not a string") + return pid, connection_name def read_status() -> Dict[str, object]: diff --git a/mypy/dmypy_os.py b/mypy/dmypy_os.py new file mode 100644 index 000000000000..77cf963ad612 --- /dev/null +++ b/mypy/dmypy_os.py @@ -0,0 +1,43 @@ +import sys + +from typing import Any, Callable + +if sys.platform == 'win32': + import ctypes + from ctypes.wintypes import DWORD, HANDLE + import subprocess + + PROCESS_QUERY_LIMITED_INFORMATION = ctypes.c_ulong(0x1000) + + kernel32 = ctypes.windll.kernel32 + OpenProcess = kernel32.OpenProcess # type: Callable[[DWORD, int, int], HANDLE] + GetExitCodeProcess = kernel32.GetExitCodeProcess # type: Callable[[HANDLE, Any], int] +else: + import os + import signal + + +def alive(pid: int) -> bool: + """Is the process alive?""" + if sys.platform == 'win32': + # why can't anything be easy... + status = DWORD() + handle = OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, + 0, + pid) + GetExitCodeProcess(handle, ctypes.byref(status)) + return status.value == 259 # STILL_ACTIVE + else: + try: + os.kill(pid, 0) + except OSError: + return False + return True + + +def kill(pid: int) -> None: + """Kill the process.""" + if sys.platform == 'win32': + subprocess.check_output("taskkill /pid {pid} /f /t".format(pid=pid)) + else: + os.kill(pid, signal.SIGKILL) diff --git a/mypy/dmypy_server.py b/mypy/dmypy_server.py index 212cf7e5c440..5d22437fe1f8 100644 --- a/mypy/dmypy_server.py +++ b/mypy/dmypy_server.py @@ -6,10 +6,13 @@ to enable fine-grained incremental reprocessing of changes. """ +import argparse +import base64 import json import os -import shutil -import socket +import pickle +import random +import subprocess import sys import tempfile import time @@ -23,6 +26,7 @@ from mypy.find_sources import create_source_list, InvalidSourceList from mypy.server.update import FineGrainedBuildManager from mypy.dmypy_util import STATUS_FILE, receive +from mypy.ipc import IPCServer, IPCException from mypy.fscache import FileSystemCache from mypy.fswatcher import FileSystemWatcher, FileData from mypy.modulefinder import BuildSource, compute_search_paths @@ -30,66 +34,92 @@ from mypy.typestate import reset_global_state from mypy.version import __version__ + MYPY = False if MYPY: from typing_extensions import Final - MEM_PROFILE = False # type: Final # If True, dump memory profile after initialization +if sys.platform == 'win32': + def daemonize(options: Options, + timeout: Optional[int] = None, + log_file: Optional[str] = None) -> int: + """Create the daemon process via "dmypy daemon" and pass options via command line -def daemonize(func: Callable[[], None], log_file: Optional[str] = None) -> int: - """Arrange to call func() in a grandchild of the current process. - - Return 0 for success, exit status for failure, negative if - subprocess killed by signal. - """ - # See https://stackoverflow.com/questions/473620/how-do-you-create-a-daemon-in-python - # mypyc doesn't like unreachable code, so trick mypy into thinking the branch is reachable - if sys.platform == 'win32' or bool(): - raise ValueError('Mypy daemon is not supported on Windows yet') - sys.stdout.flush() - sys.stderr.flush() - pid = os.fork() - if pid: - # Parent process: wait for child in case things go bad there. - npid, sts = os.waitpid(pid, 0) - sig = sts & 0xff - if sig: - print("Child killed by signal", sig) - return -sig - sts = sts >> 8 - if sts: - print("Child exit status", sts) - return sts - # Child process: do a bunch of UNIX stuff and then fork a grandchild. - try: - os.setsid() # Detach controlling terminal - os.umask(0o27) - devnull = os.open('/dev/null', os.O_RDWR) - os.dup2(devnull, 0) - os.dup2(devnull, 1) - os.dup2(devnull, 2) - os.close(devnull) + This uses the DETACHED_PROCESS flag to invoke the Server. + See https://docs.microsoft.com/en-us/windows/desktop/procthread/process-creation-flags + + It also pickles the options to be unpickled by mypy. + """ + command = [sys.executable, '-m', 'mypy.dmypy', 'daemon'] + pickeled_options = pickle.dumps((options.snapshot(), timeout, log_file)) + command.append('--options-data="{}"'.format(base64.b64encode(pickeled_options).decode())) + try: + subprocess.Popen(command, creationflags=0x8) # DETACHED_PROCESS + return 0 + except subprocess.CalledProcessError as e: + return e.returncode + +else: + def _daemonize_cb(func: Callable[[], None], log_file: Optional[str] = None) -> int: + """Arrange to call func() in a grandchild of the current process. + + Return 0 for success, exit status for failure, negative if + subprocess killed by signal. + """ + # See https://stackoverflow.com/questions/473620/how-do-you-create-a-daemon-in-python + sys.stdout.flush() + sys.stderr.flush() pid = os.fork() if pid: - # Child is done, exit to parent. - os._exit(0) - # Grandchild: run the server. - if log_file: - sys.stdout = sys.stderr = open(log_file, 'a', buffering=1) - fd = sys.stdout.fileno() - os.dup2(fd, 2) - os.dup2(fd, 1) - func() - finally: - # Make sure we never get back into the caller. - os._exit(1) + # Parent process: wait for child in case things go bad there. + npid, sts = os.waitpid(pid, 0) + sig = sts & 0xff + if sig: + print("Child killed by signal", sig) + return -sig + sts = sts >> 8 + if sts: + print("Child exit status", sts) + return sts + # Child process: do a bunch of UNIX stuff and then fork a grandchild. + try: + os.setsid() # Detach controlling terminal + os.umask(0o27) + devnull = os.open('/dev/null', os.O_RDWR) + os.dup2(devnull, 0) + os.dup2(devnull, 1) + os.dup2(devnull, 2) + os.close(devnull) + pid = os.fork() + if pid: + # Child is done, exit to parent. + os._exit(0) + # Grandchild: run the server. + if log_file: + sys.stdout = sys.stderr = open(log_file, 'a', buffering=1) + fd = sys.stdout.fileno() + os.dup2(fd, 2) + os.dup2(fd, 1) + func() + finally: + # Make sure we never get back into the caller. + os._exit(1) + + def daemonize(options: Options, + timeout: Optional[int] = None, + log_file: Optional[str] = None) -> int: + """Run the mypy daemon in a grandchild of the current process + Return 0 for success, exit status for failure, negative if + subprocess killed by signal. + """ + return _daemonize_cb(Server(options, timeout).serve, log_file) # Server code. -SOCKET_NAME = 'dmypy.sock' # type: Final +CONNECTION_NAME = 'dmypy.sock' # type: Final def process_start_options(flags: List[str], allow_sources: bool) -> Options: @@ -155,25 +185,13 @@ def serve(self) -> None: """Serve requests, synchronously (no thread or fork).""" command = None try: - sock = self.create_listening_socket() - if self.timeout is not None: - sock.settimeout(self.timeout) - try: - with open(STATUS_FILE, 'w') as f: - json.dump({'pid': os.getpid(), 'sockname': sock.getsockname()}, f) - f.write('\n') # I like my JSON with trailing newline - while True: - try: - conn, addr = sock.accept() - except socket.timeout: - print("Exiting due to inactivity.") - reset_global_state() - sys.exit(0) - try: - data = receive(conn) - except OSError: - conn.close() # Maybe the client hung up - continue + server = IPCServer(CONNECTION_NAME, self.timeout) + with open(STATUS_FILE, 'w') as f: + json.dump({'pid': os.getpid(), 'connection_name': server.connection_name}, f) + f.write('\n') # I like my JSON with a trailing newline + while True: + with server: + data = receive(server) resp = {} # type: Dict[str, Any] if 'command' not in data: resp = {'error': "No command found in request"} @@ -189,40 +207,31 @@ def serve(self) -> None: # If we are crashing, report the crash to the client tb = traceback.format_exception(*sys.exc_info()) resp = {'error': "Daemon crashed!\n" + "".join(tb)} - conn.sendall(json.dumps(resp).encode('utf8')) + server.write(json.dumps(resp).encode('utf8')) raise try: - conn.sendall(json.dumps(resp).encode('utf8')) + server.write(json.dumps(resp).encode('utf8')) except OSError: pass # Maybe the client hung up - conn.close() if command == 'stop': - sock.close() reset_global_state() sys.exit(0) - finally: - # If the final command is something other than a clean - # stop, remove the status file. (We can't just - # simplify the logic and always remove the file, since - # that could cause us to remove a future server's - # status file.) - if command != 'stop': - os.unlink(STATUS_FILE) finally: - shutil.rmtree(self.sock_directory) + # If the final command is something other than a clean + # stop, remove the status file. (We can't just + # simplify the logic and always remove the file, since + # that could cause us to remove a future server's + # status file.) + if command != 'stop': + os.unlink(STATUS_FILE) + try: + server.cleanup() # try to remove the socket dir on Linux + except OSError: + pass exc_info = sys.exc_info() if exc_info[0] and exc_info[0] is not SystemExit: traceback.print_exception(*exc_info) - def create_listening_socket(self) -> socket.socket: - """Create the socket and set it up for listening.""" - self.sock_directory = tempfile.mkdtemp() - sockname = os.path.join(self.sock_directory, SOCKET_NAME) - sock = socket.socket(socket.AF_UNIX) - sock.bind(sockname) - sock.listen(1) - return sock - def run_command(self, command: str, data: Mapping[str, object]) -> Dict[str, object]: """Run a specific command from the registry.""" key = 'cmd_' + command @@ -457,17 +466,7 @@ def cmd_hang(self) -> Dict[str, object]: def get_meminfo() -> Dict[str, Any]: - # See https://stackoverflow.com/questions/938733/total-memory-used-by-python-process - import resource # Since it doesn't exist on Windows. res = {} # type: Dict[str, Any] - rusage = resource.getrusage(resource.RUSAGE_SELF) - # mypyc doesn't like unreachable code, so trick mypy into thinking the branch is reachable - if sys.platform == 'darwin' or bool(): - factor = 1 - else: - factor = 1024 # Linux - res['memory_maxrss_mib'] = rusage.ru_maxrss * factor / MiB - # If we can import psutil, use it for some extra data try: import psutil # type: ignore # It's not in typeshed yet except ImportError: @@ -481,4 +480,17 @@ def get_meminfo() -> Dict[str, Any]: meminfo = process.memory_info() res['memory_rss_mib'] = meminfo.rss / MiB res['memory_vms_mib'] = meminfo.vms / MiB + if sys.platform == 'win32': + res['memory_maxrss_mib'] = meminfo.peak_wset / MiB + else: + # See https://stackoverflow.com/questions/938733/total-memory-used-by-python-process + import resource # Since it doesn't exist on Windows. + rusage = resource.getrusage(resource.RUSAGE_SELF) + # mypyc doesn't like unreachable code, so trick mypy into thinking + # the branch is reachable + if sys.platform == 'darwin' or bool(): + factor = 1 + else: + factor = 1024 # Linux + res['memory_maxrss_mib'] = rusage.ru_maxrss * factor / MiB return res diff --git a/mypy/dmypy_util.py b/mypy/dmypy_util.py index d5aacfd1dc45..012b994cea18 100644 --- a/mypy/dmypy_util.py +++ b/mypy/dmypy_util.py @@ -1,13 +1,14 @@ """Shared code between dmypy.py and dmypy_server.py. -This should be pretty lightweight and not depend on other mypy code. +This should be pretty lightweight and not depend on other mypy code (other than ipc). """ import json -import socket from typing import Any +from mypy.ipc import IPCBase + MYPY = False if MYPY: from typing_extensions import Final @@ -15,20 +16,13 @@ STATUS_FILE = '.dmypy.json' # type: Final -def receive(sock: socket.socket) -> Any: - """Receive JSON data from a socket until EOF. - - Raise a subclass of OSError if there's a socket exception. +def receive(connection: IPCBase) -> Any: + """Receive JSON data from a connection until EOF. Raise OSError if the data received is not valid JSON or if it is not a dict. """ - bdata = bytearray() - while True: - more = sock.recv(100000) - if not more: - break - bdata.extend(more) + bdata = connection.read() if not bdata: raise OSError("No data received") try: diff --git a/mypy/ipc.py b/mypy/ipc.py new file mode 100644 index 000000000000..5bc2c91f6ed1 --- /dev/null +++ b/mypy/ipc.py @@ -0,0 +1,217 @@ +"""Cross platform abstractions for inter-process communication + +On Unix, this uses AF_UNIX sockets. +On Windows, this uses NamedPipes. +""" + +import base64 +import contextlib +import os +import shutil +import sys +import tempfile + +from typing import Iterator, Optional, Callable + +MYPY = False +if MYPY: + from typing import Type + +from types import TracebackType + +if sys.platform == 'win32': + # This may be private, but it is needed for IPC on Windows, and is basically stable + import _winapi + import ctypes + + _IPCHandle = int + + kernel32 = ctypes.windll.kernel32 + DisconnectNamedPipe = kernel32.DisconnectNamedPipe # type: Callable[[_IPCHandle], int] + FlushFileBuffers = kernel32.FlushFileBuffers # type: Callable[[_IPCHandle], int] +else: + import socket + _IPCHandle = socket.socket + + +class IPCException(Exception): + """Exception for IPC issues.""" + pass + + +class IPCBase: + """Base class for communication between the dmypy client and server. + + This contains logic shared between the client and server, such as reading + and writing. + """ + + connection = None # type: _IPCHandle + + def __init__(self, name: str) -> None: + self.READ_SIZE = 100000 + self.name = name + + def read(self) -> bytes: + """Read bytes from an IPC connection until its empty.""" + bdata = bytearray() + while True: + if sys.platform == 'win32': + more, _ = _winapi.ReadFile(self.connection, self.READ_SIZE) + else: + more = self.connection.recv(self.READ_SIZE) + if not more: + break + bdata.extend(more) + return bytes(bdata) + + def write(self, data: bytes) -> None: + """Write bytes to an IPC connection.""" + if sys.platform == 'win32': + try: + _winapi.WriteFile(self.connection, data) + # this empty write is to copy the behavior of socket.sendall, + # which also sends an empty message to signify it is done writing + _winapi.WriteFile(self.connection, b'') + except WindowsError as e: + raise IPCException("Failed to write with error: {}".format(e.winerror)) + else: + self.connection.sendall(data) + self.connection.shutdown(socket.SHUT_WR) + + def close(self) -> None: + if sys.platform == 'win32': + if self.connection != _winapi.NULL: + _winapi.CloseHandle(self.connection) + else: + self.connection.close() + + +class IPCClient(IPCBase): + """The client side of an IPC connection.""" + + def __init__(self, name: str, timeout: Optional[int]) -> None: + super().__init__(name) + if sys.platform == 'win32': + timeout = timeout or 1000 # we need to set a timeout + try: + _winapi.WaitNamedPipe(self.name, timeout) + except FileNotFoundError: + raise IPCException("The NamedPipe at {} was not found.".format(self.name)) + except WindowsError as e: + if e.winerror == _winapi.ERROR_SEM_TIMEOUT: + raise IPCException("Timed out waiting for connection.") + else: + raise + try: + self.connection = _winapi.CreateFile( + self.name, + _winapi.GENERIC_READ | _winapi.GENERIC_WRITE, + 0, + _winapi.NULL, + _winapi.OPEN_EXISTING, + 0, + _winapi.NULL, + ) + except WindowsError as e: + if e.winerror == _winapi.ERROR_PIPE_BUSY: + raise IPCException("The connection is busy.") + else: + raise + _winapi.SetNamedPipeHandleState(self.connection, + _winapi.PIPE_READMODE_MESSAGE, + None, + None) + else: + self.connection = socket.socket(socket.AF_UNIX) + self.connection.settimeout(timeout) + self.connection.connect(name) + + def __enter__(self) -> 'IPCClient': + return self + + def __exit__(self, + exc_ty: 'Optional[Type[BaseException]]' = None, + exc_val: Optional[BaseException] = None, + exc_tb: Optional[TracebackType] = None, + ) -> bool: + self.close() + return False + + +class IPCServer(IPCBase): + + BUFFER_SIZE = 2**16 + + def __init__(self, name: str, timeout: Optional[int] = None) -> None: + if sys.platform == 'win32': + name = r'\\.\pipe\{}-{}.pipe'.format(name, base64.b64encode(os.urandom(6))) + super().__init__(name) + if sys.platform == 'win32': + self.connection = _winapi.CreateNamedPipe(self.name, + _winapi.PIPE_ACCESS_DUPLEX + | _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE, + _winapi.PIPE_READMODE_MESSAGE + | _winapi.PIPE_TYPE_MESSAGE + | _winapi.PIPE_WAIT + | 0x8, # PIPE_REJECT_REMOTE_CLIENTS + 1, # one instance + self.BUFFER_SIZE, + self.BUFFER_SIZE, + 1000, # Default timeout in milis + 0, # Use default security descriptor + ) + if self.connection == -1: # INVALID_HANDLE_VALUE + err = _winapi.GetLastError() + raise IPCException('Invalid handle to pipe: {err}'.format(err)) + else: + self.sock_directory = tempfile.mkdtemp() + sockfile = os.path.join(self.sock_directory, self.name) + self.sock = socket.socket(socket.AF_UNIX) + self.sock.bind(sockfile) + self.sock.listen(1) + if timeout is not None: + self.sock.settimeout(timeout) + + def __enter__(self) -> 'IPCServer': + if sys.platform == 'win32': + # NOTE: It is theoretically possible that this will hang forever if the + # client never connects, though this can be "solved" by killing the server + try: + _winapi.ConnectNamedPipe(self.connection, _winapi.NULL) + except WindowsError as e: + if e.winerror == _winapi.ERROR_PIPE_CONNECTED: + pass # The client already exists, which is fine. + else: + try: + self.connection, _ = self.sock.accept() + except socket.timeout: + raise IPCException('The socket timed out') + return self + + def __exit__(self, + exc_ty: 'Optional[Type[BaseException]]' = None, + exc_val: Optional[BaseException] = None, + exc_tb: Optional[TracebackType] = None, + ) -> bool: + if sys.platform == 'win32': + # Wait for the client to finish reading the last write before disconnecting + if not FlushFileBuffers(self.connection): + raise IPCException("Failed to flush NamedPipe buffer, maybe the client hung up?") + DisconnectNamedPipe(self.connection) + else: + self.close() + return False + + def cleanup(self) -> None: + if sys.platform == 'win32': + self.close() + else: + shutil.rmtree(self.sock_directory) + + @property + def connection_name(self) -> str: + if sys.platform == 'win32': + return self.name + else: + return self.sock.getsockname() diff --git a/mypy/test/testdaemon.py b/mypy/test/testdaemon.py index 29135662c991..3a55b3e01008 100644 --- a/mypy/test/testdaemon.py +++ b/mypy/test/testdaemon.py @@ -31,14 +31,13 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: def test_daemon(testcase: DataDrivenTestCase) -> None: - if sys.platform == 'win32': - return # These tests don't run on Windows yet. assert testcase.old_cwd is not None, "test was not properly set up" for i, step in enumerate(parse_script(testcase.input)): cmd = step[0] expected_lines = step[1:] assert cmd.startswith('$') cmd = cmd[1:].strip() + cmd = cmd.replace('{python}', sys.executable) sts, output = run_cmd(cmd) output_lines = output.splitlines() if sts: diff --git a/mypy/test/testipc.py b/mypy/test/testipc.py new file mode 100644 index 000000000000..660107b78325 --- /dev/null +++ b/mypy/test/testipc.py @@ -0,0 +1,56 @@ +import os +import sys +import time +from unittest import TestCase, main +from multiprocessing import Process, Queue + +from mypy.ipc import IPCClient, IPCServer, IPCException + + +CONNECTION_NAME = 'dmypy-test-ipc.sock' + + +def server(msg: str, q: 'Queue[str]') -> None: + server = IPCServer(CONNECTION_NAME) + q.put(server.connection_name) + data = b'' + while not data: + with server: + server.write(msg.encode()) + data = server.read() + server.cleanup() + + +class IPCTests(TestCase): + def test_transaction_large(self) -> None: + queue = Queue() # type: Queue[str] + msg = 't' * 100001 # longer than the max read size of 100_000 + p = Process(target=server, args=(msg, queue), daemon=True) + p.start() + connection_name = queue.get() + with IPCClient(connection_name, timeout=1) as client: + assert client.read() == msg.encode() + client.write(b'test') + queue.close() + queue.join_thread() + p.join() + + def test_connect_twice(self) -> None: + queue = Queue() # type: Queue[str] + msg = 'this is a test message' + p = Process(target=server, args=(msg, queue), daemon=True) + p.start() + connection_name = queue.get() + with IPCClient(connection_name, timeout=1) as client: + assert client.read() == msg.encode() + client.write(b'') # don't let the server hang up yet, we want to connect again. + + with IPCClient(connection_name, timeout=1) as client: + client.write(b'test') + queue.close() + queue.join_thread() + p.join() + + +if __name__ == '__main__': + main() diff --git a/test-data/unit/daemon.test b/test-data/unit/daemon.test index 7b2a42862725..27579556dfba 100644 --- a/test-data/unit/daemon.test +++ b/test-data/unit/daemon.test @@ -29,15 +29,15 @@ def f(): pass $ dmypy run -- foo.py --follow-imports=error Daemon started $ dmypy run -- foo.py --follow-imports=error -$ echo '[mypy]' >mypy.ini -$ echo 'disallow_untyped_defs = True' >>mypy.ini +$ {python} -c "print('[mypy]')" >mypy.ini +$ {python} -c "print('disallow_untyped_defs = True')" >>mypy.ini $ dmypy run -- foo.py --follow-imports=error Restarting: configuration changed Daemon stopped Daemon started foo.py:1: error: Function is missing a type annotation == Return code: 1 -$ echo 'def f() -> None: pass' >foo.py +$ {python} -c "print('def f() -> None: pass')" >foo.py $ dmypy run -- foo.py --follow-imports=error $ dmypy stop Daemon stopped @@ -47,7 +47,7 @@ def f(): pass [case testDaemonRunRestartPluginVersion] $ dmypy run -- foo.py Daemon started -$ echo ' ' >>plug.py +$ {python} -c "print(' ')" >> plug.py $ dmypy run -- foo.py Restarting: plugins changed Daemon stopped