diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index fae6875d82..502a0f896a 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -1,5 +1,7 @@ import asyncio +import inspect import random +import socket import weakref from typing import AsyncIterator, Iterable, Mapping, Optional, Sequence, Tuple, Type @@ -11,8 +13,13 @@ SSLConnection, ) from redis.commands import AsyncSentinelCommands -from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError -from redis.utils import str_if_bytes +from redis.exceptions import ( + ConnectionError, + ReadOnlyError, + RedisError, + ResponseError, + TimeoutError, +) class MasterNotFoundError(ConnectionError): @@ -37,11 +44,47 @@ def __repr__(self): async def connect_to(self, address): self.host, self.port = address - await super().connect() - if self.connection_pool.check_connection: - await self.send_command("PING") - if str_if_bytes(await self.read_response()) != "PONG": - raise ConnectionError("PING failed") + + if self.is_connected: + return + try: + await self._connect() + except asyncio.CancelledError: + raise # in 3.7 and earlier, this is an Exception, not BaseException + except (socket.timeout, asyncio.TimeoutError): + raise TimeoutError("Timeout connecting to server") + except OSError as e: + raise ConnectionError(self._error_message(e)) + except Exception as exc: + raise ConnectionError(exc) from exc + + try: + if not self.redis_connect_func: + # Use the default on_connect function + await self.on_connect_check_health( + check_health=self.connection_pool.check_connection + ) + else: + # Use the passed function redis_connect_func + ( + await self.redis_connect_func(self) + if asyncio.iscoroutinefunction(self.redis_connect_func) + else self.redis_connect_func(self) + ) + except RedisError: + # clean up after any error in on_connect + await self.disconnect() + raise + + # run any user callbacks. right now the only internal callback + # is for pubsub channel/pattern resubscription + # first, remove any dead weakrefs + self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] + for ref in self._connect_callbacks: + callback = ref() + task = callback(self) + if task and inspect.isawaitable(task): + await task async def _connect_retry(self): if self._reader: diff --git a/redis/sentinel.py b/redis/sentinel.py index 02aa244ede..d03d15744c 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -1,12 +1,18 @@ import random +import socket import weakref from typing import Optional from redis.client import Redis from redis.commands import SentinelCommands from redis.connection import Connection, ConnectionPool, SSLConnection -from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError -from redis.utils import str_if_bytes +from redis.exceptions import ( + ConnectionError, + ReadOnlyError, + RedisError, + ResponseError, + TimeoutError, +) class MasterNotFoundError(ConnectionError): @@ -35,11 +41,39 @@ def __repr__(self): def connect_to(self, address): self.host, self.port = address - super().connect() - if self.connection_pool.check_connection: - self.send_command("PING") - if str_if_bytes(self.read_response()) != "PONG": - raise ConnectionError("PING failed") + + if self._sock: + return + try: + sock = self._connect() + except socket.timeout: + raise TimeoutError("Timeout connecting to server") + except OSError as e: + raise ConnectionError(self._error_message(e)) + + self._sock = sock + try: + if self.redis_connect_func is None: + # Use the default on_connect function + self.on_connect_check_health( + check_health=self.connection_pool.check_connection + ) + else: + # Use the passed function redis_connect_func + self.redis_connect_func(self) + except RedisError: + # clean up after any error in on_connect + self.disconnect() + raise + + # run any user callbacks. right now the only internal callback + # is for pubsub channel/pattern resubscription + # first, remove any dead weakrefs + self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] + for ref in self._connect_callbacks: + callback = ref() + if callback: + callback(self) def _connect_retry(self): if self._sock: @@ -294,13 +328,16 @@ def discover_master(self, service_name): """ collected_errors = list() for sentinel_no, sentinel in enumerate(self.sentinels): + # print(f"Sentinel: {sentinel_no}") try: masters = sentinel.sentinel_masters() except (ConnectionError, TimeoutError) as e: collected_errors.append(f"{sentinel} - {e!r}") continue state = masters.get(service_name) + # print(f"Found master: {state}") if state and self.check_master_state(state, service_name): + # print("Valid state") # Put this sentinel at the top of the list self.sentinels[0], self.sentinels[sentinel_no] = ( sentinel, @@ -313,6 +350,7 @@ def discover_master(self, service_name): else state["ip"] ) return ip, state["port"] + # print("Ignoring it") error_info = "" if len(collected_errors) > 0: diff --git a/tests/test_asyncio/test_sentinel_managed_connection.py b/tests/test_asyncio/test_sentinel_managed_connection.py index cae4b9581f..ddcda9c179 100644 --- a/tests/test_asyncio/test_sentinel_managed_connection.py +++ b/tests/test_asyncio/test_sentinel_managed_connection.py @@ -34,4 +34,5 @@ async def mock_connect(): conn._connect.side_effect = mock_connect await conn.connect() assert conn._connect.call_count == 3 + assert connection_pool.get_master_address.call_count == 3 await conn.disconnect() diff --git a/tests/test_sentinel_managed_connection.py b/tests/test_sentinel_managed_connection.py new file mode 100644 index 0000000000..6fe5f7cd5b --- /dev/null +++ b/tests/test_sentinel_managed_connection.py @@ -0,0 +1,34 @@ +import socket + +from redis.retry import Retry +from redis.sentinel import SentinelManagedConnection +from redis.backoff import NoBackoff +from unittest import mock + + +def test_connect_retry_on_timeout_error(master_host): + """Test that the _connect function is retried in case of a timeout""" + connection_pool = mock.Mock() + connection_pool.get_master_address = mock.Mock( + return_value=(master_host[0], master_host[1]) + ) + conn = SentinelManagedConnection( + retry_on_timeout=True, + retry=Retry(NoBackoff(), 3), + connection_pool=connection_pool, + ) + origin_connect = conn._connect + conn._connect = mock.Mock() + + def mock_connect(): + # connect only on the last retry + if conn._connect.call_count <= 2: + raise socket.timeout + else: + return origin_connect() + + conn._connect.side_effect = mock_connect + conn.connect() + assert conn._connect.call_count == 3 + assert connection_pool.get_master_address.call_count == 3 + conn.disconnect()