Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 122 additions & 19 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
UnixDomainSocketConnection,
)
from redis.asyncio.lock import Lock
from redis.asyncio.observability.recorder import record_error_count
from redis.asyncio.retry import Retry
from redis.backoff import ExponentialWithJitterBackoff
from redis.client import (
Expand Down Expand Up @@ -714,7 +715,12 @@ async def _send_command_parse_response(self, conn, command_name, *args, **option
await conn.send_command(*args)
return await self.parse_response(conn, command_name, **options)

async def _close_connection(self, conn: Connection):
async def _close_connection(
self,
conn: Connection,
error: Optional[BaseException] = None,
failure_count: Optional[int] = None,
):
"""
Close the connection before retrying.

Expand All @@ -724,7 +730,7 @@ async def _close_connection(self, conn: Connection):
After we disconnect the connection, it will try to reconnect and
do a health check as part of the send_command logic(on connection level).
"""
await conn.disconnect()
await conn.disconnect(error=error, failure_count=failure_count)

# COMMAND EXECUTION AND PROTOCOL PARSING
async def execute_command(self, *args, **options):
Expand All @@ -734,15 +740,34 @@ async def execute_command(self, *args, **options):
command_name = args[0]
conn = self.connection or await pool.get_connection()

# Track actual retry attempts for error reporting
actual_retry_attempts = [0]

def failure_callback(error, failure_count):
actual_retry_attempts[0] = failure_count
return self._close_connection(conn, error, failure_count)

if self.single_connection_client:
await self._single_conn_lock.acquire()
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda _: self._close_connection(conn),
failure_callback,
with_failure_count=True,
)
except Exception as e:
await record_error_count(
server_address=conn.host,
server_port=conn.port,
network_peer_address=conn.host,
network_peer_port=conn.port,
error_type=e,
retry_attempts=actual_retry_attempts[0],
is_internal=False,
)
raise
finally:
if self.single_connection_client:
self._single_conn_lock.release()
Expand Down Expand Up @@ -1009,11 +1034,16 @@ async def connect(self):
)
)

async def _reconnect(self, conn):
async def _reconnect(
self,
conn,
error: Optional[BaseException] = None,
failure_count: Optional[int] = None,
):
"""
Try to reconnect
"""
await conn.disconnect()
await conn.disconnect(error=error, failure_count=failure_count)
await conn.connect()

async def _execute(self, conn, command, *args, **kwargs):
Expand All @@ -1024,10 +1054,35 @@ async def _execute(self, conn, command, *args, **kwargs):
called by the # connection to resubscribe us to any channels and
patterns we were previously listening to
"""
return await conn.retry.call_with_retry(
lambda: command(*args, **kwargs),
lambda _: self._reconnect(conn),
)
if not len(args) == 0:
command_name = args[0]
else:
command_name = None

# Track actual retry attempts for error reporting
actual_retry_attempts = [0]

def failure_callback(error, failure_count):
actual_retry_attempts[0] = failure_count
return self._reconnect(conn, error, failure_count)

try:
return await conn.retry.call_with_retry(
lambda: command(*args, **kwargs),
failure_callback,
with_failure_count=True,
)
except Exception as e:
await record_error_count(
server_address=conn.host,
server_port=conn.port,
network_peer_address=conn.host,
network_peer_port=conn.port,
error_type=e,
retry_attempts=actual_retry_attempts[0],
is_internal=False,
)
raise

async def parse_response(self, block: bool = True, timeout: float = 0):
"""Parse the response from a publish/subscribe command"""
Expand Down Expand Up @@ -1432,7 +1487,8 @@ async def _disconnect_reset_raise_on_watching(
self,
conn: Connection,
error: Exception,
):
failure_count: Optional[int] = None,
) -> None:
"""
Close the connection reset watching state and
raise an exception if we were watching.
Expand Down Expand Up @@ -1467,12 +1523,32 @@ async def immediate_execute_command(self, *args, **options):
conn = await self.connection_pool.get_connection()
self.connection = conn

return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_reset_raise_on_watching(conn, error),
)
# Track actual retry attempts for error reporting
actual_retry_attempts = [0]

def failure_callback(error, failure_count):
actual_retry_attempts[0] = failure_count
return self._disconnect_reset_raise_on_watching(conn, error, failure_count)

try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
failure_callback,
with_failure_count=True,
)
except Exception as e:
await record_error_count(
server_address=conn.host,
server_port=conn.port,
network_peer_address=conn.host,
network_peer_port=conn.port,
error_type=e,
retry_attempts=actual_retry_attempts[0],
is_internal=False,
)
raise

def pipeline_execute_command(self, *args, **options):
"""
Expand Down Expand Up @@ -1626,7 +1702,12 @@ async def load_scripts(self):
if not exist:
s.sha = await immediate("SCRIPT LOAD", s.script)

async def _disconnect_raise_on_watching(self, conn: Connection, error: Exception):
async def _disconnect_raise_on_watching(
self,
conn: Connection,
error: Exception,
failure_count: Optional[int] = None,
):
"""
Close the connection, raise an exception if we were watching.

Expand All @@ -1636,7 +1717,7 @@ async def _disconnect_raise_on_watching(self, conn: Connection, error: Exception
After we disconnect the connection, it will try to reconnect and
do a health check as part of the send_command logic(on connection level).
"""
await conn.disconnect()
await conn.disconnect(error=error, failure_count=failure_count)
# if we were watching a variable, the watch is no longer valid
# since this connection has died. raise a WatchError, which
# indicates the user should retry this transaction.
Expand All @@ -1654,8 +1735,10 @@ async def execute(self, raise_on_error: bool = True) -> List[Any]:
await self.load_scripts()
if self.is_transaction or self.explicit_transaction:
execute = self._execute_transaction
operation_name = "MULTI"
else:
execute = self._execute_pipeline
operation_name = "PIPELINE"

conn = self.connection
if not conn:
Expand All @@ -1665,11 +1748,31 @@ async def execute(self, raise_on_error: bool = True) -> List[Any]:
self.connection = conn
conn = cast(Connection, conn)

# Track actual retry attempts for error reporting
actual_retry_attempts = [0]
stack_len = len(stack)

def failure_callback(error, failure_count):
actual_retry_attempts[0] = failure_count
return self._disconnect_raise_on_watching(conn, error, failure_count)

try:
return await conn.retry.call_with_retry(
lambda: execute(conn, stack, raise_on_error),
lambda error: self._disconnect_raise_on_watching(conn, error),
failure_callback,
with_failure_count=True,
)
except Exception as e:
await record_error_count(
server_address=conn.host,
server_port=conn.port,
network_peer_address=conn.host,
network_peer_port=conn.port,
error_type=e,
retry_attempts=actual_retry_attempts[0],
is_internal=False,
)
raise
finally:
await self.reset()

Expand Down
Loading