Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 62 additions & 4 deletions redis/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from enum import Enum
from typing import Any, List, Optional, Union

from redis.event import EventDispatcherInterface, EventDispatcher, \
OnCacheEvictionEvent, OnCacheHitEvent
from redis.observability.attributes import CSCResult, CSCReason


class CacheEntryStatus(Enum):
VALID = "VALID"
Expand Down Expand Up @@ -186,9 +190,6 @@ def set(self, entry: CacheEntry) -> bool:
self._cache[entry.cache_key] = entry
self._eviction_policy.touch(entry.cache_key)

if self._cache_config.is_exceeds_max_size(len(self._cache)):
self._eviction_policy.evict_next()

return True

def get(self, key: CacheKey) -> Union[CacheEntry, None]:
Expand Down Expand Up @@ -247,6 +248,63 @@ def flush(self) -> int:
def is_cachable(self, key: CacheKey) -> bool:
return self._cache_config.is_allowed_to_cache(key.command)

class CacheProxy(CacheInterface):
"""
Proxy object that wraps cache implementations to enable additional logic on top
"""
def __init__(self, cache: CacheInterface, event_dispatcher: Optional[EventDispatcherInterface] = None):
self._cache = cache

if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher

@property
def collection(self) -> OrderedDict:
return self._cache.collection

@property
def config(self) -> CacheConfigurationInterface:
return self._cache.config

@property
def eviction_policy(self) -> EvictionPolicyInterface:
return self._cache.eviction_policy

@property
def size(self) -> int:
return self._cache.size

def get(self, key: CacheKey) -> Union[CacheEntry, None]:
return self._cache.get(key)

def set(self, entry: CacheEntry) -> bool:
is_set = self._cache.set(entry)

if self.config.is_exceeds_max_size(self.size):
self._event_dispatcher.dispatch(
OnCacheEvictionEvent(
count=1,
reason=CSCReason.FULL,
)
)
self.eviction_policy.evict_next()

return is_set

def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]:
return self._cache.delete_by_cache_keys(cache_keys)

def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]:
return self._cache.delete_by_redis_keys(redis_keys)

def flush(self) -> int:
return self._cache.flush()

def is_cachable(self, key: CacheKey) -> bool:
return self._cache.is_cachable(key)


class LRUPolicy(EvictionPolicyInterface):
def __init__(self):
Expand Down Expand Up @@ -422,4 +480,4 @@ def __init__(self, cache_config: Optional[CacheConfig] = None):

def get_cache(self) -> CacheInterface:
cache_class = self._config.get_cache_class()
return cache_class(cache_config=self._config)
return CacheProxy(cache_class(cache_config=self._config))
68 changes: 49 additions & 19 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@
CacheFactory,
CacheFactoryInterface,
CacheInterface,
CacheKey,
CacheKey, CacheProxy,
)

from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
from .auth.token import TokenInterface
from .backoff import NoBackoff
from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
from .event import AfterConnectionReleasedEvent, EventDispatcher, OnErrorEvent, OnMaintenanceNotificationEvent, \
AfterConnectionCreatedEvent, AfterConnectionAcquiredEvent, AfterConnectionClosedEvent
AfterConnectionCreatedEvent, AfterConnectionAcquiredEvent, AfterConnectionClosedEvent, OnCacheHitEvent, \
OnCacheMissEvent, OnCacheEvictionEvent, OnCacheInitializationEvent
from .exceptions import (
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
Expand All @@ -55,7 +56,7 @@
MaintNotificationsPoolHandler, MaintenanceNotification,
)
from .observability.attributes import AttributeBuilder, DB_CLIENT_CONNECTION_STATE, ConnectionState, \
DB_CLIENT_CONNECTION_POOL_NAME
DB_CLIENT_CONNECTION_POOL_NAME, CSCReason
from .observability.metrics import CloseReason
from .retry import Retry
from .utils import (
Expand Down Expand Up @@ -1403,6 +1404,8 @@ def __init__(
self.retry = self._conn.retry
self.host = self._conn.host
self.port = self._conn.port
self.db = self._conn.db
self._event_dispatcher = self._conn._event_dispatcher
self.credential_provider = conn.credential_provider
self._pool_lock = pool_lock
self._cache = cache
Expand Down Expand Up @@ -1545,17 +1548,28 @@ def read_response(
):
with self._cache_lock:
# Check if command response exists in a cache and it's not in progress.
if (
self._current_command_cache_key is not None
and self._cache.get(self._current_command_cache_key) is not None
and self._cache.get(self._current_command_cache_key).status
!= CacheEntryStatus.IN_PROGRESS
):
res = copy.deepcopy(
self._cache.get(self._current_command_cache_key).cache_value
if self._current_command_cache_key is not None:
if (
self._cache.get(self._current_command_cache_key) is not None
and self._cache.get(self._current_command_cache_key).status
!= CacheEntryStatus.IN_PROGRESS
):
res = copy.deepcopy(
self._cache.get(self._current_command_cache_key).cache_value
)
self._current_command_cache_key = None
self._event_dispatcher.dispatch(
OnCacheHitEvent(
bytes_saved=len(res),
db_namespace=self.db,
)
)
return res
self._event_dispatcher.dispatch(
OnCacheMissEvent(
db_namespace=self.db,
)
)
self._current_command_cache_key = None
return res

response = self._conn.read_response(
disable_decoding=disable_decoding,
Expand Down Expand Up @@ -1719,7 +1733,16 @@ def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]
if data[1] is None:
self._cache.flush()
else:
self._cache.delete_by_redis_keys(data[1])
keys_deleted = self._cache.delete_by_redis_keys(data[1])

if len(keys_deleted) > 0:
self._event_dispatcher.dispatch(
OnCacheEvictionEvent(
count=len(keys_deleted),
reason=CSCReason.INVALIDATION,
db_namespace=self.db,
)
)


class SSLConnection(Connection):
Expand Down Expand Up @@ -2530,6 +2553,10 @@ def __init__(
self.cache = None
self._cache_factory = cache_factory

self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None)
if self._event_dispatcher is None:
self._event_dispatcher = EventDispatcher()

if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"):
if self._connection_kwargs.get("protocol") not in [3, "3"]:
raise RedisError("Client caching is only supported with RESP version 3")
Expand All @@ -2543,19 +2570,22 @@ def __init__(
self.cache = cache
else:
if self._cache_factory is not None:
self.cache = self._cache_factory.get_cache()
self.cache = CacheProxy(self._cache_factory.get_cache())
else:
self.cache = CacheFactory(
self._connection_kwargs.get("cache_config")
).get_cache()

self._event_dispatcher.dispatch(
OnCacheInitializationEvent(
cache_items_callback=lambda: self.cache.size,
db_namespace=self._connection_kwargs.get("db"),
)
)

connection_kwargs.pop("cache", None)
connection_kwargs.pop("cache_config", None)

self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None)
if self._event_dispatcher is None:
self._event_dispatcher = EventDispatcher()

# a lock to protect the critical section in _checkpid().
# this lock is acquired when the process id changes, such as
# after a fork. during this time, multiple threads in the child
Expand Down
102 changes: 95 additions & 7 deletions redis/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@

from redis.auth.token import TokenInterface
from redis.credentials import CredentialProvider, StreamingCredentialProvider
from redis.observability.attributes import PubSubDirection
from redis.observability.attributes import PubSubDirection, CSCResult, CSCReason
from redis.observability.recorder import record_operation_duration, record_error_count, record_maint_notification_count, \
record_connection_create_time, init_connection_count, record_connection_relaxed_timeout, record_connection_handoff, \
record_pubsub_message, record_streaming_lag, record_connection_wait_time, record_connection_use_time, \
record_connection_closed
record_connection_closed, record_csc_request, init_csc_items, record_csc_eviction, record_csc_network_saved, \
register_pools_connection_count, register_csc_items_callback
from redis.utils import str_if_bytes


Expand Down Expand Up @@ -103,10 +104,6 @@ def __init__(
AsyncAfterConnectionReleasedEvent: [
AsyncReAuthConnectionListener(),
],
OnErrorEvent: [ExportErrorCountMetric()],
OnMaintenanceNotificationEvent: [
ExportMaintenanceNotificationCountMetric(),
],
AfterConnectionCreatedEvent: [ExportConnectionCreateTimeMetric()],
AfterConnectionTimeoutUpdatedEvent: [
ExportConnectionRelaxedTimeoutMetric(),
Expand All @@ -126,6 +123,14 @@ def __init__(
OnStreamMessageReceivedEvent: [
ExportStreamingLagMetric(),
],
OnErrorEvent: [ExportErrorCountMetric()],
OnMaintenanceNotificationEvent: [
ExportMaintenanceNotificationCountMetric(),
],
OnCacheInitializationEvent: [InitializeCSCItemsObservability()],
OnCacheEvictionEvent: [ExportCSCEvictionMetric()],
OnCacheHitEvent: [ExportCSCNetworkSavedMetric(), ExportCSCRequestMetric()],
OnCacheMissEvent: [ExportCSCRequestMetric()],
}

self._lock = threading.Lock()
Expand Down Expand Up @@ -420,6 +425,38 @@ class AfterConnectionClosedEvent:
close_reason: "CloseReason"
error: Optional[Exception] = None

@dataclass
class OnCacheHitEvent:
"""
Event fired whenever a cache hit occurs.
"""
bytes_saved: int
db_namespace: Optional[int] = None

@dataclass
class OnCacheMissEvent:
"""
Event fired whenever a cache miss occurs.
"""
db_namespace: Optional[int] = None

@dataclass
class OnCacheInitializationEvent:
"""
Event fired after cache is initialized.
"""
cache_items_callback: Callable
db_namespace: Optional[int] = None

@dataclass
class OnCacheEvictionEvent:
"""
Event fired after cache eviction.
"""
count: int
reason: CSCReason
db_namespace: Optional[int] = None

class AsyncOnCommandsFailEvent(OnCommandsFailEvent):
pass

Expand Down Expand Up @@ -651,7 +688,11 @@ class InitializeConnectionCountObservability(EventListenerInterface):
Listener that initializes connection count observability.
"""
def listen(self, event: AfterPooledConnectionsInstantiationEvent):
init_connection_count(event.connection_pools)
# Initialize gauge only once, subsequent calls won't have an affect.
init_connection_count()

# Register pools for connection count observability.
register_pools_connection_count(event.connection_pools)

class ExportConnectionRelaxedTimeoutMetric(EventListenerInterface):
"""
Expand Down Expand Up @@ -747,3 +788,50 @@ def listen(self, event: AfterConnectionClosedEvent):
close_reason=event.close_reason,
error_type=event.error,
)

class ExportCSCRequestMetric(EventListenerInterface):
"""
Listener that exports CSC request metric.
"""
def listen(self, event: Union[OnCacheHitEvent, OnCacheMissEvent]):
if isinstance(event, OnCacheHitEvent):
result = CSCResult.HIT
else:
result = CSCResult.MISS

record_csc_request(
db_namespace=event.db_namespace,
result=result,
)

class InitializeCSCItemsObservability(EventListenerInterface):
"""
Listener that initializes CSC items observability.
"""
def listen(self, event: OnCacheInitializationEvent):
# Initialize gauge only once, subsequent calls won't have an affect.
init_csc_items()

# Register cache items callback for CSC items observability.
register_csc_items_callback(event.cache_items_callback, event.db_namespace)

class ExportCSCEvictionMetric(EventListenerInterface):
"""
Listener that exports CSC eviction metric.
"""
def listen(self, event: OnCacheEvictionEvent):
record_csc_eviction(
count=event.count,
reason=event.reason,
db_namespace=event.db_namespace,
)

class ExportCSCNetworkSavedMetric(EventListenerInterface):
"""
Listener that exports CSC network saved metric.
"""
def listen(self, event: OnCacheHitEvent):
record_csc_network_saved(
bytes_saved=event.bytes_saved,
db_namespace=event.db_namespace,
)
Loading