diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index abe7d67463..a2c52927a6 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -725,8 +725,14 @@ async def execute_command(self, *args: EncodableT): await self.connect() connection = self.connection - kwargs = {"check_health": not self.subscribed} - await self._execute(connection, connection.send_command, *args, **kwargs) + await self._execute( + connection, + lambda: connection.send_command( + *args, + check_health=not self.subscribed, + disconnect_on_interrupt=False, + ), + ) async def connect(self): """ @@ -753,7 +759,7 @@ async def _disconnect_raise_connect(self, conn, error): raise error await conn.connect() - async def _execute(self, conn, command, *args, **kwargs): + async def _execute(self, conn, command): """ Connect manually upon disconnection. If the Redis server is down, this will fail and raise a ConnectionError as desired. @@ -762,7 +768,7 @@ async def _execute(self, conn, command, *args, **kwargs): patterns we were previously listening to """ return await conn.retry.call_with_retry( - lambda: command(*args, **kwargs), + command, lambda error: self._disconnect_raise_connect(conn, error), ) @@ -781,7 +787,13 @@ async def parse_response(self, block: bool = True, timeout: float = 0): await conn.connect() read_timeout = None if block else timeout - response = await self._execute(conn, conn.read_response, timeout=read_timeout) + response = await self._execute( + conn, + lambda: conn.read_response( + timeout=read_timeout, + disconnect_on_interrupt=False, + ), + ) if conn.health_check_interval and response == self.health_check_response: # ignore the health check message as user might not expect it @@ -801,7 +813,10 @@ async def check_health(self): and asyncio.get_event_loop().time() > conn.next_health_check ): await conn.send_command( - "PING", self.HEALTH_CHECK_MESSAGE, check_health=False + "PING", + self.HEALTH_CHECK_MESSAGE, + check_health=False, + disconnect_on_interrupt=False, ) def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT: diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 4f19153318..816b22ad01 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -731,7 +731,11 @@ async def _send_packed_command(self, command: Iterable[bytes]) -> None: await self._writer.drain() async def send_packed_command( - self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True + self, + command: Union[bytes, str, Iterable[bytes]], + check_health: bool = True, + *, + disconnect_on_interrupt: bool = True, ) -> None: if not self.is_connected: await self.connect() @@ -763,14 +767,22 @@ async def send_packed_command( raise ConnectionError( f"Error {err_no} while writing to socket. {errmsg}." ) from e - except Exception: - await self.disconnect(nowait=True) + except BaseException: + # On interrupt (e.g. by CancelledError) there's no way to determine + # how much data, if any, was successfully sent, so this socket is unusable + # for subsequent commands (which may concatenate to an unfinished command). + if disconnect_on_interrupt: + await self.disconnect(nowait=True) raise - async def send_command(self, *args: Any, **kwargs: Any) -> None: + async def send_command( + self, *args, check_health=True, disconnect_on_interrupt=True, **kwargs, + ): """Pack and send a command to the Redis server""" await self.send_packed_command( - self.pack_command(*args), check_health=kwargs.get("check_health", True) + self.pack_command(*args), + check_health=check_health, + disconnect_on_interrupt=disconnect_on_interrupt, ) async def can_read_destructive(self): @@ -787,6 +799,8 @@ async def read_response( self, disable_decoding: bool = False, timeout: Optional[float] = None, + *, + disconnect_on_interrupt: bool = True, ): """Read the response from a previously sent command""" read_timeout = timeout if timeout is not None else self.socket_timeout @@ -812,12 +826,13 @@ async def read_response( raise ConnectionError( f"Error while reading from {self.host}:{self.port} : {e.args}" ) - except asyncio.CancelledError: - # need this check for 3.7, where CancelledError - # is subclass of Exception, not BaseException - raise - except Exception: - await self.disconnect(nowait=True) + except BaseException: + # On interrupt (e.g. by CancelledError) there's no way to determine + # how much data, if any, was successfully read, so this socket is unusable + # for subsequent commands (which may read previous command's response + # as their own). + if disconnect_on_interrupt: + await self.disconnect(nowait=True) raise if self.health_check_interval: diff --git a/redis/client.py b/redis/client.py index ed857c8fba..5ccec2f440 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1462,10 +1462,16 @@ def execute_command(self, *args): # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) connection = self.connection - kwargs = {"check_health": not self.subscribed} if not self.subscribed: self.clean_health_check_responses() - self._execute(connection, connection.send_command, *args, **kwargs) + self._execute( + connection, + lambda: connection.send_command( + *args, + check_health=not self.subscribed, + disconnect_on_interrupt=True, + ), + ) def clean_health_check_responses(self): """ @@ -1474,7 +1480,7 @@ def clean_health_check_responses(self): ttl = 10 conn = self.connection while self.health_check_response_counter > 0 and ttl > 0: - if self._execute(conn, conn.can_read, timeout=conn.socket_timeout): + if self._execute(conn, lambda: conn.can_read(timeout=conn.socket_timeout)): response = self._execute(conn, conn.read_response) if self.is_health_check_response(response): self.health_check_response_counter -= 1 @@ -1496,7 +1502,7 @@ def _disconnect_raise_connect(self, conn, error): raise error conn.connect() - def _execute(self, conn, command, *args, **kwargs): + def _execute(self, conn, command): """ Connect manually upon disconnection. If the Redis server is down, this will fail and raise a ConnectionError as desired. @@ -1505,7 +1511,7 @@ def _execute(self, conn, command, *args, **kwargs): patterns we were previously listening to """ return conn.retry.call_with_retry( - lambda: command(*args, **kwargs), + command, lambda error: self._disconnect_raise_connect(conn, error), ) @@ -1526,7 +1532,7 @@ def try_read(): return None else: conn.connect() - return conn.read_response() + return conn.read_response(disconnect_on_interrupt=False) response = self._execute(conn, try_read) @@ -1556,7 +1562,12 @@ def check_health(self): ) if conn.health_check_interval and time.time() > conn.next_health_check: - conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False) + conn.send_command( + "PING", + self.HEALTH_CHECK_MESSAGE, + check_health=False, + disconnect_on_interrupt=False, + ) self.health_check_response_counter += 1 def _normalize_keys(self, data): diff --git a/redis/cluster.py b/redis/cluster.py index 0b2c4f1387..f2df61f7d1 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1701,7 +1701,7 @@ def execute_command(self, *args, **kwargs): # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) connection = self.connection - self._execute(connection, connection.send_command, *args) + self._execute(connection, lambda: connection.send_command(*args)) def get_redis_connection(self): """ diff --git a/redis/connection.py b/redis/connection.py index 9c5b536f89..ada4b6681f 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -755,7 +755,7 @@ def check_health(self): if self.health_check_interval and time() > self.next_health_check: self.retry.call_with_retry(self._send_ping, self._ping_failed) - def send_packed_command(self, command, check_health=True): + def send_packed_command(self, command, check_health=True, *, disconnect_on_interrupt=True): """Send an already packed command to the Redis server""" if not self._sock: self.connect() @@ -781,11 +781,20 @@ def send_packed_command(self, command, check_health=True): except Exception: self.disconnect() raise + except BaseException: + # On interrupt (e.g. by gevent.Timeout) there's no way to determine + # how much data, if any, was successfully sent, so this socket is unusable + # for subsequent commands (which may concatenate to an unfinished command). + if disconnect_on_interrupt: + self.disconnect() + raise - def send_command(self, *args, **kwargs): + def send_command(self, *args, check_health=True, disconnect_on_interrupt=True, **kwargs): """Pack and send a command to the Redis server""" self.send_packed_command( - self.pack_command(*args), check_health=kwargs.get("check_health", True) + self.pack_command(*args), + check_health=check_health, + disconnect_on_interrupt=disconnect_on_interrupt, ) def can_read(self, timeout=0): @@ -801,7 +810,7 @@ def can_read(self, timeout=0): f"Error while reading from {self.host}:{self.port}: {e.args}" ) - def read_response(self, disable_decoding=False): + def read_response(self, disable_decoding=False, *, disconnect_on_interrupt=True): """Read the response from a previously sent command""" try: hosterr = f"{self.host}:{self.port}" @@ -819,6 +828,14 @@ def read_response(self, disable_decoding=False): except Exception: self.disconnect() raise + except BaseException: + # On interrupt (e.g. by gevent.Timeout) there's no way to determine + # how much data, if any, was successfully read, so this socket is unusable + # for subsequent commands (which may read previous command's response + # as their own). + if disconnect_on_interrupt: + self.disconnect() + raise if self.health_check_interval: self.next_health_check = time() + self.health_check_interval diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 6bf0034146..13f767ef03 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -112,3 +112,25 @@ async def test_connect_timeout_error_without_retry(): await conn.connect() assert conn._connect.call_count == 1 assert str(e.value) == "Timeout connecting to server" + + +@pytest.mark.parametrize('exc_type', [Exception, BaseException]) +async def test_read_response__interrupt_does_not_corrupt(exc_type): + conn = Connection() + + await conn.send_command("GET non_existent_key") + resp = await conn.read_response() + assert resp is None + + with pytest.raises(exc_type): + await conn.send_command("EXISTS non_existent_key") + # due to the interrupt, the integer '0' result of EXISTS will remain on the socket's buffer + with patch.object(socket.socket, "recv", side_effect=exc_type) as mock_recv: + await conn.read_response() + mock_recv.assert_called_once() + + await conn.send_command("GET non_existent_key") + resp = await conn.read_response() + # If working properly, this will get a None. + # If not, it will get a zero (the integer result of the previous EXISTS command). + assert resp is None diff --git a/tests/test_connection.py b/tests/test_connection.py index d9251c31dc..af012c70ec 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -122,3 +122,40 @@ def test_connect_timeout_error_without_retry(self): assert conn._connect.call_count == 1 assert str(e.value) == "Timeout connecting to server" self.clear(conn) + + @pytest.mark.parametrize('exc_type', [Exception, BaseException]) + def test_read_response__interrupt_does_not_corrupt(self, exc_type): + conn = Connection() + + # A note on BaseException: + # While socket.recv is not supposed to raise BaseException, gevent's version + # of socket (which, when using gevent + redis-py, one would monkey-patch in) + # can raise BaseException on a timer elapse, since `gevent.Timeout` derives + # from BaseException. This design suggests that a timeout should + # not be suppressed but rather allowed to propagate. + # asyncio.exceptions.CancelledError also derives from BaseException + # for same reason. + # + # The notion that one should never `expect:` or `expect BaseException`, + # however, is misguided. It's idiomatic to handle it, to provide + # for exception safety, as long as you re-raise. + # + # with gevent.Timeout(5): + # res = client.exists('my_key') + + conn.send_command("GET non_existent_key") + resp = conn.read_response() + assert resp is None + + with pytest.raises(exc_type): + conn.send_command("EXISTS non_existent_key") + # due to the interrupt, the integer '0' result of EXISTS will remain on the socket's buffer + with patch.object(socket.socket, "recv", side_effect=exc_type) as mock_recv: + _ = conn.read_response() + mock_recv.assert_called_once() + + conn.send_command("GET non_existent_key") + resp = conn.read_response() + # If working properly, this will get a None. + # If not, it will get a zero (the integer result of the previous EXISTS command). + assert resp is None