diff --git a/redis/cache.py b/redis/cache.py index 949ad3ddf9..70b22ea027 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -4,6 +4,10 @@ from enum import Enum from typing import Any, List, Optional, Union +from redis.event import EventDispatcherInterface, EventDispatcher, \ + OnCacheEvictionEvent, OnCacheHitEvent +from redis.observability.attributes import CSCResult, CSCReason + class CacheEntryStatus(Enum): VALID = "VALID" @@ -186,9 +190,6 @@ def set(self, entry: CacheEntry) -> bool: self._cache[entry.cache_key] = entry self._eviction_policy.touch(entry.cache_key) - if self._cache_config.is_exceeds_max_size(len(self._cache)): - self._eviction_policy.evict_next() - return True def get(self, key: CacheKey) -> Union[CacheEntry, None]: @@ -247,6 +248,63 @@ def flush(self) -> int: def is_cachable(self, key: CacheKey) -> bool: return self._cache_config.is_allowed_to_cache(key.command) +class CacheProxy(CacheInterface): + """ + Proxy object that wraps cache implementations to enable additional logic on top + """ + def __init__(self, cache: CacheInterface, event_dispatcher: Optional[EventDispatcherInterface] = None): + self._cache = cache + + if event_dispatcher is None: + self._event_dispatcher = EventDispatcher() + else: + self._event_dispatcher = event_dispatcher + + @property + def collection(self) -> OrderedDict: + return self._cache.collection + + @property + def config(self) -> CacheConfigurationInterface: + return self._cache.config + + @property + def eviction_policy(self) -> EvictionPolicyInterface: + return self._cache.eviction_policy + + @property + def size(self) -> int: + return self._cache.size + + def get(self, key: CacheKey) -> Union[CacheEntry, None]: + return self._cache.get(key) + + def set(self, entry: CacheEntry) -> bool: + is_set = self._cache.set(entry) + + if self.config.is_exceeds_max_size(self.size): + self._event_dispatcher.dispatch( + OnCacheEvictionEvent( + count=1, + reason=CSCReason.FULL, + ) + ) + self.eviction_policy.evict_next() + + return is_set + + def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]: + return self._cache.delete_by_cache_keys(cache_keys) + + def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]: + return self._cache.delete_by_redis_keys(redis_keys) + + def flush(self) -> int: + return self._cache.flush() + + def is_cachable(self, key: CacheKey) -> bool: + return self._cache.is_cachable(key) + class LRUPolicy(EvictionPolicyInterface): def __init__(self): @@ -422,4 +480,4 @@ def __init__(self, cache_config: Optional[CacheConfig] = None): def get_cache(self) -> CacheInterface: cache_class = self._config.get_cache_class() - return cache_class(cache_config=self._config) + return CacheProxy(cache_class(cache_config=self._config)) diff --git a/redis/connection.py b/redis/connection.py index 7c2df663b6..ec5ecf6b15 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -28,7 +28,7 @@ CacheFactory, CacheFactoryInterface, CacheInterface, - CacheKey, + CacheKey, CacheProxy, ) from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser @@ -36,7 +36,8 @@ from .backoff import NoBackoff from .credentials import CredentialProvider, UsernamePasswordCredentialProvider from .event import AfterConnectionReleasedEvent, EventDispatcher, OnErrorEvent, OnMaintenanceNotificationEvent, \ - AfterConnectionCreatedEvent, AfterConnectionAcquiredEvent, AfterConnectionClosedEvent + AfterConnectionCreatedEvent, AfterConnectionAcquiredEvent, AfterConnectionClosedEvent, OnCacheHitEvent, \ + OnCacheMissEvent, OnCacheEvictionEvent, OnCacheInitializationEvent from .exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, @@ -55,7 +56,7 @@ MaintNotificationsPoolHandler, MaintenanceNotification, ) from .observability.attributes import AttributeBuilder, DB_CLIENT_CONNECTION_STATE, ConnectionState, \ - DB_CLIENT_CONNECTION_POOL_NAME + DB_CLIENT_CONNECTION_POOL_NAME, CSCReason from .observability.metrics import CloseReason from .retry import Retry from .utils import ( @@ -1403,6 +1404,8 @@ def __init__( self.retry = self._conn.retry self.host = self._conn.host self.port = self._conn.port + self.db = self._conn.db + self._event_dispatcher = self._conn._event_dispatcher self.credential_provider = conn.credential_provider self._pool_lock = pool_lock self._cache = cache @@ -1545,17 +1548,28 @@ def read_response( ): with self._cache_lock: # Check if command response exists in a cache and it's not in progress. - if ( - self._current_command_cache_key is not None - and self._cache.get(self._current_command_cache_key) is not None - and self._cache.get(self._current_command_cache_key).status - != CacheEntryStatus.IN_PROGRESS - ): - res = copy.deepcopy( - self._cache.get(self._current_command_cache_key).cache_value + if self._current_command_cache_key is not None: + if ( + self._cache.get(self._current_command_cache_key) is not None + and self._cache.get(self._current_command_cache_key).status + != CacheEntryStatus.IN_PROGRESS + ): + res = copy.deepcopy( + self._cache.get(self._current_command_cache_key).cache_value + ) + self._current_command_cache_key = None + self._event_dispatcher.dispatch( + OnCacheHitEvent( + bytes_saved=len(res), + db_namespace=self.db, + ) + ) + return res + self._event_dispatcher.dispatch( + OnCacheMissEvent( + db_namespace=self.db, + ) ) - self._current_command_cache_key = None - return res response = self._conn.read_response( disable_decoding=disable_decoding, @@ -1719,7 +1733,16 @@ def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]] if data[1] is None: self._cache.flush() else: - self._cache.delete_by_redis_keys(data[1]) + keys_deleted = self._cache.delete_by_redis_keys(data[1]) + + if len(keys_deleted) > 0: + self._event_dispatcher.dispatch( + OnCacheEvictionEvent( + count=len(keys_deleted), + reason=CSCReason.INVALIDATION, + db_namespace=self.db, + ) + ) class SSLConnection(Connection): @@ -2530,6 +2553,10 @@ def __init__( self.cache = None self._cache_factory = cache_factory + self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None) + if self._event_dispatcher is None: + self._event_dispatcher = EventDispatcher() + if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): if self._connection_kwargs.get("protocol") not in [3, "3"]: raise RedisError("Client caching is only supported with RESP version 3") @@ -2543,19 +2570,22 @@ def __init__( self.cache = cache else: if self._cache_factory is not None: - self.cache = self._cache_factory.get_cache() + self.cache = CacheProxy(self._cache_factory.get_cache()) else: self.cache = CacheFactory( self._connection_kwargs.get("cache_config") ).get_cache() + self._event_dispatcher.dispatch( + OnCacheInitializationEvent( + cache_items_callback=lambda: self.cache.size, + db_namespace=self._connection_kwargs.get("db"), + ) + ) + connection_kwargs.pop("cache", None) connection_kwargs.pop("cache_config", None) - self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None) - if self._event_dispatcher is None: - self._event_dispatcher = EventDispatcher() - # a lock to protect the critical section in _checkpid(). # this lock is acquired when the process id changes, such as # after a fork. during this time, multiple threads in the child diff --git a/redis/event.py b/redis/event.py index c8d07d78c1..3310efe425 100644 --- a/redis/event.py +++ b/redis/event.py @@ -8,11 +8,12 @@ from redis.auth.token import TokenInterface from redis.credentials import CredentialProvider, StreamingCredentialProvider -from redis.observability.attributes import PubSubDirection +from redis.observability.attributes import PubSubDirection, CSCResult, CSCReason from redis.observability.recorder import record_operation_duration, record_error_count, record_maint_notification_count, \ record_connection_create_time, init_connection_count, record_connection_relaxed_timeout, record_connection_handoff, \ record_pubsub_message, record_streaming_lag, record_connection_wait_time, record_connection_use_time, \ - record_connection_closed + record_connection_closed, record_csc_request, init_csc_items, record_csc_eviction, record_csc_network_saved, \ + register_pools_connection_count, register_csc_items_callback from redis.utils import str_if_bytes @@ -103,10 +104,6 @@ def __init__( AsyncAfterConnectionReleasedEvent: [ AsyncReAuthConnectionListener(), ], - OnErrorEvent: [ExportErrorCountMetric()], - OnMaintenanceNotificationEvent: [ - ExportMaintenanceNotificationCountMetric(), - ], AfterConnectionCreatedEvent: [ExportConnectionCreateTimeMetric()], AfterConnectionTimeoutUpdatedEvent: [ ExportConnectionRelaxedTimeoutMetric(), @@ -126,6 +123,14 @@ def __init__( OnStreamMessageReceivedEvent: [ ExportStreamingLagMetric(), ], + OnErrorEvent: [ExportErrorCountMetric()], + OnMaintenanceNotificationEvent: [ + ExportMaintenanceNotificationCountMetric(), + ], + OnCacheInitializationEvent: [InitializeCSCItemsObservability()], + OnCacheEvictionEvent: [ExportCSCEvictionMetric()], + OnCacheHitEvent: [ExportCSCNetworkSavedMetric(), ExportCSCRequestMetric()], + OnCacheMissEvent: [ExportCSCRequestMetric()], } self._lock = threading.Lock() @@ -420,6 +425,38 @@ class AfterConnectionClosedEvent: close_reason: "CloseReason" error: Optional[Exception] = None +@dataclass +class OnCacheHitEvent: + """ + Event fired whenever a cache hit occurs. + """ + bytes_saved: int + db_namespace: Optional[int] = None + +@dataclass +class OnCacheMissEvent: + """ + Event fired whenever a cache miss occurs. + """ + db_namespace: Optional[int] = None + +@dataclass +class OnCacheInitializationEvent: + """ + Event fired after cache is initialized. + """ + cache_items_callback: Callable + db_namespace: Optional[int] = None + +@dataclass +class OnCacheEvictionEvent: + """ + Event fired after cache eviction. + """ + count: int + reason: CSCReason + db_namespace: Optional[int] = None + class AsyncOnCommandsFailEvent(OnCommandsFailEvent): pass @@ -651,7 +688,11 @@ class InitializeConnectionCountObservability(EventListenerInterface): Listener that initializes connection count observability. """ def listen(self, event: AfterPooledConnectionsInstantiationEvent): - init_connection_count(event.connection_pools) + # Initialize gauge only once, subsequent calls won't have an affect. + init_connection_count() + + # Register pools for connection count observability. + register_pools_connection_count(event.connection_pools) class ExportConnectionRelaxedTimeoutMetric(EventListenerInterface): """ @@ -747,3 +788,50 @@ def listen(self, event: AfterConnectionClosedEvent): close_reason=event.close_reason, error_type=event.error, ) + +class ExportCSCRequestMetric(EventListenerInterface): + """ + Listener that exports CSC request metric. + """ + def listen(self, event: Union[OnCacheHitEvent, OnCacheMissEvent]): + if isinstance(event, OnCacheHitEvent): + result = CSCResult.HIT + else: + result = CSCResult.MISS + + record_csc_request( + db_namespace=event.db_namespace, + result=result, + ) + +class InitializeCSCItemsObservability(EventListenerInterface): + """ + Listener that initializes CSC items observability. + """ + def listen(self, event: OnCacheInitializationEvent): + # Initialize gauge only once, subsequent calls won't have an affect. + init_csc_items() + + # Register cache items callback for CSC items observability. + register_csc_items_callback(event.cache_items_callback, event.db_namespace) + +class ExportCSCEvictionMetric(EventListenerInterface): + """ + Listener that exports CSC eviction metric. + """ + def listen(self, event: OnCacheEvictionEvent): + record_csc_eviction( + count=event.count, + reason=event.reason, + db_namespace=event.db_namespace, + ) + +class ExportCSCNetworkSavedMetric(EventListenerInterface): + """ + Listener that exports CSC network saved metric. + """ + def listen(self, event: OnCacheHitEvent): + record_csc_network_saved( + bytes_saved=event.bytes_saved, + db_namespace=event.db_namespace, + ) diff --git a/redis/observability/attributes.py b/redis/observability/attributes.py index 9493ff73da..5bd392d86e 100644 --- a/redis/observability/attributes.py +++ b/redis/observability/attributes.py @@ -48,6 +48,8 @@ REDIS_CLIENT_STREAM_NAME = "redis.client.stream.name" REDIS_CLIENT_CONSUMER_GROUP = "redis.client.consumer_group" REDIS_CLIENT_CONSUMER_NAME = "redis.client.consumer_name" +REDIS_CLIENT_CSC_RESULT = "redis.client.csc.result" +REDIS_CLIENT_CSC_REASON = "redis.client.csc.reason" class ConnectionState(Enum): IDLE = "idle" @@ -57,6 +59,14 @@ class PubSubDirection(Enum): PUBLISH = "publish" RECEIVE = "receive" +class CSCResult(Enum): + HIT = "hit" + MISS = "miss" + +class CSCReason(Enum): + FULL = 'full' + INVALIDATION = 'invalidation' + class AttributeBuilder: """ @@ -274,6 +284,33 @@ def build_streaming_attributes( return attrs + @staticmethod + def build_csc_attributes( + db_namespace: Optional[int] = None, + result: Optional[CSCResult] = None, + reason: Optional[CSCReason] = None, + ) -> Dict[str, Any]: + """ + Build attributes for a Client Side Caching (CSC) operation. + + Args: + db_namespace: Redis database index + result: CSC result ('hit' or 'miss') + reason: Reason for CSC eviction ('full' or 'invalidation') + + Returns: + Dictionary of CSC attributes + """ + attrs: Dict[str, Any] = AttributeBuilder.build_base_attributes(db_namespace=db_namespace) + + if result is not None: + attrs[REDIS_CLIENT_CSC_RESULT] = result.value + + if reason is not None: + attrs[REDIS_CLIENT_CSC_REASON] = reason.value + + return attrs + @staticmethod def build_pool_name( server_address: str, diff --git a/redis/observability/metrics.py b/redis/observability/metrics.py index 65706aa29d..eacc72c9ce 100644 --- a/redis/observability/metrics.py +++ b/redis/observability/metrics.py @@ -11,7 +11,7 @@ from typing import Any, Dict, Optional, Callable, List from redis.observability.attributes import AttributeBuilder, ConnectionState, REDIS_CLIENT_CONNECTION_NOTIFICATION, \ - REDIS_CLIENT_CONNECTION_CLOSE_REASON, ERROR_TYPE, PubSubDirection + REDIS_CLIENT_CONNECTION_CLOSE_REASON, ERROR_TYPE, PubSubDirection, CSCResult, CSCReason from redis.observability.config import OTelConfig, MetricGroup logger = logging.getLogger(__name__) @@ -93,6 +93,9 @@ def __init__(self, meter: Meter, config: OTelConfig): if MetricGroup.STREAMING in self.config.metric_groups: self._init_streaming_metrics() + if MetricGroup.CSC in self.config.metric_groups: + self._init_csc_metrics() + logger.info("RedisMetricsCollector initialized") def _init_resiliency_metrics(self) -> None: @@ -180,6 +183,26 @@ def _init_streaming_metrics(self) -> None: description="End-to-end lag per message, showing how stale are the messages when the application starts processing them." ) + def _init_csc_metrics(self) -> None: + """Initialize Client Side Caching (CSC) metric instruments.""" + self.csc_requests = self.meter.create_counter( + name="redis.client.csc.requests", + unit="{request}", + description="The total number of requests to the cache", + ) + + self.csc_evictions = self.meter.create_counter( + name="redis.client.csc.evictions", + unit="{eviction}", + description="The total number of cache evictions", + ) + + self.csc_network_saved = self.meter.create_counter( + name="redis.client.csc.network_saved", + unit="{bytes}", + description="The total number of bytes saved by using CSC", + ) + # Resiliency metric recording methods def record_error_count( @@ -274,7 +297,8 @@ def init_connection_count( Args: callback: Callback function to retrieve connection count """ - if not MetricGroup.CONNECTION_BASIC in self.config.metric_groups: + if not MetricGroup.CONNECTION_BASIC in self.config.metric_groups \ + and not self.connection_count: return self.connection_count = self.meter.create_observable_gauge( @@ -284,6 +308,27 @@ def init_connection_count( callbacks=[callback], ) + def init_csc_items( + self, + callback: Callable, + ) -> None: + """ + Initialize observable gauge for CSC items metric. + + Args: + callback: Callback function to retrieve CSC items count + """ + if not MetricGroup.CSC in self.config.metric_groups \ + and not self.csc_items: + return + + self.csc_items = self.meter.create_observable_gauge( + name="redis.client.csc.items", + unit="{item}", + description="The total number of cached responses currently stored", + callbacks=[callback], + ) + def record_connection_timeout(self, pool_name: str) -> None: """ Record a connection timeout event. @@ -534,6 +579,64 @@ def record_streaming_lag( ) self.stream_lag.record(lag_seconds, attributes=attrs) + # CSC metric recording methods + + def record_csc_request( + self, + db_namespace: Optional[int] = None, + result: Optional[CSCResult] = None, + ) -> None: + """ + Record a Client Side Caching (CSC) request. + + Args: + db_namespace: Redis database index + result: CSC result ('hit' or 'miss') + """ + if not hasattr(self, "csc_requests"): + return + + attrs = self.attr_builder.build_csc_attributes(result=result, db_namespace=db_namespace) + self.csc_requests.add(1, attributes=attrs) + + def record_csc_eviction( + self, + count: int, + db_namespace: Optional[int] = None, + reason: Optional[CSCReason] = None, + ) -> None: + """ + Record a Client Side Caching (CSC) eviction. + + Args: + count: Number of evictions + db_namespace: Redis database index + reason: Reason for eviction + """ + if not hasattr(self, "csc_evictions"): + return + + attrs = self.attr_builder.build_csc_attributes(reason=reason, db_namespace=db_namespace) + self.csc_evictions.add(count, attributes=attrs) + + def record_csc_network_saved( + self, + bytes_saved: int, + db_namespace: Optional[int] = None, + ) -> None: + """ + Record the number of bytes saved by using Client Side Caching (CSC). + + Args: + db_namespace: Redis database index + bytes_saved: Number of bytes saved + """ + if not hasattr(self, "csc_network_saved"): + return + + attrs = self.attr_builder.build_csc_attributes(db_namespace=db_namespace) + self.csc_network_saved.add(bytes_saved, attributes=attrs) + # Utility methods @staticmethod diff --git a/redis/observability/recorder.py b/redis/observability/recorder.py index a2bb89961d..a385137bd8 100644 --- a/redis/observability/recorder.py +++ b/redis/observability/recorder.py @@ -20,15 +20,21 @@ """ import time -from typing import Optional, Callable +from typing import Optional, Callable, List +from redis.connection import ConnectionPoolInterface +from redis.observability.attributes import PubSubDirection, ConnectionState, CSCResult, CSCReason, AttributeBuilder from redis.observability.attributes import PubSubDirection, ConnectionState from redis.observability.metrics import RedisMetricsCollector, CloseReason from redis.observability.providers import get_observability_instance +from redis.observability.registry import get_observables_registry_instance # Global metrics collector instance (lazy-initialized) _metrics_collector: Optional[RedisMetricsCollector] = None +CONNECTION_COUNT_REGISTRY_KEY = "connection_count" +CSC_ITEMS_REGISTRY_KEY = "csc_items" + def record_operation_duration( command_name: str, @@ -125,14 +131,9 @@ def record_connection_create_time( # pass -def init_connection_count( - connection_pools: list, -) -> None: +def init_connection_count() -> None: """ Initialize observable gauge for connection count metric. - - Args: - connection_pools: Connection pools to collect metrics from. """ global _metrics_collector @@ -141,23 +142,48 @@ def init_connection_count( if _metrics_collector is None: return - # Lazy import - from opentelemetry.metrics import Observation - - def connection_count_callback(__): + def observable_callback(__): + observables_registry = get_observables_registry_instance() + callbacks = observables_registry.get(CONNECTION_COUNT_REGISTRY_KEY) observations = [] - for pool in connection_pools: - for count, attributes in pool.get_connection_count(): - observations.append(Observation(count, attributes)) + + for callback in callbacks: + observations.extend(callback()) + return observations # try: _metrics_collector.init_connection_count( - callback=connection_count_callback, + callback=observable_callback, ) # except Exception: # pass +def register_pools_connection_count( + connection_pools: List["ConnectionPoolInterface"], +) -> None: + """ + Add connection pools to connection count observable registry. + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + # Lazy import + from opentelemetry.metrics import Observation + + def connection_count_callback(): + observations = [] + for connection_pool in connection_pools: + for count, attributes in connection_pool.get_connection_count(): + observations.append(Observation(count, attributes=attributes)) + return observations + + observables_registry = get_observables_registry_instance() + observables_registry.register(CONNECTION_COUNT_REGISTRY_KEY, connection_count_callback) def record_connection_timeout( pool_name: str, @@ -486,6 +512,125 @@ def record_maint_notification_count( # except Exception: # pass +def record_csc_request( + db_namespace: Optional[int] = None, + result: Optional[CSCResult] = None, +): + """ + Record a Client Side Caching (CSC) request. + + Args: + db_namespace: Redis database index + result: CSC result ('hit' or 'miss') + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + _metrics_collector.record_csc_request( + db_namespace=db_namespace, + result=result, + ) + +def init_csc_items() -> None: + """ + Initialize observable gauge for CSC items metric. + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + def observable_callback(__): + observables_registry = get_observables_registry_instance() + callbacks = observables_registry.get(CSC_ITEMS_REGISTRY_KEY) + observations = [] + + for callback in callbacks: + observations.extend(callback()) + + return observations + + _metrics_collector.init_csc_items( + callback=observable_callback, + ) + +def register_csc_items_callback( + callback: Callable, + db_namespace: Optional[int] = None, +) -> None: + """ + Adds given callback to CSC items observable registry. + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + # Lazy import + from opentelemetry.metrics import Observation + + def csc_items_callback(): + return [Observation(callback(), attributes=AttributeBuilder.build_csc_attributes(db_namespace=db_namespace))] + + observables_registry = get_observables_registry_instance() + observables_registry.register(CSC_ITEMS_REGISTRY_KEY, csc_items_callback) + +def record_csc_eviction( + count: int, + db_namespace: Optional[int] = None, + reason: Optional[CSCReason] = None, +) -> None: + """ + Record a Client Side Caching (CSC) eviction. + + Args: + count: Number of evictions + db_namespace: Redis database index + reason: Reason for eviction + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + _metrics_collector.record_csc_eviction( + count=count, + db_namespace=db_namespace, + reason=reason, + ) + +def record_csc_network_saved( + bytes_saved: int, + db_namespace: Optional[int] = None, +) -> None: + """ + Record the number of bytes saved by using Client Side Caching (CSC). + + Args: + bytes_saved: Number of bytes saved + db_namespace: Redis database index + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + _metrics_collector.record_csc_network_saved( + bytes_saved=bytes_saved, + db_namespace=db_namespace, + ) def _get_or_create_collector() -> Optional[RedisMetricsCollector]: """ diff --git a/redis/observability/registry.py b/redis/observability/registry.py new file mode 100644 index 0000000000..4bdc56ff54 --- /dev/null +++ b/redis/observability/registry.py @@ -0,0 +1,64 @@ +import threading +from typing import Dict, List, Callable, Optional, Any + +from opentelemetry.metrics import Observation + + +class ObservablesRegistry: + """ + Global registry for storing callbacks for observable metrics. + """ + def __init__(self, registry: Dict[str, List[Callable[[], List[Observation]]]] = None): + self._registry = registry or {} + self._lock = threading.Lock() + + def register(self, name: str, callback: Callable[[], List[Observation]]) -> None: + """ + Register a callback for an observable metric. + """ + with self._lock: + self._registry.setdefault(name, []).append(callback) + + def get(self, name: str) -> List[Callable[[], List[Observation]]]: + """ + Get all callbacks for an observable metric. + """ + with self._lock: + return self._registry.get(name, []) + + def clear(self) -> None: + """ + Clear the registry. + """ + with self._lock: + self._registry.clear() + + def __len__(self) -> int: + """ + Get the number of registered callbacks. + """ + return len(self._registry) + +# Global singleton instance +_observables_registry_instance: Optional[ObservablesRegistry] = None + +def get_observables_registry_instance() -> ObservablesRegistry: + """ + Get the global observables registry singleton instance. + + This is the Pythonic way to get the singleton instance. + + Returns: + The global ObservablesRegistry singleton + + Example: + >>> + >>> registry = get_observables_registry_instance() + >>> registry.register('my_metric', my_callback) + """ + global _observables_registry_instance + + if _observables_registry_instance is None: + _observables_registry_instance = ObservablesRegistry() + + return _observables_registry_instance \ No newline at end of file diff --git a/tests/test_cache.py b/tests/test_cache.py index 265bcded04..77e83c0962 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,4 +1,5 @@ import time +from unittest.mock import MagicMock import pytest import redis @@ -8,11 +9,19 @@ CacheEntry, CacheEntryStatus, CacheKey, + CacheProxy, DefaultCache, EvictionPolicy, EvictionPolicyType, LRUPolicy, ) +from redis.event import ( + EventDispatcher, + EventListenerInterface, + OnCacheEvictionEvent, + OnCacheInitializationEvent, +) +from redis.observability.attributes import CSCReason from tests.conftest import _get_client, skip_if_resp_version, skip_if_server_version_lt @@ -1120,66 +1129,6 @@ def test_set_does_not_store_not_allowed_key(self, cache_key, mock_connection): ) ) - def test_set_evict_lru_cache_key_on_reaching_max_size(self, mock_connection): - cache = DefaultCache(CacheConfig(max_size=3)) - cache_key1 = CacheKey( - command="GET", redis_keys=("foo",), redis_args=("GET", "foo") - ) - cache_key2 = CacheKey( - command="GET", redis_keys=("foo1",), redis_args=("GET", "foo1") - ) - cache_key3 = CacheKey( - command="GET", redis_keys=("foo2",), redis_args=("GET", "foo2") - ) - - # Set 3 different keys - assert cache.set( - CacheEntry( - cache_key=cache_key1, - cache_value=b"bar", - status=CacheEntryStatus.VALID, - connection_ref=mock_connection, - ) - ) - assert cache.set( - CacheEntry( - cache_key=cache_key2, - cache_value=b"bar1", - status=CacheEntryStatus.VALID, - connection_ref=mock_connection, - ) - ) - assert cache.set( - CacheEntry( - cache_key=cache_key3, - cache_value=b"bar2", - status=CacheEntryStatus.VALID, - connection_ref=mock_connection, - ) - ) - - # Accessing key in the order that it makes 2nd key LRU - assert cache.get(cache_key1).cache_value == b"bar" - assert cache.get(cache_key2).cache_value == b"bar1" - assert cache.get(cache_key3).cache_value == b"bar2" - assert cache.get(cache_key1).cache_value == b"bar" - - cache_key4 = CacheKey( - command="GET", redis_keys=("foo3",), redis_args=("GET", "foo3") - ) - assert cache.set( - CacheEntry( - cache_key=cache_key4, - cache_value=b"bar3", - status=CacheEntryStatus.VALID, - connection_ref=mock_connection, - ) - ) - - # Make sure that new key was added and 2nd is evicted - assert cache.get(cache_key4).cache_value == b"bar3" - assert cache.get(cache_key2) is None - @pytest.mark.parametrize( "cache_key", [{"command": "GET", "redis_keys": ("bar",)}], indirect=True ) @@ -1560,3 +1509,271 @@ def test_is_exceeds_max_size(self, cache_conf: CacheConfig): def test_is_allowed_to_cache(self, cache_conf: CacheConfig): assert cache_conf.is_allowed_to_cache("GET") assert not cache_conf.is_allowed_to_cache("SET") + + +class TestUnitCacheProxy: + """Unit tests for CacheProxy class with mocked event dispatcher.""" + + @pytest.fixture + def mock_cache(self, mock_connection): + """Create a DefaultCache for testing.""" + return DefaultCache(CacheConfig(max_size=5)) + + @pytest.fixture + def mock_event_dispatcher(self): + """Create a mock event dispatcher.""" + return MagicMock(spec=EventDispatcher) + + @pytest.fixture + def cache_key(self): + """Create a sample cache key.""" + return CacheKey(command="GET", redis_keys=("foo",), redis_args=("GET", "foo")) + + def test_initialization_emits_cache_initialisation_event(self, mock_cache): + """Test that CacheProxy emits OnCacheInitialisationEvent on initialization.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + OnCacheInitializationEvent: [listener], + }) + + CacheProxy(mock_cache, event_dispatcher) + + listener.listen.assert_called_once() + event = listener.listen.call_args[0][0] + assert isinstance(event, OnCacheInitializationEvent) + assert callable(event.cache_items_callback) + + def test_initialization_event_callback_returns_cache_size( + self, mock_cache, mock_connection + ): + """Test that the cache_items_callback returns the current cache size.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + OnCacheInitializationEvent: [listener], + }) + + proxy = CacheProxy(mock_cache, event_dispatcher) + + event = listener.listen.call_args[0][0] + assert event.cache_items_callback() == 0 + + # Add an entry and verify callback reflects new size + cache_key = CacheKey(command="GET", redis_keys=("foo",), redis_args=("GET", "foo")) + proxy.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert event.cache_items_callback() == 1 + + def test_initialization_creates_default_event_dispatcher_when_none_provided( + self, mock_cache + ): + """Test that CacheProxy creates a default EventDispatcher when none is provided.""" + # Should not raise an error + proxy = CacheProxy(mock_cache) + assert proxy is not None + + def test_set_emits_eviction_event_when_cache_exceeds_max_size( + self, mock_connection + ): + """Test that OnCacheEvictionEvent is emitted when cache exceeds max size.""" + # Create a cache with max_size=2 + cache = DefaultCache(CacheConfig(max_size=2)) + event_dispatcher = EventDispatcher() + eviction_listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + OnCacheEvictionEvent: [eviction_listener], + }) + + proxy = CacheProxy(cache, event_dispatcher) + + # Add 2 entries (at max capacity) + for i in range(2): + cache_key = CacheKey( + command="GET", redis_keys=(f"key{i}",), redis_args=("GET", f"key{i}") + ) + proxy.set( + CacheEntry( + cache_key=cache_key, + cache_value=f"value{i}".encode(), + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + # No eviction event yet + eviction_listener.listen.assert_not_called() + + # Add a 3rd entry, which should trigger eviction + cache_key = CacheKey( + command="GET", redis_keys=("key3",), redis_args=("GET", "key3") + ) + proxy.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"value3", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + # Eviction event should be emitted + eviction_listener.listen.assert_called_once() + event = eviction_listener.listen.call_args[0][0] + assert isinstance(event, OnCacheEvictionEvent) + assert event.count == 1 + assert event.reason == CSCReason.FULL + + def test_set_does_not_emit_eviction_event_when_under_max_size( + self, mock_cache, mock_connection + ): + """Test that OnCacheEvictionEvent is NOT emitted when cache is under max size.""" + event_dispatcher = EventDispatcher() + eviction_listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + OnCacheEvictionEvent: [eviction_listener], + }) + + proxy = CacheProxy(mock_cache, event_dispatcher) + + cache_key = CacheKey(command="GET", redis_keys=("foo",), redis_args=("GET", "foo")) + proxy.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + eviction_listener.listen.assert_not_called() + + def test_collection_property_delegates_to_underlying_cache(self, mock_cache): + """Test that collection property returns the underlying cache's collection.""" + proxy = CacheProxy(mock_cache) + assert proxy.collection is mock_cache.collection + + def test_config_property_delegates_to_underlying_cache(self, mock_cache): + """Test that config property returns the underlying cache's config.""" + proxy = CacheProxy(mock_cache) + assert proxy.config is mock_cache.config + + def test_eviction_policy_property_delegates_to_underlying_cache(self, mock_cache): + """Test that eviction_policy property returns the underlying cache's eviction_policy.""" + proxy = CacheProxy(mock_cache) + assert proxy.eviction_policy is mock_cache.eviction_policy + + def test_size_property_delegates_to_underlying_cache( + self, mock_cache, mock_connection + ): + """Test that size property returns the underlying cache's size.""" + proxy = CacheProxy(mock_cache) + assert proxy.size == 0 + + cache_key = CacheKey(command="GET", redis_keys=("foo",), redis_args=("GET", "foo")) + proxy.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert proxy.size == 1 + + def test_get_delegates_to_underlying_cache(self, mock_cache, mock_connection): + """Test that get method delegates to the underlying cache.""" + proxy = CacheProxy(mock_cache) + + cache_key = CacheKey(command="GET", redis_keys=("foo",), redis_args=("GET", "foo")) + entry = CacheEntry( + cache_key=cache_key, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + proxy.set(entry) + + result = proxy.get(cache_key) + assert result is not None + assert result.cache_value == b"bar" + + def test_delete_by_cache_keys_delegates_to_underlying_cache( + self, mock_cache, mock_connection + ): + """Test that delete_by_cache_keys method delegates to the underlying cache.""" + proxy = CacheProxy(mock_cache) + + cache_key = CacheKey(command="GET", redis_keys=("foo",), redis_args=("GET", "foo")) + proxy.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + result = proxy.delete_by_cache_keys([cache_key]) + assert result == [True] + assert proxy.get(cache_key) is None + + def test_delete_by_redis_keys_delegates_to_underlying_cache( + self, mock_cache, mock_connection + ): + """Test that delete_by_redis_keys method delegates to the underlying cache.""" + proxy = CacheProxy(mock_cache) + + cache_key = CacheKey(command="GET", redis_keys=(b"foo",), redis_args=("GET", "foo")) + proxy.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + result = proxy.delete_by_redis_keys([b"foo"]) + assert result == [True] + assert proxy.get(cache_key) is None + + def test_flush_delegates_to_underlying_cache(self, mock_cache, mock_connection): + """Test that flush method delegates to the underlying cache.""" + proxy = CacheProxy(mock_cache) + + for i in range(3): + cache_key = CacheKey( + command="GET", redis_keys=(f"key{i}",), redis_args=("GET", f"key{i}") + ) + proxy.set( + CacheEntry( + cache_key=cache_key, + cache_value=f"value{i}".encode(), + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + assert proxy.size == 3 + result = proxy.flush() + assert result == 3 + assert proxy.size == 0 + + def test_is_cachable_delegates_to_underlying_cache(self, mock_cache): + """Test that is_cachable method delegates to the underlying cache.""" + proxy = CacheProxy(mock_cache) + + # GET is cachable by default + cache_key = CacheKey(command="GET", redis_keys=("foo",), redis_args=()) + assert proxy.is_cachable(cache_key) is True + + # SET is not cachable + cache_key = CacheKey(command="SET", redis_keys=("foo",), redis_args=()) + assert proxy.is_cachable(cache_key) is False diff --git a/tests/test_connection.py b/tests/test_connection.py index 441528ca6b..735ef19d27 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -7,7 +7,7 @@ from errno import ECONNREFUSED from typing import Any from unittest import mock -from unittest.mock import call, patch +from unittest.mock import call, patch, MagicMock, Mock import pytest import redis @@ -28,10 +28,17 @@ Connection, SSLConnection, UnixDomainSocketConnection, - parse_url, + parse_url, BlockingConnectionPool, ) from redis.credentials import UsernamePasswordCredentialProvider +from redis.event import ( + EventDispatcher, + EventListenerInterface, + OnCacheHitEvent, + OnCacheMissEvent, +) from redis.exceptions import ConnectionError, InvalidResponse, RedisError, TimeoutError +from redis.observability.attributes import DB_CLIENT_CONNECTION_POOL_NAME, DB_CLIENT_CONNECTION_STATE, ConnectionState from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE @@ -446,7 +453,9 @@ def test_clears_cache_on_disconnect(self, mock_connection, cache_conf): mock_connection.retry = "mock" mock_connection.host = "mock" mock_connection.port = "mock" + mock_connection.db = 0 mock_connection.credential_provider = UsernamePasswordCredentialProvider() + mock_connection._event_dispatcher = EventDispatcher() proxy_connection = CacheProxyConnection( mock_connection, cache, threading.RLock() @@ -463,7 +472,9 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): mock_connection.retry = "mock" mock_connection.host = "mock" mock_connection.port = "mock" + mock_connection.db = 0 mock_connection.credential_provider = UsernamePasswordCredentialProvider() + mock_connection._event_dispatcher = EventDispatcher() mock_cache.is_cachable.return_value = True mock_cache.get.side_effect = [ @@ -573,7 +584,9 @@ def test_triggers_invalidation_processing_on_another_connection( mock_connection.retry = "mock" mock_connection.host = "mock" mock_connection.port = "mock" + mock_connection.db = 0 mock_connection.credential_provider = UsernamePasswordCredentialProvider() + mock_connection._event_dispatcher = Mock(spec=EventDispatcher) another_conn = copy.deepcopy(mock_connection) another_conn.can_read.side_effect = [True, False] @@ -598,3 +611,364 @@ def test_triggers_invalidation_processing_on_another_connection( assert proxy_connection.read_response() == b"bar" assert another_conn.can_read.call_count == 2 another_conn.read_response.assert_called_once() + + @pytest.mark.skipif( + platform.python_implementation() == "PyPy", + reason="Pypy doesn't support side_effect", + ) + def test_cache_hit_event_emitted_on_cached_response(self, mock_connection): + """Test that OnCacheHitEvent is emitted when returning a cached response.""" + cache = DefaultCache(CacheConfig(max_size=10)) + event_dispatcher = EventDispatcher() + cache_hit_listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + OnCacheHitEvent: [cache_hit_listener], + }) + + mock_connection.retry = "mock" + mock_connection.host = "mock" + mock_connection.port = "mock" + mock_connection.db = 0 + mock_connection.credential_provider = UsernamePasswordCredentialProvider() + mock_connection.can_read.return_value = False + mock_connection._event_dispatcher = event_dispatcher + + cache_key = CacheKey( + command="GET", redis_keys=("foo",), redis_args=("GET", "foo") + ) + cache.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + proxy_connection = CacheProxyConnection( + mock_connection, cache, threading.RLock() + ) + # Manually set the cache key to simulate send_command having been called + proxy_connection._current_command_cache_key = cache_key + + result = proxy_connection.read_response() + + assert result == b"bar" + cache_hit_listener.listen.assert_called_once() + event = cache_hit_listener.listen.call_args[0][0] + assert isinstance(event, OnCacheHitEvent) + assert event.bytes_saved == len(b"bar") + assert event.db_namespace == 0 + + @pytest.mark.skipif( + platform.python_implementation() == "PyPy", + reason="Pypy doesn't support side_effect", + ) + def test_cache_miss_event_emitted_on_uncached_response(self, mock_connection): + """Test that OnCacheMissEvent is emitted when cache miss occurs.""" + cache = DefaultCache(CacheConfig(max_size=10)) + event_dispatcher = EventDispatcher() + cache_miss_listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + OnCacheMissEvent: [cache_miss_listener], + }) + + mock_connection.retry = "mock" + mock_connection.host = "mock" + mock_connection.port = "mock" + mock_connection.db = 0 + mock_connection.credential_provider = UsernamePasswordCredentialProvider() + mock_connection.can_read.return_value = False + mock_connection.send_command.return_value = None + mock_connection.read_response.return_value = b"bar" + mock_connection._event_dispatcher = event_dispatcher + + proxy_connection = CacheProxyConnection( + mock_connection, cache, threading.RLock() + ) + proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]}) + result = proxy_connection.read_response() + + assert result == b"bar" + cache_miss_listener.listen.assert_called_once() + event = cache_miss_listener.listen.call_args[0][0] + assert isinstance(event, OnCacheMissEvent) + assert event.db_namespace == 0 + + @pytest.mark.skipif( + platform.python_implementation() == "PyPy", + reason="Pypy doesn't support side_effect", + ) + def test_cache_miss_event_not_emitted_for_non_cachable_command(self, mock_connection): + """Test that OnCacheMissEvent is emitted for non-cachable commands.""" + cache = DefaultCache(CacheConfig(max_size=10)) + event_dispatcher = EventDispatcher() + cache_miss_listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + OnCacheMissEvent: [cache_miss_listener], + }) + + mock_connection.retry = "mock" + mock_connection.host = "mock" + mock_connection.port = "mock" + mock_connection.db = 0 + mock_connection.credential_provider = UsernamePasswordCredentialProvider() + mock_connection.can_read.return_value = False + mock_connection.send_command.return_value = None + mock_connection.read_response.return_value = b"OK" + mock_connection._event_dispatcher = event_dispatcher + + proxy_connection = CacheProxyConnection( + mock_connection, cache, threading.RLock() + ) + # SET is not cachable + proxy_connection.send_command(*["SET", "foo", "bar"]) + result = proxy_connection.read_response() + + assert result == b"OK" + cache_miss_listener.listen.assert_not_called() + + @pytest.mark.skipif( + platform.python_implementation() == "PyPy", + reason="Pypy doesn't support side_effect", + ) + def test_cache_hit_not_emitted_for_in_progress_entry(self, mock_connection): + """Test that OnCacheHitEvent is NOT emitted when cache entry is IN_PROGRESS.""" + cache = DefaultCache(CacheConfig(max_size=10)) + event_dispatcher = EventDispatcher() + cache_hit_listener = MagicMock(spec=EventListenerInterface) + cache_miss_listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + OnCacheHitEvent: [cache_hit_listener], + OnCacheMissEvent: [cache_miss_listener], + }) + + mock_connection.retry = "mock" + mock_connection.host = "mock" + mock_connection.port = "mock" + mock_connection.db = 0 + mock_connection.credential_provider = UsernamePasswordCredentialProvider() + mock_connection.can_read.return_value = False + mock_connection.read_response.return_value = b"bar" + mock_connection._event_dispatcher = event_dispatcher + + cache_key = CacheKey( + command="GET", redis_keys=("foo",), redis_args=("GET", "foo") + ) + # Set entry with IN_PROGRESS status + cache.set( + CacheEntry( + cache_key=cache_key, + cache_value=CacheProxyConnection.DUMMY_CACHE_VALUE, + status=CacheEntryStatus.IN_PROGRESS, + connection_ref=mock_connection, + ) + ) + + proxy_connection = CacheProxyConnection( + mock_connection, cache, threading.RLock() + ) + proxy_connection._current_command_cache_key = cache_key + + result = proxy_connection.read_response() + + assert result == b"bar" + # Cache hit should NOT be emitted for IN_PROGRESS entry + cache_hit_listener.listen.assert_not_called() + # Cache miss should be emitted instead + cache_miss_listener.listen.assert_called_once() + + +class TestConnectionPoolGetConnectionCount: + """Tests for ConnectionPool.get_connection_count() method.""" + + def test_get_connection_count_returns_idle_and_used_counts(self): + """Test that get_connection_count returns both idle and used connection counts.""" + pool = ConnectionPool(max_connections=10) + + # Initially, no connections exist + counts = pool.get_connection_count() + assert len(counts) == 2 + + # Check idle connections count + idle_count, idle_attrs = counts[0] + assert idle_count == 0 + assert DB_CLIENT_CONNECTION_POOL_NAME in idle_attrs + assert idle_attrs[DB_CLIENT_CONNECTION_STATE] == ConnectionState.IDLE.value + + # Check used connections count + used_count, used_attrs = counts[1] + assert used_count == 0 + assert DB_CLIENT_CONNECTION_POOL_NAME in used_attrs + assert used_attrs[DB_CLIENT_CONNECTION_STATE] == ConnectionState.USED.value + + pool.disconnect() + + def test_get_connection_count_with_connections_in_use(self): + """Test get_connection_count when connections are in use.""" + + pool = ConnectionPool(max_connections=10) + + # Create mock connections + mock_conn1 = MagicMock() + mock_conn1.pid = pool.pid + + mock_conn2 = MagicMock() + mock_conn2.pid = pool.pid + + # Simulate connections in use + pool._in_use_connections.add(mock_conn1) + pool._in_use_connections.add(mock_conn2) + + counts = pool.get_connection_count() + + idle_count, idle_attrs = counts[0] + used_count, used_attrs = counts[1] + + assert idle_count == 0 + assert used_count == 2 + assert idle_attrs[DB_CLIENT_CONNECTION_STATE] == ConnectionState.IDLE.value + assert used_attrs[DB_CLIENT_CONNECTION_STATE] == ConnectionState.USED.value + + pool.disconnect() + + def test_get_connection_count_with_available_connections(self): + """Test get_connection_count when connections are available (idle).""" + + pool = ConnectionPool(max_connections=10) + + # Create mock connections + mock_conn1 = MagicMock() + mock_conn1.pid = pool.pid + + mock_conn2 = MagicMock() + mock_conn2.pid = pool.pid + + mock_conn3 = MagicMock() + mock_conn3.pid = pool.pid + + # Simulate available connections + pool._available_connections.append(mock_conn1) + pool._available_connections.append(mock_conn2) + pool._available_connections.append(mock_conn3) + + counts = pool.get_connection_count() + + idle_count, idle_attrs = counts[0] + used_count, used_attrs = counts[1] + + assert idle_count == 3 + assert used_count == 0 + assert idle_attrs[DB_CLIENT_CONNECTION_STATE] == ConnectionState.IDLE.value + assert used_attrs[DB_CLIENT_CONNECTION_STATE] == ConnectionState.USED.value + + pool.disconnect() + + def test_get_connection_count_mixed_connections(self): + """Test get_connection_count with both idle and used connections.""" + + pool = ConnectionPool(max_connections=10) + + # Create mock connections + mock_idle = MagicMock() + mock_idle.pid = pool.pid + + mock_used1 = MagicMock() + mock_used1.pid = pool.pid + + mock_used2 = MagicMock() + mock_used2.pid = pool.pid + + # Simulate mixed state + pool._available_connections.append(mock_idle) + pool._in_use_connections.add(mock_used1) + pool._in_use_connections.add(mock_used2) + + counts = pool.get_connection_count() + + idle_count, _ = counts[0] + used_count, _ = counts[1] + + assert idle_count == 1 + assert used_count == 2 + + pool.disconnect() + + def test_get_connection_count_includes_pool_name_in_attributes(self): + """Test that get_connection_count includes pool name in attributes.""" + + pool = ConnectionPool(max_connections=10) + + counts = pool.get_connection_count() + + idle_count, idle_attrs = counts[0] + used_count, used_attrs = counts[1] + + # Both should have the pool name + assert DB_CLIENT_CONNECTION_POOL_NAME in idle_attrs + assert DB_CLIENT_CONNECTION_POOL_NAME in used_attrs + + # Pool name should be the repr of the pool + assert repr(pool) in idle_attrs[DB_CLIENT_CONNECTION_POOL_NAME] + assert repr(pool) in used_attrs[DB_CLIENT_CONNECTION_POOL_NAME] + + pool.disconnect() + + +class TestBlockingConnectionPoolGetConnectionCount: + """Tests for BlockingConnectionPool.get_connection_count() method.""" + + def test_get_connection_count_returns_idle_and_used_counts(self): + """Test that BlockingConnectionPool.get_connection_count returns both counts.""" + + pool = BlockingConnectionPool(max_connections=10) + + # Initially, no connections exist + counts = pool.get_connection_count() + assert len(counts) == 2 + + idle_count, idle_attrs = counts[0] + used_count, used_attrs = counts[1] + + assert idle_count == 0 + assert used_count == 0 + assert idle_attrs[DB_CLIENT_CONNECTION_STATE] == ConnectionState.IDLE.value + assert used_attrs[DB_CLIENT_CONNECTION_STATE] == ConnectionState.USED.value + + pool.disconnect() + + def test_get_connection_count_with_connections_in_queue(self): + """Test get_connection_count when connections are in the queue (idle).""" + + pool = BlockingConnectionPool(max_connections=10) + + # Create mock connections and add to queue + mock_conn1 = MagicMock() + mock_conn1.pid = pool.pid + + mock_conn2 = MagicMock() + mock_conn2.pid = pool.pid + + # Add connections to the pool's internal list and queue + pool._connections.append(mock_conn1) + pool._connections.append(mock_conn2) + + # Clear the queue and add our connections + while not pool.pool.empty(): + try: + pool.pool.get_nowait() + except Exception: + break + + pool.pool.put_nowait(mock_conn1) + pool.pool.put_nowait(mock_conn2) + + counts = pool.get_connection_count() + + idle_count, _ = counts[0] + used_count, _ = counts[1] + + assert idle_count == 2 + assert used_count == 0 + + pool.disconnect() diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index b23a505711..e24f3112a5 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -19,7 +19,7 @@ EventDispatcher, EventListenerInterface, ) -from redis.observability.metrics import CloseReason +from redis.connection import CloseReason from redis.utils import SSL_AVAILABLE from .conftest import ( diff --git a/tests/test_observability/test_recorder.py b/tests/test_observability/test_recorder.py index 09127de49e..54401bc421 100644 --- a/tests/test_observability/test_recorder.py +++ b/tests/test_observability/test_recorder.py @@ -8,9 +8,10 @@ """ import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, call from redis.observability import recorder +from redis.observability.registry import ObservablesRegistry, get_observables_registry_instance from redis.observability.attributes import ( ConnectionState, PubSubDirection, @@ -998,3 +999,429 @@ def test_enabled_group_receives_meter_calls_disabled_group_does_not(self): instruments.client_errors.add.assert_not_called() instruments.maintenance_notifications.add.assert_not_called() instruments.stream_lag.record.assert_not_called() + + +class TestObservablesRegistry: + """Tests for ObservablesRegistry singleton and callback registration.""" + + def test_registry_singleton_returns_same_instance(self): + """Test that get_observables_registry_instance returns the same instance.""" + registry1 = get_observables_registry_instance() + registry2 = get_observables_registry_instance() + assert registry1 is registry2 + + def test_registry_register_and_get_callbacks(self): + """Test registering and retrieving callbacks from the registry.""" + registry = ObservablesRegistry() + + callback1 = MagicMock(return_value=[]) + callback2 = MagicMock(return_value=[]) + + registry.register("test_metric", callback1) + registry.register("test_metric", callback2) + + callbacks = registry.get("test_metric") + assert len(callbacks) == 2 + assert callback1 in callbacks + assert callback2 in callbacks + + def test_registry_get_returns_empty_list_for_unknown_key(self): + """Test that get returns empty list for unknown metric key.""" + registry = ObservablesRegistry() + callbacks = registry.get("unknown_metric") + assert callbacks == [] + + def test_registry_clear_removes_all_callbacks(self): + """Test that clear removes all registered callbacks.""" + registry = ObservablesRegistry() + + registry.register("metric1", MagicMock()) + registry.register("metric2", MagicMock()) + + registry.clear() + + assert registry.get("metric1") == [] + assert registry.get("metric2") == [] + assert len(registry) == 0 + + +class TestInitConnectionCount: + """Tests for init_connection_count and register_pools_connection_count.""" + + @pytest.fixture + def mock_observable_gauge(self): + """Create a mock observable gauge.""" + return MagicMock() + + @pytest.fixture + def mock_meter_with_observable(self, mock_observable_gauge): + """Create a mock meter that returns our mock observable gauge.""" + meter = MagicMock() + meter.create_observable_gauge.return_value = mock_observable_gauge + return meter + + @pytest.fixture + def mock_config_with_connection_basic(self): + """Create a config with CONNECTION_BASIC enabled.""" + return OTelConfig(metric_groups=[MetricGroup.CONNECTION_BASIC]) + + @pytest.fixture + def setup_connection_count_recorder( + self, mock_meter_with_observable, mock_config_with_connection_basic + ): + """Setup recorder with mocked meter for connection count tests.""" + recorder.reset_collector() + get_observables_registry_instance().clear() + + with patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector( + mock_meter_with_observable, mock_config_with_connection_basic + ) + + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + yield mock_meter_with_observable + + recorder.reset_collector() + get_observables_registry_instance().clear() + + def test_init_connection_count_creates_observable_gauge( + self, setup_connection_count_recorder + ): + """Test that init_connection_count creates an observable gauge.""" + mock_meter = setup_connection_count_recorder + + recorder.init_connection_count() + + mock_meter.create_observable_gauge.assert_called_once() + call_kwargs = mock_meter.create_observable_gauge.call_args[1] + assert call_kwargs['name'] == 'db.client.connection.count' + + def test_init_connection_count_callback_aggregates_registry_callbacks( + self, setup_connection_count_recorder + ): + """Test that the observable callback aggregates all registered pool callbacks.""" + mock_meter = setup_connection_count_recorder + + recorder.init_connection_count() + + # Get the callback that was passed to create_observable_gauge + call_args = mock_meter.create_observable_gauge.call_args + observable_callback = call_args[1]['callbacks'][0] + + # Register some mock pool callbacks + from opentelemetry.metrics import Observation + + mock_observation1 = Observation(5, attributes={"pool": "pool1"}) + mock_observation2 = Observation(3, attributes={"pool": "pool2"}) + + pool_callback1 = MagicMock(return_value=[mock_observation1]) + pool_callback2 = MagicMock(return_value=[mock_observation2]) + + registry = get_observables_registry_instance() + registry.register(recorder.CONNECTION_COUNT_REGISTRY_KEY, pool_callback1) + registry.register(recorder.CONNECTION_COUNT_REGISTRY_KEY, pool_callback2) + + # Call the observable callback + observations = observable_callback(None) + + # Verify both pool callbacks were called and observations aggregated + pool_callback1.assert_called_once() + pool_callback2.assert_called_once() + assert len(observations) == 2 + assert mock_observation1 in observations + assert mock_observation2 in observations + + def test_register_pools_connection_count_adds_callback_to_registry( + self, setup_connection_count_recorder + ): + """Test that register_pools_connection_count adds a callback to the registry.""" + # Create mock connection pools + mock_pool1 = MagicMock() + mock_pool1.get_connection_count.return_value = [ + (5, {"state": "idle"}), + (2, {"state": "used"}), + ] + + mock_pool2 = MagicMock() + mock_pool2.get_connection_count.return_value = [ + (3, {"state": "idle"}), + ] + + recorder.register_pools_connection_count([mock_pool1, mock_pool2]) + + registry = get_observables_registry_instance() + callbacks = registry.get(recorder.CONNECTION_COUNT_REGISTRY_KEY) + + assert len(callbacks) == 1 + + # Call the registered callback and verify it returns observations + observations = callbacks[0]() + + assert len(observations) == 3 + mock_pool1.get_connection_count.assert_called_once() + mock_pool2.get_connection_count.assert_called_once() + + +class TestInitCSCItems: + """Tests for init_csc_items and register_csc_items_callback.""" + + @pytest.fixture + def mock_observable_gauge(self): + """Create a mock observable gauge.""" + return MagicMock() + + @pytest.fixture + def mock_meter_with_observable(self, mock_observable_gauge): + """Create a mock meter that returns our mock observable gauge.""" + meter = MagicMock() + meter.create_observable_gauge.return_value = mock_observable_gauge + return meter + + @pytest.fixture + def mock_config_with_csc(self): + """Create a config with CSC metric group enabled.""" + return OTelConfig(metric_groups=[MetricGroup.CSC]) + + @pytest.fixture + def setup_csc_recorder(self, mock_meter_with_observable, mock_config_with_csc): + """Setup recorder with mocked meter for CSC tests.""" + recorder.reset_collector() + get_observables_registry_instance().clear() + + with patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector( + mock_meter_with_observable, mock_config_with_csc + ) + + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + yield mock_meter_with_observable + + recorder.reset_collector() + get_observables_registry_instance().clear() + + def test_init_csc_items_creates_observable_gauge(self, setup_csc_recorder): + """Test that init_csc_items creates an observable gauge.""" + mock_meter = setup_csc_recorder + + recorder.init_csc_items() + + mock_meter.create_observable_gauge.assert_called_once() + call_args = mock_meter.create_observable_gauge.call_args + assert call_args[1]['name'] == 'redis.client.csc.items' + + def test_init_csc_items_callback_aggregates_registry_callbacks( + self, setup_csc_recorder + ): + """Test that the CSC observable callback aggregates all registered callbacks.""" + mock_meter = setup_csc_recorder + + recorder.init_csc_items() + + # Get the callback that was passed to create_observable_gauge + call_args = mock_meter.create_observable_gauge.call_args + observable_callback = call_args[1]['callbacks'][0] + + # Register some mock CSC callbacks + from opentelemetry.metrics import Observation + + mock_observation1 = Observation(100, attributes={"db": 0}) + mock_observation2 = Observation(50, attributes={"db": 1}) + + csc_callback1 = MagicMock(return_value=[mock_observation1]) + csc_callback2 = MagicMock(return_value=[mock_observation2]) + + registry = get_observables_registry_instance() + registry.register(recorder.CSC_ITEMS_REGISTRY_KEY, csc_callback1) + registry.register(recorder.CSC_ITEMS_REGISTRY_KEY, csc_callback2) + + # Call the observable callback + observations = observable_callback(None) + + # Verify both callbacks were called and observations aggregated + csc_callback1.assert_called_once() + csc_callback2.assert_called_once() + assert len(observations) == 2 + assert mock_observation1 in observations + assert mock_observation2 in observations + + def test_register_csc_items_callback_adds_callback_to_registry( + self, setup_csc_recorder + ): + """Test that register_csc_items_callback adds a callback to the registry.""" + # Create a mock cache size callback + cache_size_callback = MagicMock(return_value=42) + + recorder.register_csc_items_callback(cache_size_callback, db_namespace=0) + + registry = get_observables_registry_instance() + callbacks = registry.get(recorder.CSC_ITEMS_REGISTRY_KEY) + + assert len(callbacks) == 1 + + # Call the registered callback and verify it returns an observation + observations = callbacks[0]() + + assert len(observations) == 1 + assert observations[0].value == 42 + cache_size_callback.assert_called_once() + + def test_register_csc_items_callback_multiple_registrations( + self, setup_csc_recorder + ): + """Test registering multiple CSC callbacks.""" + callback1 = MagicMock(return_value=10) + callback2 = MagicMock(return_value=20) + + recorder.register_csc_items_callback(callback1, db_namespace=0) + recorder.register_csc_items_callback(callback2, db_namespace=1) + + registry = get_observables_registry_instance() + callbacks = registry.get(recorder.CSC_ITEMS_REGISTRY_KEY) + + assert len(callbacks) == 2 + + # Verify each callback returns correct observation + obs1 = callbacks[0]() + obs2 = callbacks[1]() + + assert obs1[0].value == 10 + assert obs2[0].value == 20 + + +class TestObservableGaugeIntegration: + """Integration tests for observable gauge pattern with registry.""" + + @pytest.fixture + def clean_registry(self): + """Ensure clean registry before and after test.""" + get_observables_registry_instance().clear() + yield + get_observables_registry_instance().clear() + + def test_full_observable_gauge_flow(self, clean_registry): + """Test the complete flow: init -> register -> callback invocation.""" + from opentelemetry.metrics import Observation + + # Create mock meter and collector + mock_meter = MagicMock() + captured_callback = None + + def capture_callback(name, **kwargs): + nonlocal captured_callback + captured_callback = kwargs.get('callbacks', [None])[0] + return MagicMock() + + mock_meter.create_observable_gauge.side_effect = capture_callback + mock_meter.create_counter.return_value = MagicMock() + mock_meter.create_histogram.return_value = MagicMock() + mock_meter.create_up_down_counter.return_value = MagicMock() + + config = OTelConfig(metric_groups=[MetricGroup.CONNECTION_BASIC]) + + recorder.reset_collector() + + with patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector(mock_meter, config) + + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + # Step 1: Initialize the observable gauge + recorder.init_connection_count() + + # Step 2: Register pool callbacks + mock_pool = MagicMock() + mock_pool.get_connection_count.return_value = [ + (5, {"state": "idle", "pool": "pool1"}), + ] + recorder.register_pools_connection_count([mock_pool]) + + # Step 3: Simulate OTel calling the observable callback + assert captured_callback is not None + observations = captured_callback(None) + + # Verify the observation was created correctly + assert len(observations) == 1 + assert observations[0].value == 5 + + recorder.reset_collector() + + def test_observable_callback_handles_empty_registry(self, clean_registry): + """Test that observable callback handles empty registry gracefully.""" + mock_meter = MagicMock() + captured_callback = None + + def capture_callback(name, **kwargs): + nonlocal captured_callback + captured_callback = kwargs.get('callbacks', [None])[0] + return MagicMock() + + mock_meter.create_observable_gauge.side_effect = capture_callback + mock_meter.create_counter.return_value = MagicMock() + mock_meter.create_histogram.return_value = MagicMock() + mock_meter.create_up_down_counter.return_value = MagicMock() + + config = OTelConfig(metric_groups=[MetricGroup.CONNECTION_BASIC]) + + recorder.reset_collector() + + with patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector(mock_meter, config) + + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + recorder.init_connection_count() + + # Don't register any pools - registry is empty + assert captured_callback is not None + observations = captured_callback(None) + + # Should return empty list, not raise an error + assert observations == [] + + recorder.reset_collector() + + def test_multiple_pools_aggregated_correctly(self, clean_registry): + """Test that observations from multiple pools are aggregated correctly.""" + mock_meter = MagicMock() + captured_callback = None + + def capture_callback(name, **kwargs): + nonlocal captured_callback + captured_callback = kwargs.get('callbacks', [None])[0] + return MagicMock() + + mock_meter.create_observable_gauge.side_effect = capture_callback + mock_meter.create_counter.return_value = MagicMock() + mock_meter.create_histogram.return_value = MagicMock() + mock_meter.create_up_down_counter.return_value = MagicMock() + + config = OTelConfig(metric_groups=[MetricGroup.CONNECTION_BASIC]) + + recorder.reset_collector() + + with patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector(mock_meter, config) + + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + recorder.init_connection_count() + + # Register multiple pools in separate calls + mock_pool1 = MagicMock() + mock_pool1.get_connection_count.return_value = [ + (5, {"state": "idle"}), + (2, {"state": "used"}), + ] + + mock_pool2 = MagicMock() + mock_pool2.get_connection_count.return_value = [ + (3, {"state": "idle"}), + ] + + recorder.register_pools_connection_count([mock_pool1]) + recorder.register_pools_connection_count([mock_pool2]) + + # Simulate OTel calling the observable callback + observations = captured_callback(None) + + # Should have 3 observations total (2 from pool1, 1 from pool2) + assert len(observations) == 3 + + recorder.reset_collector()