diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 546361f80b1f47..a80b25ad657107 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -16,6 +16,7 @@ import collections import collections.abc import concurrent.futures +import enum import errno import heapq import itertools @@ -272,6 +273,23 @@ async def restore(self): self._proto.resume_writing() +class _ServerState(enum.Enum): + """This tracks the state of Server. + + -[in]->NOT_STARTED -[ss]-> SERVING -[cl]-> CLOSED -[wk]*-> SHUTDOWN + + - in: Server.__init__() + - ss: Server._start_serving() + - cl: Server.close() + - wk: Server._wakeup() *only called if number of clients == 0 + """ + + NOT_STARTED = "not_started" + SERVING = "serving" + CLOSED = "closed" + SHUTDOWN = "shutdown" + + class Server(events.AbstractServer): def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog, @@ -287,22 +305,33 @@ def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog, self._ssl_context = ssl_context self._ssl_handshake_timeout = ssl_handshake_timeout self._ssl_shutdown_timeout = ssl_shutdown_timeout - self._serving = False + self._state = _ServerState.NOT_STARTED self._serving_forever_fut = None def __repr__(self): return f'<{self.__class__.__name__} sockets={self.sockets!r}>' def _attach(self, transport): - assert self._sockets is not None + if self._state != _ServerState.SERVING: + raise RuntimeError("server is not serving, cannot attach transport") self._clients.add(transport) def _detach(self, transport): self._clients.discard(transport) - if len(self._clients) == 0 and self._sockets is None: - self._wakeup() + if self._state == _ServerState.CLOSED and len(self._clients) == 0: + self._shutdown() + + def _shutdown(self): + if self._state == _ServerState.CLOSED: + self._state = _ServerState.SHUTDOWN + elif self._state == _ServerState.SHUTDOWN: + # gh109564: the wakeup method has two possible call-sites, + # through an explicit call Server.close(), or indirectly through + # Server._detach() by the last connected client. + return + else: + raise RuntimeError(f"server {self!r} must be closed before shutdown") - def _wakeup(self): waiters = self._waiters self._waiters = None for waiter in waiters: @@ -310,9 +339,13 @@ def _wakeup(self): waiter.set_result(None) def _start_serving(self): - if self._serving: + if self._state == _ServerState.NOT_STARTED: + self._state = _ServerState.SERVING + elif self._state == _ServerState.SERVING: return - self._serving = True + else: + raise RuntimeError(f'server {self!r} was already started and then closed') + for sock in self._sockets: sock.listen(self._backlog) self._loop._start_serving( @@ -324,7 +357,7 @@ def get_loop(self): return self._loop def is_serving(self): - return self._serving + return self._state == _ServerState.SERVING @property def sockets(self): @@ -333,23 +366,30 @@ def sockets(self): return tuple(trsock.TransportSocket(s) for s in self._sockets) def close(self): - sockets = self._sockets - if sockets is None: + if self._state in {_ServerState.CLOSED, _ServerState.SHUTDOWN}: return - self._sockets = None - for sock in sockets: - self._loop._stop_serving(sock) + prev_state = self._state + try: + self._state = _ServerState.CLOSED - self._serving = False + sockets = self._sockets + if sockets is None: + return + self._sockets = None - if (self._serving_forever_fut is not None and - not self._serving_forever_fut.done()): - self._serving_forever_fut.cancel() - self._serving_forever_fut = None + for sock in sockets: + self._loop._stop_serving(sock) - if len(self._clients) == 0: - self._wakeup() + if (self._serving_forever_fut is not None and + not self._serving_forever_fut.done()): + self._serving_forever_fut.cancel() + self._serving_forever_fut = None + + if len(self._clients) == 0: + self._shutdown() + except: + self._state = prev_state def close_clients(self): for transport in self._clients.copy(): @@ -369,8 +409,6 @@ async def serve_forever(self): if self._serving_forever_fut is not None: raise RuntimeError( f'server {self!r} is already being awaited on serve_forever()') - if self._sockets is None: - raise RuntimeError(f'server {self!r} is closed') self._start_serving() self._serving_forever_fut = self._loop.create_future() @@ -407,7 +445,7 @@ async def wait_closed(self): # from two places: self.close() and self._detach(), but only # when both conditions have become true. To signal that this # has happened, self._wakeup() sets self._waiters to None. - if self._waiters is None: + if self._state == _ServerState.SHUTDOWN: return waiter = self._loop.create_future() self._waiters.append(waiter) diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index 7eb55bd63ddb73..159360c0d32680 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -62,7 +62,7 @@ def __init__(self, loop, sock, protocol, waiter=None, self._closing = False # Set when close() called. self._called_connection_lost = False self._eof_written = False - if self._server is not None: + if self._server is not None and self._server.is_serving(): self._server._attach(self) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 22147451fa7ebd..f967dd6a03c99e 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -794,8 +794,9 @@ def __init__(self, loop, sock, protocol, extra=None, server=None): self._closing = False # Set when close() called. self._paused = False # Set when pause_reading() called - if self._server is not None: + if self._server is not None and self._server.is_serving(): self._server._attach(self) + loop._transports[self._sock_fd] = self def __repr__(self): diff --git a/Lib/test/test_asyncio/test_server.py b/Lib/test/test_asyncio/test_server.py index 32211f4cba32cb..4daa7d2ce70be3 100644 --- a/Lib/test/test_asyncio/test_server.py +++ b/Lib/test/test_asyncio/test_server.py @@ -4,6 +4,7 @@ import time import threading import unittest +from unittest.mock import Mock from test.support import socket_helper from test.test_asyncio import utils as test_utils @@ -65,7 +66,7 @@ async def main(srv): self.assertIsNone(srv._waiters) self.assertFalse(srv.is_serving()) - with self.assertRaisesRegex(RuntimeError, r'is closed'): + with self.assertRaisesRegex(RuntimeError, r'started and then closed'): self.loop.run_until_complete(srv.serve_forever()) @@ -118,7 +119,7 @@ async def main(srv): self.assertIsNone(srv._waiters) self.assertFalse(srv.is_serving()) - with self.assertRaisesRegex(RuntimeError, r'is closed'): + with self.assertRaisesRegex(RuntimeError, r'started and then closed'): self.loop.run_until_complete(srv.serve_forever()) @@ -186,6 +187,8 @@ async def serve(rd, wr): loop.call_soon(srv.close) loop.call_soon(wr.close) await srv.wait_closed() + self.assertTrue(task.done()) + self.assertFalse(srv.is_serving()) async def test_close_clients(self): async def serve(rd, wr): @@ -212,6 +215,9 @@ async def serve(rd, wr): await asyncio.sleep(0) self.assertTrue(task.done()) + with self.assertRaisesRegex(RuntimeError, r'started and then closed'): + await srv.start_serving() + async def test_abort_clients(self): async def serve(rd, wr): fut.set_result((rd, wr)) diff --git a/Misc/NEWS.d/next/Library/2025-03-09-23-10-39.gh-issue-109564.r9rnIB.rst b/Misc/NEWS.d/next/Library/2025-03-09-23-10-39.gh-issue-109564.r9rnIB.rst new file mode 100644 index 00000000000000..70e981a8a5dd5f --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-03-09-23-10-39.gh-issue-109564.r9rnIB.rst @@ -0,0 +1 @@ +Fix race condition in :meth:`asyncio.Server.close`. Patch by Jamie Phan.