Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 62 additions & 4 deletions redis/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from enum import Enum
from typing import Any, List, Optional, Union

from redis.event import EventDispatcherInterface, EventDispatcher, \
OnCacheEvictionEvent, OnCacheHitEvent
from redis.observability.attributes import CSCResult, CSCReason


class CacheEntryStatus(Enum):
VALID = "VALID"
Expand Down Expand Up @@ -186,9 +190,6 @@ def set(self, entry: CacheEntry) -> bool:
self._cache[entry.cache_key] = entry
self._eviction_policy.touch(entry.cache_key)

if self._cache_config.is_exceeds_max_size(len(self._cache)):
self._eviction_policy.evict_next()

return True

def get(self, key: CacheKey) -> Union[CacheEntry, None]:
Expand Down Expand Up @@ -247,6 +248,63 @@ def flush(self) -> int:
def is_cachable(self, key: CacheKey) -> bool:
return self._cache_config.is_allowed_to_cache(key.command)

class CacheProxy(CacheInterface):
"""
Proxy object that wraps cache implementations to enable additional logic on top
"""
def __init__(self, cache: CacheInterface, event_dispatcher: Optional[EventDispatcherInterface] = None):
self._cache = cache

if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher

@property
def collection(self) -> OrderedDict:
return self._cache.collection

@property
def config(self) -> CacheConfigurationInterface:
return self._cache.config

@property
def eviction_policy(self) -> EvictionPolicyInterface:
return self._cache.eviction_policy

@property
def size(self) -> int:
return self._cache.size

def get(self, key: CacheKey) -> Union[CacheEntry, None]:
return self._cache.get(key)

def set(self, entry: CacheEntry) -> bool:
is_set = self._cache.set(entry)

if self.config.is_exceeds_max_size(self.size):
self._event_dispatcher.dispatch(
OnCacheEvictionEvent(
count=1,
reason=CSCReason.FULL,
)
)
self.eviction_policy.evict_next()

return is_set

def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]:
return self._cache.delete_by_cache_keys(cache_keys)

def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]:
return self._cache.delete_by_redis_keys(redis_keys)

def flush(self) -> int:
return self._cache.flush()

def is_cachable(self, key: CacheKey) -> bool:
return self._cache.is_cachable(key)


class LRUPolicy(EvictionPolicyInterface):
def __init__(self):
Expand Down Expand Up @@ -422,4 +480,4 @@ def __init__(self, cache_config: Optional[CacheConfig] = None):

def get_cache(self) -> CacheInterface:
cache_class = self._config.get_cache_class()
return cache_class(cache_config=self._config)
return CacheProxy(cache_class(cache_config=self._config))
114 changes: 90 additions & 24 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@
CacheFactory,
CacheFactoryInterface,
CacheInterface,
CacheKey,
CacheKey, CacheProxy,
)

from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
from .auth.token import TokenInterface
from .backoff import NoBackoff
from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
from .event import AfterConnectionReleasedEvent, EventDispatcher, OnErrorEvent, OnMaintenanceNotificationEvent, \
AfterConnectionCreatedEvent
AfterConnectionCreatedEvent, AfterConnectionAcquiredEvent, AfterConnectionClosedEvent, OnCacheHitEvent, \
OnCacheMissEvent, OnCacheEvictionEvent, OnCacheInitializationEvent
from .exceptions import (
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
Expand All @@ -55,7 +56,8 @@
MaintNotificationsPoolHandler, MaintenanceNotification,
)
from .observability.attributes import AttributeBuilder, DB_CLIENT_CONNECTION_STATE, ConnectionState, \
DB_CLIENT_CONNECTION_POOL_NAME
DB_CLIENT_CONNECTION_POOL_NAME, CSCReason
from .observability.metrics import CloseReason
from .retry import Retry
from .utils import (
CRYPTOGRAPHY_AVAILABLE,
Expand Down Expand Up @@ -1068,6 +1070,11 @@ def disconnect(self, *args):
pass

if len(args) > 0 and isinstance(args[0], Exception):
if len(args) > 2 and args[2]:
close_reason = CloseReason.HEALTHCHECK_FAILED
else:
close_reason = CloseReason.ERROR

if args[1] <= self.retry.get_retries():
self._event_dispatcher.dispatch(
OnErrorEvent(
Expand All @@ -1078,6 +1085,19 @@ def disconnect(self, *args):
)
)

self._event_dispatcher.dispatch(
AfterConnectionClosedEvent(
close_reason=close_reason,
error=args[0],
)
)
else:
self._event_dispatcher.dispatch(
AfterConnectionClosedEvent(
close_reason=CloseReason.APPLICATION_CLOSE
)
)

def mark_for_reconnect(self):
self._should_reconnect = True

Expand All @@ -1095,7 +1115,7 @@ def _send_ping(self):

def _ping_failed(self, error, failure_count):
"""Function to call when PING fails"""
self.disconnect(error, failure_count)
self.disconnect(error, failure_count, True)

def check_health(self):
"""Check the health of the connection with a PING/PONG"""
Expand Down Expand Up @@ -1380,6 +1400,8 @@ def __init__(
self.retry = self._conn.retry
self.host = self._conn.host
self.port = self._conn.port
self.db = self._conn.db
self._event_dispatcher = self._conn._event_dispatcher
self.credential_provider = conn.credential_provider
self._pool_lock = pool_lock
self._cache = cache
Expand Down Expand Up @@ -1522,17 +1544,28 @@ def read_response(
):
with self._cache_lock:
# Check if command response exists in a cache and it's not in progress.
if (
self._current_command_cache_key is not None
and self._cache.get(self._current_command_cache_key) is not None
and self._cache.get(self._current_command_cache_key).status
!= CacheEntryStatus.IN_PROGRESS
):
res = copy.deepcopy(
self._cache.get(self._current_command_cache_key).cache_value
if self._current_command_cache_key is not None:
if (
self._cache.get(self._current_command_cache_key) is not None
and self._cache.get(self._current_command_cache_key).status
!= CacheEntryStatus.IN_PROGRESS
):
res = copy.deepcopy(
self._cache.get(self._current_command_cache_key).cache_value
)
self._current_command_cache_key = None
self._event_dispatcher.dispatch(
OnCacheHitEvent(
bytes_saved=len(res),
db_namespace=self.db,
)
)
return res
self._event_dispatcher.dispatch(
OnCacheMissEvent(
db_namespace=self.db,
)
)
self._current_command_cache_key = None
return res

response = self._conn.read_response(
disable_decoding=disable_decoding,
Expand Down Expand Up @@ -1696,7 +1729,16 @@ def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]
if data[1] is None:
self._cache.flush()
else:
self._cache.delete_by_redis_keys(data[1])
keys_deleted = self._cache.delete_by_redis_keys(data[1])

if len(keys_deleted) > 0:
self._event_dispatcher.dispatch(
OnCacheEvictionEvent(
count=len(data[1]),
reason=CSCReason.INVALIDATION,
db_namespace=self.db,
)
)


class SSLConnection(Connection):
Expand Down Expand Up @@ -2507,6 +2549,10 @@ def __init__(
self.cache = None
self._cache_factory = cache_factory

self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None)
if self._event_dispatcher is None:
self._event_dispatcher = EventDispatcher()

if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"):
if self._connection_kwargs.get("protocol") not in [3, "3"]:
raise RedisError("Client caching is only supported with RESP version 3")
Expand All @@ -2520,19 +2566,22 @@ def __init__(
self.cache = cache
else:
if self._cache_factory is not None:
self.cache = self._cache_factory.get_cache()
self.cache = CacheProxy(self._cache_factory.get_cache())
else:
self.cache = CacheFactory(
self._connection_kwargs.get("cache_config")
).get_cache()

self._event_dispatcher.dispatch(
OnCacheInitializationEvent(
cache_items_callback=lambda: self.cache.size,
db_namespace=self._connection_kwargs.get("db"),
)
)

connection_kwargs.pop("cache", None)
connection_kwargs.pop("cache_config", None)

self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None)
if self._event_dispatcher is None:
self._event_dispatcher = EventDispatcher()

# a lock to protect the critical section in _checkpid().
# this lock is acquired when the process id changes, such as
# after a fork. during this time, multiple threads in the child
Expand Down Expand Up @@ -2648,6 +2697,8 @@ def _checkpid(self) -> None:
def get_connection(self, command_name=None, *keys, **options) -> "Connection":
"Get a connection from the pool"

# Start timing for observability
start_time_acquired = time.monotonic()
self._checkpid()
is_created = False

Expand All @@ -2656,7 +2707,7 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection":
connection = self._available_connections.pop()
except IndexError:
# Start timing for observability
start_time = time.monotonic()
start_time_created = time.monotonic()

connection = self.make_connection()
is_created = True
Expand Down Expand Up @@ -2691,9 +2742,16 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection":
self._event_dispatcher.dispatch(
AfterConnectionCreatedEvent(
connection_pool=self,
duration_seconds=time.monotonic() - start_time,
duration_seconds=time.monotonic() - start_time_created,
)
)

self._event_dispatcher.dispatch(
AfterConnectionAcquiredEvent(
connection_pool=self,
duration_seconds=time.monotonic() - start_time_acquired,
)
)
return connection

def get_encoder(self) -> Encoder:
Expand Down Expand Up @@ -2957,6 +3015,7 @@ def get_connection(self, command_name=None, *keys, **options):
create new connections when we need to, i.e.: the actual number of
connections will only increase in response to demand.
"""
start_time_acquired = time.monotonic()
# Make sure we haven't changed process.
self._checkpid()
is_created = False
Expand All @@ -2979,7 +3038,7 @@ def get_connection(self, command_name=None, *keys, **options):
# a new connection to add to the pool.
if connection is None:
# Start timing for observability
start_time = time.monotonic()
start_time_created = time.monotonic()
connection = self.make_connection()
is_created = True
finally:
Expand Down Expand Up @@ -3014,10 +3073,17 @@ def get_connection(self, command_name=None, *keys, **options):
self._event_dispatcher.dispatch(
AfterConnectionCreatedEvent(
connection_pool=self,
duration_seconds=time.monotonic() - start_time,
duration_seconds=time.monotonic() - start_time_created,
)
)

self._event_dispatcher.dispatch(
AfterConnectionAcquiredEvent(
connection_pool=self,
duration_seconds=time.monotonic() - start_time_acquired,
)
)

return connection

def release(self, connection):
Expand Down
Loading