Skip to content

Don't corrupt connection on interruption #2500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,14 @@ async def execute_command(self, *args: EncodableT):

await self.connect()
connection = self.connection
kwargs = {"check_health": not self.subscribed}
await self._execute(connection, connection.send_command, *args, **kwargs)
await self._execute(
connection,
lambda: connection.send_command(
*args,
check_health=not self.subscribed,
disconnect_on_interrupt=False,
),
)

async def connect(self):
"""
Expand All @@ -753,7 +759,7 @@ async def _disconnect_raise_connect(self, conn, error):
raise error
await conn.connect()

async def _execute(self, conn, command, *args, **kwargs):
async def _execute(self, conn, command):
"""
Connect manually upon disconnection. If the Redis server is down,
this will fail and raise a ConnectionError as desired.
Expand All @@ -762,7 +768,7 @@ async def _execute(self, conn, command, *args, **kwargs):
patterns we were previously listening to
"""
return await conn.retry.call_with_retry(
lambda: command(*args, **kwargs),
command,
lambda error: self._disconnect_raise_connect(conn, error),
)

Expand All @@ -781,7 +787,13 @@ async def parse_response(self, block: bool = True, timeout: float = 0):
await conn.connect()

read_timeout = None if block else timeout
response = await self._execute(conn, conn.read_response, timeout=read_timeout)
response = await self._execute(
conn,
lambda: conn.read_response(
timeout=read_timeout,
disconnect_on_interrupt=False,
),
)

if conn.health_check_interval and response == self.health_check_response:
# ignore the health check message as user might not expect it
Expand All @@ -801,7 +813,10 @@ async def check_health(self):
and asyncio.get_event_loop().time() > conn.next_health_check
):
await conn.send_command(
"PING", self.HEALTH_CHECK_MESSAGE, check_health=False
"PING",
self.HEALTH_CHECK_MESSAGE,
check_health=False,
disconnect_on_interrupt=False,
)

def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT:
Expand Down
37 changes: 26 additions & 11 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,11 @@ async def _send_packed_command(self, command: Iterable[bytes]) -> None:
await self._writer.drain()

async def send_packed_command(
self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
self,
command: Union[bytes, str, Iterable[bytes]],
check_health: bool = True,
*,
disconnect_on_interrupt: bool = True,
) -> None:
if not self.is_connected:
await self.connect()
Expand Down Expand Up @@ -763,14 +767,22 @@ async def send_packed_command(
raise ConnectionError(
f"Error {err_no} while writing to socket. {errmsg}."
) from e
except Exception:
await self.disconnect(nowait=True)
except BaseException:
# On interrupt (e.g. by CancelledError) there's no way to determine
# how much data, if any, was successfully sent, so this socket is unusable
# for subsequent commands (which may concatenate to an unfinished command).
if disconnect_on_interrupt:
await self.disconnect(nowait=True)
raise

async def send_command(self, *args: Any, **kwargs: Any) -> None:
async def send_command(
self, *args, check_health=True, disconnect_on_interrupt=True, **kwargs,
):
"""Pack and send a command to the Redis server"""
await self.send_packed_command(
self.pack_command(*args), check_health=kwargs.get("check_health", True)
self.pack_command(*args),
check_health=check_health,
disconnect_on_interrupt=disconnect_on_interrupt,
)

async def can_read_destructive(self):
Expand All @@ -787,6 +799,8 @@ async def read_response(
self,
disable_decoding: bool = False,
timeout: Optional[float] = None,
*,
disconnect_on_interrupt: bool = True,
):
"""Read the response from a previously sent command"""
read_timeout = timeout if timeout is not None else self.socket_timeout
Expand All @@ -812,12 +826,13 @@ async def read_response(
raise ConnectionError(
f"Error while reading from {self.host}:{self.port} : {e.args}"
)
except asyncio.CancelledError:
# need this check for 3.7, where CancelledError
# is subclass of Exception, not BaseException
raise
except Exception:
await self.disconnect(nowait=True)
except BaseException:
# On interrupt (e.g. by CancelledError) there's no way to determine
# how much data, if any, was successfully read, so this socket is unusable
# for subsequent commands (which may read previous command's response
# as their own).
if disconnect_on_interrupt:
await self.disconnect(nowait=True)
raise

if self.health_check_interval:
Expand Down
25 changes: 18 additions & 7 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,10 +1462,16 @@ def execute_command(self, *args):
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
connection = self.connection
kwargs = {"check_health": not self.subscribed}
if not self.subscribed:
self.clean_health_check_responses()
self._execute(connection, connection.send_command, *args, **kwargs)
self._execute(
connection,
lambda: connection.send_command(
*args,
check_health=not self.subscribed,
disconnect_on_interrupt=True,
),
)

def clean_health_check_responses(self):
"""
Expand All @@ -1474,7 +1480,7 @@ def clean_health_check_responses(self):
ttl = 10
conn = self.connection
while self.health_check_response_counter > 0 and ttl > 0:
if self._execute(conn, conn.can_read, timeout=conn.socket_timeout):
if self._execute(conn, lambda: conn.can_read(timeout=conn.socket_timeout)):
response = self._execute(conn, conn.read_response)
if self.is_health_check_response(response):
self.health_check_response_counter -= 1
Expand All @@ -1496,7 +1502,7 @@ def _disconnect_raise_connect(self, conn, error):
raise error
conn.connect()

def _execute(self, conn, command, *args, **kwargs):
def _execute(self, conn, command):
"""
Connect manually upon disconnection. If the Redis server is down,
this will fail and raise a ConnectionError as desired.
Expand All @@ -1505,7 +1511,7 @@ def _execute(self, conn, command, *args, **kwargs):
patterns we were previously listening to
"""
return conn.retry.call_with_retry(
lambda: command(*args, **kwargs),
command,
lambda error: self._disconnect_raise_connect(conn, error),
)

Expand All @@ -1526,7 +1532,7 @@ def try_read():
return None
else:
conn.connect()
return conn.read_response()
return conn.read_response(disconnect_on_interrupt=False)

response = self._execute(conn, try_read)

Expand Down Expand Up @@ -1556,7 +1562,12 @@ def check_health(self):
)

if conn.health_check_interval and time.time() > conn.next_health_check:
conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False)
conn.send_command(
"PING",
self.HEALTH_CHECK_MESSAGE,
check_health=False,
disconnect_on_interrupt=False,
)
self.health_check_response_counter += 1

def _normalize_keys(self, data):
Expand Down
2 changes: 1 addition & 1 deletion redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1701,7 +1701,7 @@ def execute_command(self, *args, **kwargs):
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
connection = self.connection
self._execute(connection, connection.send_command, *args)
self._execute(connection, lambda: connection.send_command(*args))

def get_redis_connection(self):
"""
Expand Down
25 changes: 21 additions & 4 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ def check_health(self):
if self.health_check_interval and time() > self.next_health_check:
self.retry.call_with_retry(self._send_ping, self._ping_failed)

def send_packed_command(self, command, check_health=True):
def send_packed_command(self, command, check_health=True, *, disconnect_on_interrupt=True):
"""Send an already packed command to the Redis server"""
if not self._sock:
self.connect()
Expand All @@ -781,11 +781,20 @@ def send_packed_command(self, command, check_health=True):
except Exception:
self.disconnect()
raise
except BaseException:
# On interrupt (e.g. by gevent.Timeout) there's no way to determine
# how much data, if any, was successfully sent, so this socket is unusable
# for subsequent commands (which may concatenate to an unfinished command).
if disconnect_on_interrupt:
self.disconnect()
raise

def send_command(self, *args, **kwargs):
def send_command(self, *args, check_health=True, disconnect_on_interrupt=True, **kwargs):
"""Pack and send a command to the Redis server"""
self.send_packed_command(
self.pack_command(*args), check_health=kwargs.get("check_health", True)
self.pack_command(*args),
check_health=check_health,
disconnect_on_interrupt=disconnect_on_interrupt,
)

def can_read(self, timeout=0):
Expand All @@ -801,7 +810,7 @@ def can_read(self, timeout=0):
f"Error while reading from {self.host}:{self.port}: {e.args}"
)

def read_response(self, disable_decoding=False):
def read_response(self, disable_decoding=False, *, disconnect_on_interrupt=True):
"""Read the response from a previously sent command"""
try:
hosterr = f"{self.host}:{self.port}"
Expand All @@ -819,6 +828,14 @@ def read_response(self, disable_decoding=False):
except Exception:
self.disconnect()
raise
except BaseException:
# On interrupt (e.g. by gevent.Timeout) there's no way to determine
# how much data, if any, was successfully read, so this socket is unusable
# for subsequent commands (which may read previous command's response
# as their own).
if disconnect_on_interrupt:
self.disconnect()
raise

if self.health_check_interval:
self.next_health_check = time() + self.health_check_interval
Expand Down
22 changes: 22 additions & 0 deletions tests/test_asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,25 @@ async def test_connect_timeout_error_without_retry():
await conn.connect()
assert conn._connect.call_count == 1
assert str(e.value) == "Timeout connecting to server"


@pytest.mark.parametrize('exc_type', [Exception, BaseException])
async def test_read_response__interrupt_does_not_corrupt(exc_type):
conn = Connection()

await conn.send_command("GET non_existent_key")
resp = await conn.read_response()
assert resp is None

with pytest.raises(exc_type):
await conn.send_command("EXISTS non_existent_key")
# due to the interrupt, the integer '0' result of EXISTS will remain on the socket's buffer
with patch.object(socket.socket, "recv", side_effect=exc_type) as mock_recv:
await conn.read_response()
mock_recv.assert_called_once()

await conn.send_command("GET non_existent_key")
resp = await conn.read_response()
# If working properly, this will get a None.
# If not, it will get a zero (the integer result of the previous EXISTS command).
assert resp is None
37 changes: 37 additions & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,40 @@ def test_connect_timeout_error_without_retry(self):
assert conn._connect.call_count == 1
assert str(e.value) == "Timeout connecting to server"
self.clear(conn)

@pytest.mark.parametrize('exc_type', [Exception, BaseException])
def test_read_response__interrupt_does_not_corrupt(self, exc_type):
conn = Connection()

# A note on BaseException:
# While socket.recv is not supposed to raise BaseException, gevent's version
# of socket (which, when using gevent + redis-py, one would monkey-patch in)
# can raise BaseException on a timer elapse, since `gevent.Timeout` derives
# from BaseException. This design suggests that a timeout should
# not be suppressed but rather allowed to propagate.
# asyncio.exceptions.CancelledError also derives from BaseException
# for same reason.
#
# The notion that one should never `expect:` or `expect BaseException`,
# however, is misguided. It's idiomatic to handle it, to provide
# for exception safety, as long as you re-raise.
#
# with gevent.Timeout(5):
# res = client.exists('my_key')

conn.send_command("GET non_existent_key")
resp = conn.read_response()
assert resp is None

with pytest.raises(exc_type):
conn.send_command("EXISTS non_existent_key")
# due to the interrupt, the integer '0' result of EXISTS will remain on the socket's buffer
with patch.object(socket.socket, "recv", side_effect=exc_type) as mock_recv:
_ = conn.read_response()
mock_recv.assert_called_once()

conn.send_command("GET non_existent_key")
resp = conn.read_response()
# If working properly, this will get a None.
# If not, it will get a zero (the integer result of the previous EXISTS command).
assert resp is None