From 6b62855385277de12ab6ea1bf77f588809b3dc8d Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 13 Jan 2026 10:24:46 +0200 Subject: [PATCH 1/5] Added export of connection advanced metrics --- redis/connection.py | 48 ++- redis/event.py | 45 ++- redis/observability/metrics.py | 16 +- redis/observability/recorder.py | 7 +- tests/test_connection_pool.py | 345 +++++++++++++++++++++- tests/test_observability/test_config.py | 2 +- tests/test_observability/test_recorder.py | 14 +- 7 files changed, 448 insertions(+), 29 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 7fe672b7c2..333a605b68 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -36,7 +36,7 @@ from .backoff import NoBackoff from .credentials import CredentialProvider, UsernamePasswordCredentialProvider from .event import AfterConnectionReleasedEvent, EventDispatcher, OnErrorEvent, OnMaintenanceNotificationEvent, \ - AfterConnectionCreatedEvent + AfterConnectionCreatedEvent, AfterConnectionAcquiredEvent, AfterConnectionClosedEvent from .exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, @@ -56,6 +56,7 @@ ) from .observability.attributes import AttributeBuilder, DB_CLIENT_CONNECTION_STATE, ConnectionState, \ DB_CLIENT_CONNECTION_POOL_NAME +from .observability.metrics import CloseReason from .retry import Retry from .utils import ( CRYPTOGRAPHY_AVAILABLE, @@ -1068,6 +1069,11 @@ def disconnect(self, *args): pass if len(args) > 0 and isinstance(args[0], Exception): + if len(args) > 2 and args[2]: + close_reason = CloseReason.HEALTHCHECK_FAILED + else: + close_reason = CloseReason.ERROR + if args[1] <= self.retry.get_retries(): self._event_dispatcher.dispatch( OnErrorEvent( @@ -1078,6 +1084,19 @@ def disconnect(self, *args): ) ) + self._event_dispatcher.dispatch( + AfterConnectionClosedEvent( + close_reason=close_reason, + error=args[0], + ) + ) + else: + self._event_dispatcher.dispatch( + AfterConnectionClosedEvent( + close_reason=CloseReason.APPLICATION_CLOSE + ) + ) + def mark_for_reconnect(self): self._should_reconnect = True @@ -1095,7 +1114,7 @@ def _send_ping(self): def _ping_failed(self, error, failure_count): """Function to call when PING fails""" - self.disconnect(error, failure_count) + self.disconnect(error, failure_count, True) def check_health(self): """Check the health of the connection with a PING/PONG""" @@ -2648,6 +2667,8 @@ def _checkpid(self) -> None: def get_connection(self, command_name=None, *keys, **options) -> "Connection": "Get a connection from the pool" + # Start timing for observability + start_time_acquired = time.monotonic() self._checkpid() is_created = False @@ -2656,7 +2677,7 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection": connection = self._available_connections.pop() except IndexError: # Start timing for observability - start_time = time.monotonic() + start_time_created = time.monotonic() connection = self.make_connection() is_created = True @@ -2691,9 +2712,16 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection": self._event_dispatcher.dispatch( AfterConnectionCreatedEvent( connection_pool=self, - duration_seconds=time.monotonic() - start_time, + duration_seconds=time.monotonic() - start_time_created, ) ) + + self._event_dispatcher.dispatch( + AfterConnectionAcquiredEvent( + connection_pool=self, + duration_seconds=time.monotonic() - start_time_acquired, + ) + ) return connection def get_encoder(self) -> Encoder: @@ -2957,6 +2985,7 @@ def get_connection(self, command_name=None, *keys, **options): create new connections when we need to, i.e.: the actual number of connections will only increase in response to demand. """ + start_time_acquired = time.monotonic() # Make sure we haven't changed process. self._checkpid() is_created = False @@ -2979,7 +3008,7 @@ def get_connection(self, command_name=None, *keys, **options): # a new connection to add to the pool. if connection is None: # Start timing for observability - start_time = time.monotonic() + start_time_created = time.monotonic() connection = self.make_connection() is_created = True finally: @@ -3014,10 +3043,17 @@ def get_connection(self, command_name=None, *keys, **options): self._event_dispatcher.dispatch( AfterConnectionCreatedEvent( connection_pool=self, - duration_seconds=time.monotonic() - start_time, + duration_seconds=time.monotonic() - start_time_created, ) ) + self._event_dispatcher.dispatch( + AfterConnectionAcquiredEvent( + connection_pool=self, + duration_seconds=time.monotonic() - start_time_acquired, + ) + ) + return connection def release(self, connection): diff --git a/redis/event.py b/redis/event.py index 3c12d6307d..c8d07d78c1 100644 --- a/redis/event.py +++ b/redis/event.py @@ -11,7 +11,8 @@ from redis.observability.attributes import PubSubDirection from redis.observability.recorder import record_operation_duration, record_error_count, record_maint_notification_count, \ record_connection_create_time, init_connection_count, record_connection_relaxed_timeout, record_connection_handoff, \ - record_pubsub_message, record_streaming_lag + record_pubsub_message, record_streaming_lag, record_connection_wait_time, record_connection_use_time, \ + record_connection_closed from redis.utils import str_if_bytes @@ -113,6 +114,12 @@ def __init__( AfterConnectionHandoffEvent: [ ExportConnectionHandoffMetric(), ], + AfterConnectionAcquiredEvent: [ + ExportConnectionWaitTimeMetric(), + ], + AfterConnectionClosedEvent: [ + ExportConnectionClosedMetric(), + ], OnPubSubMessageEvent: [ ExportPubSubMessageMetric(), ], @@ -397,6 +404,22 @@ class AfterConnectionHandoffEvent: """ connection_pool: "ConnectionPoolInterface" +@dataclass +class AfterConnectionAcquiredEvent: + """ + Event fired after connection is acquired from pool. + """ + connection_pool: "ConnectionPoolInterface" + duration_seconds: float + +@dataclass +class AfterConnectionClosedEvent: + """ + Event fired after connection is closed. + """ + close_reason: "CloseReason" + error: Optional[Exception] = None + class AsyncOnCommandsFailEvent(OnCommandsFailEvent): pass @@ -704,3 +727,23 @@ def listen(self, event: OnStreamMessageReceivedEvent): consumer_group=event.consumer_group, consumer_name=event.consumer_name, ) + +class ExportConnectionWaitTimeMetric(EventListenerInterface): + """ + Listener that exports connection wait time metric. + """ + def listen(self, event: AfterConnectionAcquiredEvent): + record_connection_wait_time( + pool_name=repr(event.connection_pool), + duration_seconds=event.duration_seconds, + ) + +class ExportConnectionClosedMetric(EventListenerInterface): + """ + Listener that exports connection closed metric. + """ + def listen(self, event: AfterConnectionClosedEvent): + record_connection_closed( + close_reason=event.close_reason, + error_type=event.error, + ) diff --git a/redis/observability/metrics.py b/redis/observability/metrics.py index 21ae8c0a24..9a8793dfb6 100644 --- a/redis/observability/metrics.py +++ b/redis/observability/metrics.py @@ -7,6 +7,7 @@ import logging import time +from enum import Enum from typing import Any, Dict, Optional, Callable, List from redis.observability.attributes import AttributeBuilder, ConnectionState, REDIS_CLIENT_CONNECTION_NOTIFICATION, \ @@ -27,6 +28,11 @@ Meter = None UpDownCounter = None +class CloseReason(Enum): + APPLICATION_CLOSE = "application_close" + ERROR = "error" + HEALTHCHECK_FAILED = "healthcheck_failed" + class RedisMetricsCollector: """ @@ -401,24 +407,22 @@ def record_operation_duration( def record_connection_closed( self, - pool_name: str, - close_reason: Optional[str] = None, + close_reason: Optional[CloseReason] = None, error_type: Optional[Exception] = None, ) -> None: """ Record a connection closed event. Args: - pool_name: Connection pool name - close_reason: Reason for closing (e.g., 'idle_timeout', 'error', 'shutdown') + close_reason: Reason for closing (e.g., 'error', 'application_close') error_type: Error type if closed due to error """ if not hasattr(self, "connection_closed"): return - attrs = self.attr_builder.build_connection_attributes(pool_name=pool_name) + attrs = self.attr_builder.build_connection_attributes() if close_reason: - attrs[REDIS_CLIENT_CONNECTION_CLOSE_REASON] = close_reason + attrs[REDIS_CLIENT_CONNECTION_CLOSE_REASON] = close_reason.value attrs.update( self.attr_builder.build_error_attributes( diff --git a/redis/observability/recorder.py b/redis/observability/recorder.py index 573362e3e1..2e67062b52 100644 --- a/redis/observability/recorder.py +++ b/redis/observability/recorder.py @@ -23,7 +23,7 @@ from typing import Optional, Callable from redis.observability.attributes import PubSubDirection, ConnectionState -from redis.observability.metrics import RedisMetricsCollector +from redis.observability.metrics import RedisMetricsCollector, CloseReason from redis.observability.providers import get_observability_instance # Global metrics collector instance (lazy-initialized) @@ -251,15 +251,13 @@ def record_connection_use_time( def record_connection_closed( - pool_name: str, - close_reason: Optional[str] = None, + close_reason: Optional[CloseReason] = None, error_type: Optional[Exception] = None, ) -> None: """ Record a connection closed event. Args: - pool_name: Connection pool identifier close_reason: Reason for closing (e.g., 'idle_timeout', 'error', 'shutdown') error_type: Error type if closed due to error @@ -275,7 +273,6 @@ def record_connection_closed( # try: _metrics_collector.record_connection_closed( - pool_name=pool_name, close_reason=close_reason, error_type=error_type, ) diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index e78fc976d3..e24f3112a5 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -10,7 +10,16 @@ import redis from redis.cache import CacheConfig from redis.connection import CacheProxyConnection, Connection, to_bool -from redis.event import AfterConnectionCreatedEvent, EventDispatcher, EventListenerInterface +from redis.event import ( + AfterConnectionCreatedEvent, + AfterConnectionAcquiredEvent, + AfterConnectionReleasedEvent, + AfterConnectionClosedEvent, + OnErrorEvent, + EventDispatcher, + EventListenerInterface, +) +from redis.connection import CloseReason from redis.utils import SSL_AVAILABLE from .conftest import ( @@ -1120,3 +1129,337 @@ def test_connection_created_event_emitted_multiple_times_for_new_connections(sel pool.get_connection() assert listener.listen.call_count == 2 + + def test_connection_acquired_event_emitted_on_get_connection(self): + """Test that AfterConnectionAcquiredEvent is emitted when getting a connection.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + AfterConnectionAcquiredEvent: [listener], + }) + + pool = redis.BlockingConnectionPool( + connection_class=DummyConnection, + event_dispatcher=event_dispatcher, + max_connections=10, + timeout=5, + ) + + pool.get_connection() + + listener.listen.assert_called_once() + event = listener.listen.call_args[0][0] + assert isinstance(event, AfterConnectionAcquiredEvent) + assert event.connection_pool is pool + assert event.duration_seconds >= 0 + + def test_connection_acquired_event_emitted_on_reused_connection(self): + """Test that AfterConnectionAcquiredEvent is emitted even when reusing a connection.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + AfterConnectionAcquiredEvent: [listener], + }) + + pool = redis.BlockingConnectionPool( + connection_class=DummyConnection, + event_dispatcher=event_dispatcher, + max_connections=10, + timeout=5, + ) + + conn = pool.get_connection() + pool.release(conn) + + # Reset the mock to clear the first call + listener.reset_mock() + + # Get the same connection again (reused) + pool.get_connection() + + # Event SHOULD be emitted for reused connection + listener.listen.assert_called_once() + + def test_connection_released_event_emitted_on_release(self): + """Test that AfterConnectionReleasedEvent is emitted when releasing a connection.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + AfterConnectionReleasedEvent: [listener], + }) + + pool = redis.ConnectionPool( + connection_class=DummyConnection, + event_dispatcher=event_dispatcher, + max_connections=10, + timeout=5, + ) + + conn = pool.get_connection() + pool.release(conn) + + listener.listen.assert_called_once() + event = listener.listen.call_args[0][0] + assert isinstance(event, AfterConnectionReleasedEvent) + + +class TestConnectionPoolAcquiredEventEmission: + """Tests for AfterConnectionAcquiredEvent emission from ConnectionPool.""" + + def test_connection_acquired_event_emitted_on_get_connection(self): + """Test that AfterConnectionAcquiredEvent is emitted when getting a connection.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + AfterConnectionAcquiredEvent: [listener], + }) + + pool = redis.ConnectionPool( + connection_class=DummyConnection, + event_dispatcher=event_dispatcher, + ) + + pool.get_connection() + + listener.listen.assert_called_once() + event = listener.listen.call_args[0][0] + assert isinstance(event, AfterConnectionAcquiredEvent) + assert event.connection_pool is pool + assert event.duration_seconds >= 0 + + def test_connection_acquired_event_emitted_on_reused_connection(self): + """Test that AfterConnectionAcquiredEvent is emitted even when reusing a connection.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + AfterConnectionAcquiredEvent: [listener], + }) + + pool = redis.ConnectionPool( + connection_class=DummyConnection, + event_dispatcher=event_dispatcher, + ) + + conn = pool.get_connection() + pool.release(conn) + + # Reset the mock to clear the first call + listener.reset_mock() + + # Get the same connection again (reused) + pool.get_connection() + + # Event SHOULD be emitted for reused connection + listener.listen.assert_called_once() + + +class TestConnectionPoolReleasedEventEmission: + """Tests for AfterConnectionReleasedEvent emission from ConnectionPool.""" + + def test_connection_released_event_emitted_on_release(self): + """Test that AfterConnectionReleasedEvent is emitted when releasing a connection.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + AfterConnectionReleasedEvent: [listener], + }) + + pool = redis.ConnectionPool( + connection_class=DummyConnection, + event_dispatcher=event_dispatcher, + ) + + conn = pool.get_connection() + pool.release(conn) + + listener.listen.assert_called_once() + event = listener.listen.call_args[0][0] + assert isinstance(event, AfterConnectionReleasedEvent) + + def test_connection_released_event_not_emitted_for_foreign_connection(self): + """Test that AfterConnectionReleasedEvent is NOT emitted for connections not owned by pool.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + AfterConnectionReleasedEvent: [listener], + }) + + pool = redis.ConnectionPool( + connection_class=DummyConnection, + event_dispatcher=event_dispatcher, + ) + + # Create a connection that doesn't belong to this pool + foreign_conn = DummyConnection() + + pool.release(foreign_conn) + + # Event should NOT be emitted for foreign connection + listener.listen.assert_not_called() + + +class TestConnectionClosedEventEmission: + """Tests for AfterConnectionClosedEvent emission from Connection.""" + + def test_connection_closed_event_emitted_on_disconnect(self): + """Test that AfterConnectionClosedEvent is emitted when disconnecting.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + AfterConnectionClosedEvent: [listener], + }) + + conn = Connection( + host="localhost", + port=6379, + event_dispatcher=event_dispatcher, + ) + + # Mock the socket to simulate a connected state + mock_sock = MagicMock() + conn._sock = mock_sock + + conn.disconnect() + + listener.listen.assert_called_once() + event = listener.listen.call_args[0][0] + assert isinstance(event, AfterConnectionClosedEvent) + assert event.close_reason == CloseReason.APPLICATION_CLOSE + assert event.error is None + + def test_connection_closed_event_with_error_reason(self): + """Test that AfterConnectionClosedEvent is emitted with ERROR reason on error.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + AfterConnectionClosedEvent: [listener], + }) + + conn = Connection( + host="localhost", + port=6379, + event_dispatcher=event_dispatcher, + ) + + # Mock the socket to simulate a connected state + mock_sock = MagicMock() + conn._sock = mock_sock + + error = ConnectionError("Connection lost") + conn.disconnect(error, 0) + + listener.listen.assert_called_once() + event = listener.listen.call_args[0][0] + assert isinstance(event, AfterConnectionClosedEvent) + assert event.close_reason == CloseReason.ERROR + assert event.error is error + + def test_connection_closed_event_with_healthcheck_failed_reason(self): + """Test that AfterConnectionClosedEvent is emitted with HEALTHCHECK_FAILED reason.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + AfterConnectionClosedEvent: [listener], + }) + + conn = Connection( + host="localhost", + port=6379, + event_dispatcher=event_dispatcher, + ) + + # Mock the socket to simulate a connected state + mock_sock = MagicMock() + conn._sock = mock_sock + + error = ConnectionError("Health check failed") + # Third argument True indicates health check failure + conn.disconnect(error, 0, True) + + listener.listen.assert_called_once() + event = listener.listen.call_args[0][0] + assert isinstance(event, AfterConnectionClosedEvent) + assert event.close_reason == CloseReason.HEALTHCHECK_FAILED + assert event.error is error + + +class TestConnectionOnErrorEventEmission: + """Tests for OnErrorEvent emission from Connection.""" + + def test_on_error_event_emitted_on_disconnect_with_error(self): + """Test that OnErrorEvent is emitted when disconnecting with an error.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + OnErrorEvent: [listener], + }) + + conn = Connection( + host="localhost", + port=6379, + event_dispatcher=event_dispatcher, + ) + + # Mock the socket to simulate a connected state + mock_sock = MagicMock() + conn._sock = mock_sock + + error = ConnectionError("Connection lost") + # retry_attempts=0 which is <= default retries (0), so event should be emitted + conn.disconnect(error, 0) + + listener.listen.assert_called_once() + event = listener.listen.call_args[0][0] + assert isinstance(event, OnErrorEvent) + assert event.error is error + assert event.server_address == "localhost" + assert event.server_port == 6379 + assert event.retry_attempts == 0 + + def test_on_error_event_not_emitted_when_retry_exceeds_limit(self): + """Test that OnErrorEvent is NOT emitted when retry attempts exceed limit.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + OnErrorEvent: [listener], + }) + + conn = Connection( + host="localhost", + port=6379, + event_dispatcher=event_dispatcher, + ) + + # Mock the socket to simulate a connected state + mock_sock = MagicMock() + conn._sock = mock_sock + + error = ConnectionError("Connection lost") + # retry_attempts=5 which is > default retries (0), so event should NOT be emitted + conn.disconnect(error, 5) + + # OnErrorEvent should NOT be called + listener.listen.assert_not_called() + + def test_on_error_event_not_emitted_on_normal_disconnect(self): + """Test that OnErrorEvent is NOT emitted on normal disconnect without error.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + OnErrorEvent: [listener], + }) + + conn = Connection( + host="localhost", + port=6379, + event_dispatcher=event_dispatcher, + ) + + # Mock the socket to simulate a connected state + mock_sock = MagicMock() + conn._sock = mock_sock + + conn.disconnect() + + # OnErrorEvent should NOT be called for normal disconnect + listener.listen.assert_not_called() diff --git a/tests/test_observability/test_config.py b/tests/test_observability/test_config.py index 222ad9c830..d6bf15e381 100644 --- a/tests/test_observability/test_config.py +++ b/tests/test_observability/test_config.py @@ -25,7 +25,7 @@ def test_default_enabled_telemetry(self): def test_default_metric_groups(self): """Test that default metric groups are COMMAND, CONNECTION_BASIC, RESILIENCY.""" config = OTelConfig() - expected = MetricGroup.COMMAND | MetricGroup.CONNECTION_BASIC | MetricGroup.RESILIENCY + expected = MetricGroup.CONNECTION_BASIC | MetricGroup.RESILIENCY assert config.metric_groups == expected def test_default_sample_percentage(self): diff --git a/tests/test_observability/test_recorder.py b/tests/test_observability/test_recorder.py index 3460c4489f..09127de49e 100644 --- a/tests/test_observability/test_recorder.py +++ b/tests/test_observability/test_recorder.py @@ -44,7 +44,7 @@ REDIS_CLIENT_CONSUMER_NAME, DB_CLIENT_CONNECTION_NAME, ) from redis.observability.config import OTelConfig, MetricGroup -from redis.observability.metrics import RedisMetricsCollector +from redis.observability.metrics import RedisMetricsCollector, CloseReason from redis.observability.recorder import record_operation_duration, record_connection_create_time, \ record_connection_timeout, record_connection_wait_time, record_connection_use_time, \ record_connection_closed, record_connection_relaxed_timeout, record_connection_handoff, record_error_count, \ @@ -328,8 +328,7 @@ def test_record_connection_closed_with_reason(self, setup_recorder): instruments = setup_recorder record_connection_closed( - pool_name='ConnectionPool', - close_reason='idle_timeout', + close_reason=CloseReason.HEALTHCHECK_FAILED, ) instruments.connection_closed.add.assert_called_once() @@ -337,8 +336,7 @@ def test_record_connection_closed_with_reason(self, setup_recorder): assert call_args[0][0] == 1 attrs = call_args[1]['attributes'] - assert attrs[DB_CLIENT_CONNECTION_POOL_NAME] == 'ConnectionPool' - assert attrs[REDIS_CLIENT_CONNECTION_CLOSE_REASON] == 'idle_timeout' + assert attrs[REDIS_CLIENT_CONNECTION_CLOSE_REASON] == CloseReason.HEALTHCHECK_FAILED.value def test_record_connection_closed_with_error(self, setup_recorder): """Test recording connection closed with error type.""" @@ -347,8 +345,7 @@ def test_record_connection_closed_with_error(self, setup_recorder): error = ConnectionResetError("Connection reset by peer") record_connection_closed( - pool_name='ConnectionPool', - close_reason='error', + close_reason=CloseReason.ERROR, error_type=error, ) @@ -828,8 +825,7 @@ def test_record_connection_closed_no_meter_call_when_connection_advanced_disable recorder.reset_collector() with patch.object(recorder, '_get_or_create_collector', return_value=collector): record_connection_closed( - pool_name='test-pool', - close_reason='idle_timeout', + close_reason=CloseReason.APPLICATION_CLOSE, ) # Verify no call to the counter's add method From 21c75ec4a28f12cfaa1fd46258fc96663173c203 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 14 Jan 2026 17:29:46 +0200 Subject: [PATCH 2/5] Added CSC metrics export --- redis/cache.py | 72 ++++++- redis/connection.py | 29 ++- redis/event.py | 90 +++++++- redis/observability/attributes.py | 37 ++++ redis/observability/metrics.py | 103 ++++++++- redis/observability/recorder.py | 99 ++++++++- tests/conftest.py | 2 +- tests/test_cache.py | 337 ++++++++++++++++++++++++------ tests/test_connection.py | 182 +++++++++++++++- 9 files changed, 873 insertions(+), 78 deletions(-) diff --git a/redis/cache.py b/redis/cache.py index 949ad3ddf9..2384c43259 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -4,6 +4,10 @@ from enum import Enum from typing import Any, List, Optional, Union +from redis.event import EventDispatcherInterface, EventDispatcher, OnCacheInitialisationEvent, \ + OnCacheEvictionEvent, OnCacheHitEvent +from redis.observability.attributes import CSCResult, CSCReason + class CacheEntryStatus(Enum): VALID = "VALID" @@ -186,9 +190,6 @@ def set(self, entry: CacheEntry) -> bool: self._cache[entry.cache_key] = entry self._eviction_policy.touch(entry.cache_key) - if self._cache_config.is_exceeds_max_size(len(self._cache)): - self._eviction_policy.evict_next() - return True def get(self, key: CacheKey) -> Union[CacheEntry, None]: @@ -247,6 +248,69 @@ def flush(self) -> int: def is_cachable(self, key: CacheKey) -> bool: return self._cache_config.is_allowed_to_cache(key.command) +class CacheProxy(CacheInterface): + """ + Proxy object that wraps cache implementations to enable additional logic on top + """ + def __init__(self, cache: CacheInterface, event_dispatcher: Optional[EventDispatcherInterface] = None): + self._cache = cache + + if event_dispatcher is None: + self._event_dispatcher = EventDispatcher() + else: + self._event_dispatcher = event_dispatcher + + self._event_dispatcher.dispatch( + OnCacheInitialisationEvent( + cache_items_callback=lambda: self._cache.size, + ) + ) + + @property + def collection(self) -> OrderedDict: + return self._cache.collection + + @property + def config(self) -> CacheConfigurationInterface: + return self._cache.config + + @property + def eviction_policy(self) -> EvictionPolicyInterface: + return self._cache.eviction_policy + + @property + def size(self) -> int: + return self._cache.size + + def get(self, key: CacheKey) -> Union[CacheEntry, None]: + return self._cache.get(key) + + def set(self, entry: CacheEntry) -> bool: + is_set = self._cache.set(entry) + + if self.config.is_exceeds_max_size(self.size): + self._event_dispatcher.dispatch( + OnCacheEvictionEvent( + count=1, + reason=CSCReason.FULL, + ) + ) + self.eviction_policy.evict_next() + + return is_set + + def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]: + return self._cache.delete_by_cache_keys(cache_keys) + + def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]: + return self._cache.delete_by_redis_keys(redis_keys) + + def flush(self) -> int: + return self._cache.flush() + + def is_cachable(self, key: CacheKey) -> bool: + return self._cache.is_cachable(key) + class LRUPolicy(EvictionPolicyInterface): def __init__(self): @@ -422,4 +486,4 @@ def __init__(self, cache_config: Optional[CacheConfig] = None): def get_cache(self) -> CacheInterface: cache_class = self._config.get_cache_class() - return cache_class(cache_config=self._config) + return CacheProxy(cache_class(cache_config=self._config)) diff --git a/redis/connection.py b/redis/connection.py index 333a605b68..157c96bcb5 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -28,7 +28,7 @@ CacheFactory, CacheFactoryInterface, CacheInterface, - CacheKey, + CacheKey, CacheProxy, ) from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser @@ -36,7 +36,8 @@ from .backoff import NoBackoff from .credentials import CredentialProvider, UsernamePasswordCredentialProvider from .event import AfterConnectionReleasedEvent, EventDispatcher, OnErrorEvent, OnMaintenanceNotificationEvent, \ - AfterConnectionCreatedEvent, AfterConnectionAcquiredEvent, AfterConnectionClosedEvent + AfterConnectionCreatedEvent, AfterConnectionAcquiredEvent, AfterConnectionClosedEvent, OnCacheHitEvent, \ + OnCacheMissEvent, OnCacheEvictionEvent from .exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, @@ -55,7 +56,7 @@ MaintNotificationsPoolHandler, MaintenanceNotification, ) from .observability.attributes import AttributeBuilder, DB_CLIENT_CONNECTION_STATE, ConnectionState, \ - DB_CLIENT_CONNECTION_POOL_NAME + DB_CLIENT_CONNECTION_POOL_NAME, CSCReason from .observability.metrics import CloseReason from .retry import Retry from .utils import ( @@ -1399,6 +1400,8 @@ def __init__( self.retry = self._conn.retry self.host = self._conn.host self.port = self._conn.port + self.db = self._conn.db + self._event_dispatcher = self._conn._event_dispatcher self.credential_provider = conn.credential_provider self._pool_lock = pool_lock self._cache = cache @@ -1551,8 +1554,20 @@ def read_response( self._cache.get(self._current_command_cache_key).cache_value ) self._current_command_cache_key = None + self._event_dispatcher.dispatch( + OnCacheHitEvent( + bytes_saved=len(res), + db_namespace=self.db, + ) + ) return res + self._event_dispatcher.dispatch( + OnCacheMissEvent( + db_namespace=self.db, + ) + ) + response = self._conn.read_response( disable_decoding=disable_decoding, disconnect_on_error=disconnect_on_error, @@ -1716,6 +1731,12 @@ def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]] self._cache.flush() else: self._cache.delete_by_redis_keys(data[1]) + self._event_dispatcher.dispatch( + OnCacheEvictionEvent( + count=len(data[1]), + reason=CSCReason.INVALIDATION, + ) + ) class SSLConnection(Connection): @@ -2539,7 +2560,7 @@ def __init__( self.cache = cache else: if self._cache_factory is not None: - self.cache = self._cache_factory.get_cache() + self.cache = CacheProxy(self._cache_factory.get_cache()) else: self.cache = CacheFactory( self._connection_kwargs.get("cache_config") diff --git a/redis/event.py b/redis/event.py index c8d07d78c1..92edd22e80 100644 --- a/redis/event.py +++ b/redis/event.py @@ -8,11 +8,11 @@ from redis.auth.token import TokenInterface from redis.credentials import CredentialProvider, StreamingCredentialProvider -from redis.observability.attributes import PubSubDirection +from redis.observability.attributes import PubSubDirection, CSCResult, CSCReason from redis.observability.recorder import record_operation_duration, record_error_count, record_maint_notification_count, \ record_connection_create_time, init_connection_count, record_connection_relaxed_timeout, record_connection_handoff, \ record_pubsub_message, record_streaming_lag, record_connection_wait_time, record_connection_use_time, \ - record_connection_closed + record_connection_closed, record_csc_request, init_csc_items, record_csc_eviction, record_csc_network_saved from redis.utils import str_if_bytes @@ -103,10 +103,6 @@ def __init__( AsyncAfterConnectionReleasedEvent: [ AsyncReAuthConnectionListener(), ], - OnErrorEvent: [ExportErrorCountMetric()], - OnMaintenanceNotificationEvent: [ - ExportMaintenanceNotificationCountMetric(), - ], AfterConnectionCreatedEvent: [ExportConnectionCreateTimeMetric()], AfterConnectionTimeoutUpdatedEvent: [ ExportConnectionRelaxedTimeoutMetric(), @@ -126,6 +122,14 @@ def __init__( OnStreamMessageReceivedEvent: [ ExportStreamingLagMetric(), ], + OnErrorEvent: [ExportErrorCountMetric()], + OnMaintenanceNotificationEvent: [ + ExportMaintenanceNotificationCountMetric(), + ], + OnCacheInitialisationEvent: [InitialiseCSCItemsObservability()], + OnCacheEvictionEvent: [ExportCSCEvictionMetric()], + OnCacheHitEvent: [ExportCSCNetworkSavedMetric(), ExportCSCRequestMetric()], + OnCacheMissEvent: [ExportCSCRequestMetric()], } self._lock = threading.Lock() @@ -420,6 +424,36 @@ class AfterConnectionClosedEvent: close_reason: "CloseReason" error: Optional[Exception] = None +@dataclass +class OnCacheHitEvent: + """ + Event fired whenever a cache hit occurs. + """ + bytes_saved: int + db_namespace: Optional[int] = None + +@dataclass +class OnCacheMissEvent: + """ + Event fired whenever a cache miss occurs. + """ + db_namespace: Optional[int] = None + +@dataclass +class OnCacheInitialisationEvent: + """ + Event fired after cache is initialized. + """ + cache_items_callback: Callable + +@dataclass +class OnCacheEvictionEvent: + """ + Event fired after cache eviction. + """ + count: int + reason: CSCReason + class AsyncOnCommandsFailEvent(OnCommandsFailEvent): pass @@ -747,3 +781,47 @@ def listen(self, event: AfterConnectionClosedEvent): close_reason=event.close_reason, error_type=event.error, ) + +class ExportCSCRequestMetric(EventListenerInterface): + """ + Listener that exports CSC request metric. + """ + def listen(self, event: Union[OnCacheHitEvent, OnCacheMissEvent]): + if isinstance(event, OnCacheHitEvent): + result = CSCResult.HIT + else: + result = CSCResult.MISS + + record_csc_request( + db_namespace=event.db_namespace, + result=result, + ) + +class InitialiseCSCItemsObservability(EventListenerInterface): + """ + Listener that initializes CSC items observability. + """ + def listen(self, event: OnCacheInitialisationEvent): + init_csc_items( + callback=event.cache_items_callback, + ) + +class ExportCSCEvictionMetric(EventListenerInterface): + """ + Listener that exports CSC eviction metric. + """ + def listen(self, event: OnCacheEvictionEvent): + record_csc_eviction( + count=event.count, + reason=event.reason, + ) + +class ExportCSCNetworkSavedMetric(EventListenerInterface): + """ + Listener that exports CSC network saved metric. + """ + def listen(self, event: OnCacheHitEvent): + record_csc_network_saved( + bytes_saved=event.bytes_saved, + db_namespace=event.db_namespace, + ) diff --git a/redis/observability/attributes.py b/redis/observability/attributes.py index 9493ff73da..2e62a0ea10 100644 --- a/redis/observability/attributes.py +++ b/redis/observability/attributes.py @@ -48,6 +48,8 @@ REDIS_CLIENT_STREAM_NAME = "redis.client.stream.name" REDIS_CLIENT_CONSUMER_GROUP = "redis.client.consumer_group" REDIS_CLIENT_CONSUMER_NAME = "redis.client.consumer_name" +REDIS_CLIENT_CSC_RESULT = "redis.client.csc.result" +REDIS_CLIENT_CSC_REASON = "redis.client.csc.reason" class ConnectionState(Enum): IDLE = "idle" @@ -57,6 +59,14 @@ class PubSubDirection(Enum): PUBLISH = "publish" RECEIVE = "receive" +class CSCResult(Enum): + HIT = "hit" + MISS = "miss" + +class CSCReason(Enum): + FULL = 'full' + INVALIDATION = 'invalidation' + class AttributeBuilder: """ @@ -274,6 +284,33 @@ def build_streaming_attributes( return attrs + @staticmethod + def build_csc_attributes( + db_namespace: Optional[int] = None, + result: Optional[CSCResult] = None, + reason: Optional[CSCReason] = None, + ) -> Dict[str, Any]: + """ + Build attributes for a Client Side Caching (CSC) operation. + + Args: + db_namespace: Redis database index + result: CSC result ('hit' or 'miss') + reason: Reason for CSC miss ('full' or 'invalidation') + + Returns: + Dictionary of CSC attributes + """ + attrs: Dict[str, Any] = AttributeBuilder.build_base_attributes(db_namespace=db_namespace) + + if result is not None: + attrs[REDIS_CLIENT_CSC_RESULT] = result.value + + if reason is not None: + attrs[REDIS_CLIENT_CSC_REASON] = reason.value + + return attrs + @staticmethod def build_pool_name( server_address: str, diff --git a/redis/observability/metrics.py b/redis/observability/metrics.py index 9a8793dfb6..250388f76a 100644 --- a/redis/observability/metrics.py +++ b/redis/observability/metrics.py @@ -11,7 +11,7 @@ from typing import Any, Dict, Optional, Callable, List from redis.observability.attributes import AttributeBuilder, ConnectionState, REDIS_CLIENT_CONNECTION_NOTIFICATION, \ - REDIS_CLIENT_CONNECTION_CLOSE_REASON, ERROR_TYPE, PubSubDirection + REDIS_CLIENT_CONNECTION_CLOSE_REASON, ERROR_TYPE, PubSubDirection, CSCResult, CSCReason from redis.observability.config import OTelConfig, MetricGroup logger = logging.getLogger(__name__) @@ -82,6 +82,9 @@ def __init__(self, meter: Meter, config: OTelConfig): if MetricGroup.STREAMING in self.config.metric_groups: self._init_streaming_metrics() + if MetricGroup.CSC in self.config.metric_groups: + self._init_csc_metrics() + logger.info("RedisMetricsCollector initialized") def _init_resiliency_metrics(self) -> None: @@ -169,6 +172,26 @@ def _init_streaming_metrics(self) -> None: description="End-to-end lag per message, showing how stale are the messages when the application starts processing them." ) + def _init_csc_metrics(self) -> None: + """Initialize Client Side Caching (CSC) metric instruments.""" + self.csc_requests = self.meter.create_counter( + name="redis.client.csc.requests", + unit="{request}", + description="The total number of requests to the cache", + ) + + self.csc_evictions = self.meter.create_counter( + name="redis.client.csc.evictions", + unit="{eviction}", + description="The total number of cache evictions", + ) + + self.csc_network_saved = self.meter.create_counter( + name="redis.client.csc.network_saved", + unit="{bytes}", + description="The total number of bytes saved by using CSC", + ) + # Resiliency metric recording methods def record_error_count( @@ -273,6 +296,26 @@ def init_connection_count( callbacks=[callback], ) + def init_csc_items( + self, + callback: Callable, + ) -> None: + """ + Initialize observable gauge for CSC items metric. + + Args: + callback: Callback function to retrieve CSC items count + """ + if not MetricGroup.CSC in self.config.metric_groups: + return + + self.csc_items = self.meter.create_observable_gauge( + name="redis.client.csc.items", + unit="{item}", + description="The total number of cached responses currently stored", + callbacks=[callback], + ) + def record_connection_timeout(self, pool_name: str) -> None: """ Record a connection timeout event. @@ -523,6 +566,64 @@ def record_streaming_lag( ) self.stream_lag.record(lag_seconds, attributes=attrs) + # CSC metric recording methods + + def record_csc_request( + self, + db_namespace: Optional[int] = None, + result: Optional[CSCResult] = None, + ) -> None: + """ + Record a Client Side Caching (CSC) request. + + Args: + db_namespace: Redis database index + result: CSC result ('hit' or 'miss') + """ + if not hasattr(self, "csc_requests"): + return + + attrs = self.attr_builder.build_csc_attributes(result=result, db_namespace=db_namespace) + self.csc_requests.add(1, attributes=attrs) + + def record_csc_eviction( + self, + count: int, + db_namespace: Optional[int] = None, + reason: Optional[CSCReason] = None, + ) -> None: + """ + Record a Client Side Caching (CSC) eviction. + + Args: + count: Number of evictions + db_namespace: Redis database index + reason: Reason for eviction + """ + if not hasattr(self, "csc_evictions"): + return + + attrs = self.attr_builder.build_csc_attributes(reason=reason, db_namespace=db_namespace) + self.csc_evictions.add(count, attributes=attrs) + + def record_csc_network_saved( + self, + bytes_saved: int, + db_namespace: Optional[int] = None, + ) -> None: + """ + Record the number of bytes saved by using Client Side Caching (CSC). + + Args: + db_namespace: Redis database index + bytes_saved: Number of bytes saved + """ + if not hasattr(self, "csc_network_saved"): + return + + attrs = self.attr_builder.build_csc_attributes(db_namespace=db_namespace) + self.csc_network_saved.add(bytes_saved, attributes=attrs) + # Utility methods @staticmethod diff --git a/redis/observability/recorder.py b/redis/observability/recorder.py index 2e67062b52..048b5bc373 100644 --- a/redis/observability/recorder.py +++ b/redis/observability/recorder.py @@ -22,7 +22,7 @@ import time from typing import Optional, Callable -from redis.observability.attributes import PubSubDirection, ConnectionState +from redis.observability.attributes import PubSubDirection, ConnectionState, CSCResult, CSCReason, AttributeBuilder from redis.observability.metrics import RedisMetricsCollector, CloseReason from redis.observability.providers import get_observability_instance @@ -486,6 +486,103 @@ def record_maint_notification_count( # except Exception: # pass +def record_csc_request( + db_namespace: Optional[int] = None, + result: Optional[CSCResult] = None, +): + """ + Record a Client Side Caching (CSC) request. + + Args: + db_namespace: Redis database index + result: CSC result ('hit' or 'miss') + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + _metrics_collector.record_csc_request( + db_namespace=db_namespace, + result=result, + ) + +def init_csc_items( + callback: Callable +) -> None: + """ + Initialize observable gauge for CSC items metric. + + Args: + callback: Callback function to retrieve CSC items count + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + # Lazy import + from opentelemetry.metrics import Observation + + def observation_wrapper(__): + return [Observation(callback(), attributes=AttributeBuilder.build_csc_attributes())] + + _metrics_collector.init_csc_items( + callback=observation_wrapper, + ) + +def record_csc_eviction( + count: int, + db_namespace: Optional[int] = None, + reason: Optional[CSCReason] = None, +) -> None: + """ + Record a Client Side Caching (CSC) eviction. + + Args: + count: Number of evictions + db_namespace: Redis database index + reason: Reason for eviction + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + _metrics_collector.record_csc_eviction( + count=count, + db_namespace=db_namespace, + reason=reason, + ) + +def record_csc_network_saved( + bytes_saved: int, + db_namespace: Optional[int] = None, +) -> None: + """ + Record the number of bytes saved by using Client Side Caching (CSC). + + Args: + bytes_saved: Number of bytes saved + db_namespace: Redis database index + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + _metrics_collector.record_csc_network_saved( + bytes_saved=bytes_saved, + db_namespace=db_namespace, + ) def _get_or_create_collector() -> Optional[RedisMetricsCollector]: """ diff --git a/tests/conftest.py b/tests/conftest.py index 9d2f51795a..f24cbd23b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,7 +32,7 @@ REDIS_INFO = {} default_redis_url = "redis://localhost:6379/0" -default_protocol = "2" +default_protocol = "3" default_redismod_url = "redis://localhost:6479" # default ssl client ignores verification for the purpose of testing diff --git a/tests/test_cache.py b/tests/test_cache.py index 265bcded04..60d6650363 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,4 +1,5 @@ import time +from unittest.mock import MagicMock import pytest import redis @@ -8,11 +9,19 @@ CacheEntry, CacheEntryStatus, CacheKey, + CacheProxy, DefaultCache, EvictionPolicy, EvictionPolicyType, LRUPolicy, ) +from redis.event import ( + EventDispatcher, + EventListenerInterface, + OnCacheEvictionEvent, + OnCacheInitialisationEvent, +) +from redis.observability.attributes import CSCReason from tests.conftest import _get_client, skip_if_resp_version, skip_if_server_version_lt @@ -1120,66 +1129,6 @@ def test_set_does_not_store_not_allowed_key(self, cache_key, mock_connection): ) ) - def test_set_evict_lru_cache_key_on_reaching_max_size(self, mock_connection): - cache = DefaultCache(CacheConfig(max_size=3)) - cache_key1 = CacheKey( - command="GET", redis_keys=("foo",), redis_args=("GET", "foo") - ) - cache_key2 = CacheKey( - command="GET", redis_keys=("foo1",), redis_args=("GET", "foo1") - ) - cache_key3 = CacheKey( - command="GET", redis_keys=("foo2",), redis_args=("GET", "foo2") - ) - - # Set 3 different keys - assert cache.set( - CacheEntry( - cache_key=cache_key1, - cache_value=b"bar", - status=CacheEntryStatus.VALID, - connection_ref=mock_connection, - ) - ) - assert cache.set( - CacheEntry( - cache_key=cache_key2, - cache_value=b"bar1", - status=CacheEntryStatus.VALID, - connection_ref=mock_connection, - ) - ) - assert cache.set( - CacheEntry( - cache_key=cache_key3, - cache_value=b"bar2", - status=CacheEntryStatus.VALID, - connection_ref=mock_connection, - ) - ) - - # Accessing key in the order that it makes 2nd key LRU - assert cache.get(cache_key1).cache_value == b"bar" - assert cache.get(cache_key2).cache_value == b"bar1" - assert cache.get(cache_key3).cache_value == b"bar2" - assert cache.get(cache_key1).cache_value == b"bar" - - cache_key4 = CacheKey( - command="GET", redis_keys=("foo3",), redis_args=("GET", "foo3") - ) - assert cache.set( - CacheEntry( - cache_key=cache_key4, - cache_value=b"bar3", - status=CacheEntryStatus.VALID, - connection_ref=mock_connection, - ) - ) - - # Make sure that new key was added and 2nd is evicted - assert cache.get(cache_key4).cache_value == b"bar3" - assert cache.get(cache_key2) is None - @pytest.mark.parametrize( "cache_key", [{"command": "GET", "redis_keys": ("bar",)}], indirect=True ) @@ -1560,3 +1509,271 @@ def test_is_exceeds_max_size(self, cache_conf: CacheConfig): def test_is_allowed_to_cache(self, cache_conf: CacheConfig): assert cache_conf.is_allowed_to_cache("GET") assert not cache_conf.is_allowed_to_cache("SET") + + +class TestUnitCacheProxy: + """Unit tests for CacheProxy class with mocked event dispatcher.""" + + @pytest.fixture + def mock_cache(self, mock_connection): + """Create a DefaultCache for testing.""" + return DefaultCache(CacheConfig(max_size=5)) + + @pytest.fixture + def mock_event_dispatcher(self): + """Create a mock event dispatcher.""" + return MagicMock(spec=EventDispatcher) + + @pytest.fixture + def cache_key(self): + """Create a sample cache key.""" + return CacheKey(command="GET", redis_keys=("foo",), redis_args=("GET", "foo")) + + def test_initialization_emits_cache_initialisation_event(self, mock_cache): + """Test that CacheProxy emits OnCacheInitialisationEvent on initialization.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + OnCacheInitialisationEvent: [listener], + }) + + CacheProxy(mock_cache, event_dispatcher) + + listener.listen.assert_called_once() + event = listener.listen.call_args[0][0] + assert isinstance(event, OnCacheInitialisationEvent) + assert callable(event.cache_items_callback) + + def test_initialization_event_callback_returns_cache_size( + self, mock_cache, mock_connection + ): + """Test that the cache_items_callback returns the current cache size.""" + event_dispatcher = EventDispatcher() + listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + OnCacheInitialisationEvent: [listener], + }) + + proxy = CacheProxy(mock_cache, event_dispatcher) + + event = listener.listen.call_args[0][0] + assert event.cache_items_callback() == 0 + + # Add an entry and verify callback reflects new size + cache_key = CacheKey(command="GET", redis_keys=("foo",), redis_args=("GET", "foo")) + proxy.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert event.cache_items_callback() == 1 + + def test_initialization_creates_default_event_dispatcher_when_none_provided( + self, mock_cache + ): + """Test that CacheProxy creates a default EventDispatcher when none is provided.""" + # Should not raise an error + proxy = CacheProxy(mock_cache) + assert proxy is not None + + def test_set_emits_eviction_event_when_cache_exceeds_max_size( + self, mock_connection + ): + """Test that OnCacheEvictionEvent is emitted when cache exceeds max size.""" + # Create a cache with max_size=2 + cache = DefaultCache(CacheConfig(max_size=2)) + event_dispatcher = EventDispatcher() + eviction_listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + OnCacheEvictionEvent: [eviction_listener], + }) + + proxy = CacheProxy(cache, event_dispatcher) + + # Add 2 entries (at max capacity) + for i in range(2): + cache_key = CacheKey( + command="GET", redis_keys=(f"key{i}",), redis_args=("GET", f"key{i}") + ) + proxy.set( + CacheEntry( + cache_key=cache_key, + cache_value=f"value{i}".encode(), + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + # No eviction event yet + eviction_listener.listen.assert_not_called() + + # Add a 3rd entry, which should trigger eviction + cache_key = CacheKey( + command="GET", redis_keys=("key3",), redis_args=("GET", "key3") + ) + proxy.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"value3", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + # Eviction event should be emitted + eviction_listener.listen.assert_called_once() + event = eviction_listener.listen.call_args[0][0] + assert isinstance(event, OnCacheEvictionEvent) + assert event.count == 1 + assert event.reason == CSCReason.FULL + + def test_set_does_not_emit_eviction_event_when_under_max_size( + self, mock_cache, mock_connection + ): + """Test that OnCacheEvictionEvent is NOT emitted when cache is under max size.""" + event_dispatcher = EventDispatcher() + eviction_listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + OnCacheEvictionEvent: [eviction_listener], + }) + + proxy = CacheProxy(mock_cache, event_dispatcher) + + cache_key = CacheKey(command="GET", redis_keys=("foo",), redis_args=("GET", "foo")) + proxy.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + eviction_listener.listen.assert_not_called() + + def test_collection_property_delegates_to_underlying_cache(self, mock_cache): + """Test that collection property returns the underlying cache's collection.""" + proxy = CacheProxy(mock_cache) + assert proxy.collection is mock_cache.collection + + def test_config_property_delegates_to_underlying_cache(self, mock_cache): + """Test that config property returns the underlying cache's config.""" + proxy = CacheProxy(mock_cache) + assert proxy.config is mock_cache.config + + def test_eviction_policy_property_delegates_to_underlying_cache(self, mock_cache): + """Test that eviction_policy property returns the underlying cache's eviction_policy.""" + proxy = CacheProxy(mock_cache) + assert proxy.eviction_policy is mock_cache.eviction_policy + + def test_size_property_delegates_to_underlying_cache( + self, mock_cache, mock_connection + ): + """Test that size property returns the underlying cache's size.""" + proxy = CacheProxy(mock_cache) + assert proxy.size == 0 + + cache_key = CacheKey(command="GET", redis_keys=("foo",), redis_args=("GET", "foo")) + proxy.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + assert proxy.size == 1 + + def test_get_delegates_to_underlying_cache(self, mock_cache, mock_connection): + """Test that get method delegates to the underlying cache.""" + proxy = CacheProxy(mock_cache) + + cache_key = CacheKey(command="GET", redis_keys=("foo",), redis_args=("GET", "foo")) + entry = CacheEntry( + cache_key=cache_key, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + proxy.set(entry) + + result = proxy.get(cache_key) + assert result is not None + assert result.cache_value == b"bar" + + def test_delete_by_cache_keys_delegates_to_underlying_cache( + self, mock_cache, mock_connection + ): + """Test that delete_by_cache_keys method delegates to the underlying cache.""" + proxy = CacheProxy(mock_cache) + + cache_key = CacheKey(command="GET", redis_keys=("foo",), redis_args=("GET", "foo")) + proxy.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + result = proxy.delete_by_cache_keys([cache_key]) + assert result == [True] + assert proxy.get(cache_key) is None + + def test_delete_by_redis_keys_delegates_to_underlying_cache( + self, mock_cache, mock_connection + ): + """Test that delete_by_redis_keys method delegates to the underlying cache.""" + proxy = CacheProxy(mock_cache) + + cache_key = CacheKey(command="GET", redis_keys=(b"foo",), redis_args=("GET", "foo")) + proxy.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + result = proxy.delete_by_redis_keys([b"foo"]) + assert result == [True] + assert proxy.get(cache_key) is None + + def test_flush_delegates_to_underlying_cache(self, mock_cache, mock_connection): + """Test that flush method delegates to the underlying cache.""" + proxy = CacheProxy(mock_cache) + + for i in range(3): + cache_key = CacheKey( + command="GET", redis_keys=(f"key{i}",), redis_args=("GET", f"key{i}") + ) + proxy.set( + CacheEntry( + cache_key=cache_key, + cache_value=f"value{i}".encode(), + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + assert proxy.size == 3 + result = proxy.flush() + assert result == 3 + assert proxy.size == 0 + + def test_is_cachable_delegates_to_underlying_cache(self, mock_cache): + """Test that is_cachable method delegates to the underlying cache.""" + proxy = CacheProxy(mock_cache) + + # GET is cachable by default + cache_key = CacheKey(command="GET", redis_keys=("foo",), redis_args=()) + assert proxy.is_cachable(cache_key) is True + + # SET is not cachable + cache_key = CacheKey(command="SET", redis_keys=("foo",), redis_args=()) + assert proxy.is_cachable(cache_key) is False diff --git a/tests/test_connection.py b/tests/test_connection.py index 441528ca6b..498041fcd8 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -7,7 +7,7 @@ from errno import ECONNREFUSED from typing import Any from unittest import mock -from unittest.mock import call, patch +from unittest.mock import call, patch, MagicMock, Mock import pytest import redis @@ -31,6 +31,12 @@ parse_url, ) from redis.credentials import UsernamePasswordCredentialProvider +from redis.event import ( + EventDispatcher, + EventListenerInterface, + OnCacheHitEvent, + OnCacheMissEvent, +) from redis.exceptions import ConnectionError, InvalidResponse, RedisError, TimeoutError from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE @@ -446,7 +452,9 @@ def test_clears_cache_on_disconnect(self, mock_connection, cache_conf): mock_connection.retry = "mock" mock_connection.host = "mock" mock_connection.port = "mock" + mock_connection.db = 0 mock_connection.credential_provider = UsernamePasswordCredentialProvider() + mock_connection._event_dispatcher = EventDispatcher() proxy_connection = CacheProxyConnection( mock_connection, cache, threading.RLock() @@ -463,7 +471,9 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): mock_connection.retry = "mock" mock_connection.host = "mock" mock_connection.port = "mock" + mock_connection.db = 0 mock_connection.credential_provider = UsernamePasswordCredentialProvider() + mock_connection._event_dispatcher = EventDispatcher() mock_cache.is_cachable.return_value = True mock_cache.get.side_effect = [ @@ -573,7 +583,9 @@ def test_triggers_invalidation_processing_on_another_connection( mock_connection.retry = "mock" mock_connection.host = "mock" mock_connection.port = "mock" + mock_connection.db = 0 mock_connection.credential_provider = UsernamePasswordCredentialProvider() + mock_connection._event_dispatcher = Mock(spec=EventDispatcher) another_conn = copy.deepcopy(mock_connection) another_conn.can_read.side_effect = [True, False] @@ -598,3 +610,171 @@ def test_triggers_invalidation_processing_on_another_connection( assert proxy_connection.read_response() == b"bar" assert another_conn.can_read.call_count == 2 another_conn.read_response.assert_called_once() + + @pytest.mark.skipif( + platform.python_implementation() == "PyPy", + reason="Pypy doesn't support side_effect", + ) + def test_cache_hit_event_emitted_on_cached_response(self, mock_connection): + """Test that OnCacheHitEvent is emitted when returning a cached response.""" + cache = DefaultCache(CacheConfig(max_size=10)) + event_dispatcher = EventDispatcher() + cache_hit_listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + OnCacheHitEvent: [cache_hit_listener], + }) + + mock_connection.retry = "mock" + mock_connection.host = "mock" + mock_connection.port = "mock" + mock_connection.db = 0 + mock_connection.credential_provider = UsernamePasswordCredentialProvider() + mock_connection.can_read.return_value = False + mock_connection._event_dispatcher = event_dispatcher + + cache_key = CacheKey( + command="GET", redis_keys=("foo",), redis_args=("GET", "foo") + ) + cache.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, + ) + ) + + proxy_connection = CacheProxyConnection( + mock_connection, cache, threading.RLock() + ) + # Manually set the cache key to simulate send_command having been called + proxy_connection._current_command_cache_key = cache_key + + result = proxy_connection.read_response() + + assert result == b"bar" + cache_hit_listener.listen.assert_called_once() + event = cache_hit_listener.listen.call_args[0][0] + assert isinstance(event, OnCacheHitEvent) + assert event.bytes_saved == len(b"bar") + assert event.db_namespace == 0 + + @pytest.mark.skipif( + platform.python_implementation() == "PyPy", + reason="Pypy doesn't support side_effect", + ) + def test_cache_miss_event_emitted_on_uncached_response(self, mock_connection): + """Test that OnCacheMissEvent is emitted when cache miss occurs.""" + cache = DefaultCache(CacheConfig(max_size=10)) + event_dispatcher = EventDispatcher() + cache_miss_listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + OnCacheMissEvent: [cache_miss_listener], + }) + + mock_connection.retry = "mock" + mock_connection.host = "mock" + mock_connection.port = "mock" + mock_connection.db = 0 + mock_connection.credential_provider = UsernamePasswordCredentialProvider() + mock_connection.can_read.return_value = False + mock_connection.send_command.return_value = None + mock_connection.read_response.return_value = b"bar" + mock_connection._event_dispatcher = event_dispatcher + + proxy_connection = CacheProxyConnection( + mock_connection, cache, threading.RLock() + ) + proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]}) + result = proxy_connection.read_response() + + assert result == b"bar" + cache_miss_listener.listen.assert_called_once() + event = cache_miss_listener.listen.call_args[0][0] + assert isinstance(event, OnCacheMissEvent) + assert event.db_namespace == 0 + + @pytest.mark.skipif( + platform.python_implementation() == "PyPy", + reason="Pypy doesn't support side_effect", + ) + def test_cache_miss_event_emitted_for_non_cachable_command(self, mock_connection): + """Test that OnCacheMissEvent is emitted for non-cachable commands.""" + cache = DefaultCache(CacheConfig(max_size=10)) + event_dispatcher = EventDispatcher() + cache_miss_listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + OnCacheMissEvent: [cache_miss_listener], + }) + + mock_connection.retry = "mock" + mock_connection.host = "mock" + mock_connection.port = "mock" + mock_connection.db = 0 + mock_connection.credential_provider = UsernamePasswordCredentialProvider() + mock_connection.can_read.return_value = False + mock_connection.send_command.return_value = None + mock_connection.read_response.return_value = b"OK" + mock_connection._event_dispatcher = event_dispatcher + + proxy_connection = CacheProxyConnection( + mock_connection, cache, threading.RLock() + ) + # SET is not cachable + proxy_connection.send_command(*["SET", "foo", "bar"]) + result = proxy_connection.read_response() + + assert result == b"OK" + cache_miss_listener.listen.assert_called_once() + event = cache_miss_listener.listen.call_args[0][0] + assert isinstance(event, OnCacheMissEvent) + + @pytest.mark.skipif( + platform.python_implementation() == "PyPy", + reason="Pypy doesn't support side_effect", + ) + def test_cache_hit_not_emitted_for_in_progress_entry(self, mock_connection): + """Test that OnCacheHitEvent is NOT emitted when cache entry is IN_PROGRESS.""" + cache = DefaultCache(CacheConfig(max_size=10)) + event_dispatcher = EventDispatcher() + cache_hit_listener = MagicMock(spec=EventListenerInterface) + cache_miss_listener = MagicMock(spec=EventListenerInterface) + event_dispatcher.register_listeners({ + OnCacheHitEvent: [cache_hit_listener], + OnCacheMissEvent: [cache_miss_listener], + }) + + mock_connection.retry = "mock" + mock_connection.host = "mock" + mock_connection.port = "mock" + mock_connection.db = 0 + mock_connection.credential_provider = UsernamePasswordCredentialProvider() + mock_connection.can_read.return_value = False + mock_connection.read_response.return_value = b"bar" + mock_connection._event_dispatcher = event_dispatcher + + cache_key = CacheKey( + command="GET", redis_keys=("foo",), redis_args=("GET", "foo") + ) + # Set entry with IN_PROGRESS status + cache.set( + CacheEntry( + cache_key=cache_key, + cache_value=CacheProxyConnection.DUMMY_CACHE_VALUE, + status=CacheEntryStatus.IN_PROGRESS, + connection_ref=mock_connection, + ) + ) + + proxy_connection = CacheProxyConnection( + mock_connection, cache, threading.RLock() + ) + proxy_connection._current_command_cache_key = cache_key + + result = proxy_connection.read_response() + + assert result == b"bar" + # Cache hit should NOT be emitted for IN_PROGRESS entry + cache_hit_listener.listen.assert_not_called() + # Cache miss should be emitted instead + cache_miss_listener.listen.assert_called_once() From 7359dabce52b2cdd9a874a3df10e10facb0b3f9b Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 14 Jan 2026 17:31:03 +0200 Subject: [PATCH 3/5] Revert changes --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index f24cbd23b8..9d2f51795a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,7 +32,7 @@ REDIS_INFO = {} default_redis_url = "redis://localhost:6379/0" -default_protocol = "3" +default_protocol = "2" default_redismod_url = "redis://localhost:6479" # default ssl client ignores verification for the purpose of testing From fe7c8509fdc686ecb34162a3726e3b1364b63f85 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Sun, 18 Jan 2026 13:12:10 +0200 Subject: [PATCH 4/5] Added observable gauge registry and refactored observables metric export --- redis/cache.py | 8 +- redis/connection.py | 17 +- redis/event.py | 18 +- redis/observability/metrics.py | 6 +- redis/observability/recorder.py | 96 +++-- redis/observability/registry.py | 64 ++++ tests/test_connection.py | 198 +++++++++- tests/test_observability/test_recorder.py | 429 +++++++++++++++++++++- 8 files changed, 790 insertions(+), 46 deletions(-) create mode 100644 redis/observability/registry.py diff --git a/redis/cache.py b/redis/cache.py index 2384c43259..70b22ea027 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Any, List, Optional, Union -from redis.event import EventDispatcherInterface, EventDispatcher, OnCacheInitialisationEvent, \ +from redis.event import EventDispatcherInterface, EventDispatcher, \ OnCacheEvictionEvent, OnCacheHitEvent from redis.observability.attributes import CSCResult, CSCReason @@ -260,12 +260,6 @@ def __init__(self, cache: CacheInterface, event_dispatcher: Optional[EventDispat else: self._event_dispatcher = event_dispatcher - self._event_dispatcher.dispatch( - OnCacheInitialisationEvent( - cache_items_callback=lambda: self._cache.size, - ) - ) - @property def collection(self) -> OrderedDict: return self._cache.collection diff --git a/redis/connection.py b/redis/connection.py index 157c96bcb5..524410e5e6 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -37,7 +37,7 @@ from .credentials import CredentialProvider, UsernamePasswordCredentialProvider from .event import AfterConnectionReleasedEvent, EventDispatcher, OnErrorEvent, OnMaintenanceNotificationEvent, \ AfterConnectionCreatedEvent, AfterConnectionAcquiredEvent, AfterConnectionClosedEvent, OnCacheHitEvent, \ - OnCacheMissEvent, OnCacheEvictionEvent + OnCacheMissEvent, OnCacheEvictionEvent, OnCacheInitialisationEvent from .exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, @@ -2547,6 +2547,10 @@ def __init__( self.cache = None self._cache_factory = cache_factory + self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None) + if self._event_dispatcher is None: + self._event_dispatcher = EventDispatcher() + if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): if self._connection_kwargs.get("protocol") not in [3, "3"]: raise RedisError("Client caching is only supported with RESP version 3") @@ -2566,13 +2570,16 @@ def __init__( self._connection_kwargs.get("cache_config") ).get_cache() + self._event_dispatcher.dispatch( + OnCacheInitialisationEvent( + cache_items_callback=lambda: self.cache.size, + db_namespace=self._connection_kwargs.get("db"), + ) + ) + connection_kwargs.pop("cache", None) connection_kwargs.pop("cache_config", None) - self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None) - if self._event_dispatcher is None: - self._event_dispatcher = EventDispatcher() - # a lock to protect the critical section in _checkpid(). # this lock is acquired when the process id changes, such as # after a fork. during this time, multiple threads in the child diff --git a/redis/event.py b/redis/event.py index 92edd22e80..e060287769 100644 --- a/redis/event.py +++ b/redis/event.py @@ -12,7 +12,8 @@ from redis.observability.recorder import record_operation_duration, record_error_count, record_maint_notification_count, \ record_connection_create_time, init_connection_count, record_connection_relaxed_timeout, record_connection_handoff, \ record_pubsub_message, record_streaming_lag, record_connection_wait_time, record_connection_use_time, \ - record_connection_closed, record_csc_request, init_csc_items, record_csc_eviction, record_csc_network_saved + record_connection_closed, record_csc_request, init_csc_items, record_csc_eviction, record_csc_network_saved, \ + register_pools_connection_count, register_csc_items_callback from redis.utils import str_if_bytes @@ -445,6 +446,7 @@ class OnCacheInitialisationEvent: Event fired after cache is initialized. """ cache_items_callback: Callable + db_namespace: Optional[int] = None @dataclass class OnCacheEvictionEvent: @@ -685,7 +687,11 @@ class InitializeConnectionCountObservability(EventListenerInterface): Listener that initializes connection count observability. """ def listen(self, event: AfterPooledConnectionsInstantiationEvent): - init_connection_count(event.connection_pools) + # Initialize gauge only once, subsequent calls won't have an affect. + init_connection_count() + + # Register pools for connection count observability. + register_pools_connection_count(event.connection_pools) class ExportConnectionRelaxedTimeoutMetric(EventListenerInterface): """ @@ -802,9 +808,11 @@ class InitialiseCSCItemsObservability(EventListenerInterface): Listener that initializes CSC items observability. """ def listen(self, event: OnCacheInitialisationEvent): - init_csc_items( - callback=event.cache_items_callback, - ) + # Initialize gauge only once, subsequent calls won't have an affect. + init_csc_items() + + # Register cache items callback for CSC items observability. + register_csc_items_callback(event.cache_items_callback, event.db_namespace) class ExportCSCEvictionMetric(EventListenerInterface): """ diff --git a/redis/observability/metrics.py b/redis/observability/metrics.py index 250388f76a..74dd2168ed 100644 --- a/redis/observability/metrics.py +++ b/redis/observability/metrics.py @@ -286,7 +286,8 @@ def init_connection_count( Args: callback: Callback function to retrieve connection count """ - if not MetricGroup.CONNECTION_BASIC in self.config.metric_groups: + if not MetricGroup.CONNECTION_BASIC in self.config.metric_groups \ + and not self.connection_count: return self.connection_count = self.meter.create_observable_gauge( @@ -306,7 +307,8 @@ def init_csc_items( Args: callback: Callback function to retrieve CSC items count """ - if not MetricGroup.CSC in self.config.metric_groups: + if not MetricGroup.CSC in self.config.metric_groups \ + and not self.csc_items: return self.csc_items = self.meter.create_observable_gauge( diff --git a/redis/observability/recorder.py b/redis/observability/recorder.py index 048b5bc373..309585a69c 100644 --- a/redis/observability/recorder.py +++ b/redis/observability/recorder.py @@ -20,15 +20,19 @@ """ import time -from typing import Optional, Callable +from typing import Optional, Callable, List from redis.observability.attributes import PubSubDirection, ConnectionState, CSCResult, CSCReason, AttributeBuilder from redis.observability.metrics import RedisMetricsCollector, CloseReason from redis.observability.providers import get_observability_instance +from redis.observability.registry import get_observables_registry_instance # Global metrics collector instance (lazy-initialized) _metrics_collector: Optional[RedisMetricsCollector] = None +CONNECTION_COUNT_REGISTRY_KEY = "connection_count" +CSC_ITEMS_REGISTRY_KEY = "csc_items" + def record_operation_duration( command_name: str, @@ -125,14 +129,9 @@ def record_connection_create_time( # pass -def init_connection_count( - connection_pools: list, -) -> None: +def init_connection_count() -> None: """ Initialize observable gauge for connection count metric. - - Args: - connection_pools: Connection pools to collect metrics from. """ global _metrics_collector @@ -141,23 +140,48 @@ def init_connection_count( if _metrics_collector is None: return - # Lazy import - from opentelemetry.metrics import Observation - - def connection_count_callback(__): + def observable_callback(__): + observables_registry = get_observables_registry_instance() + callbacks = observables_registry.get(CONNECTION_COUNT_REGISTRY_KEY) observations = [] - for pool in connection_pools: - for count, attributes in pool.get_connection_count(): - observations.append(Observation(count, attributes)) + + for callback in callbacks: + observations.extend(callback()) + return observations # try: _metrics_collector.init_connection_count( - callback=connection_count_callback, + callback=observable_callback, ) # except Exception: # pass +def register_pools_connection_count( + connection_pools: List["ConnectionPoolInterface"], +) -> None: + """ + Add connection pools to connection count observable registry. + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + # Lazy import + from opentelemetry.metrics import Observation + + def connection_count_callback(): + observations = [] + for connection_pool in connection_pools: + for count, attributes in connection_pool.get_connection_count(): + observations.append(Observation(count, attributes=attributes)) + return observations + + observables_registry = get_observables_registry_instance() + observables_registry.register(CONNECTION_COUNT_REGISTRY_KEY, connection_count_callback) def record_connection_timeout( pool_name: str, @@ -509,14 +533,37 @@ def record_csc_request( result=result, ) -def init_csc_items( - callback: Callable -) -> None: +def init_csc_items() -> None: """ Initialize observable gauge for CSC items metric. + """ + global _metrics_collector - Args: - callback: Callback function to retrieve CSC items count + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + def observable_callback(__): + observables_registry = get_observables_registry_instance() + callbacks = observables_registry.get(CSC_ITEMS_REGISTRY_KEY) + observations = [] + + for callback in callbacks: + observations.extend(callback()) + + return observations + + _metrics_collector.init_csc_items( + callback=observable_callback, + ) + +def register_csc_items_callback( + callback: Callable, + db_namespace: Optional[int] = None, +) -> None: + """ + Adds given callback to CSC items observable registry. """ global _metrics_collector @@ -528,12 +575,11 @@ def init_csc_items( # Lazy import from opentelemetry.metrics import Observation - def observation_wrapper(__): - return [Observation(callback(), attributes=AttributeBuilder.build_csc_attributes())] + def csc_items_callback(): + return [Observation(callback(), attributes=AttributeBuilder.build_csc_attributes(db_namespace=db_namespace))] - _metrics_collector.init_csc_items( - callback=observation_wrapper, - ) + observables_registry = get_observables_registry_instance() + observables_registry.register(CSC_ITEMS_REGISTRY_KEY, csc_items_callback) def record_csc_eviction( count: int, diff --git a/redis/observability/registry.py b/redis/observability/registry.py new file mode 100644 index 0000000000..4bdc56ff54 --- /dev/null +++ b/redis/observability/registry.py @@ -0,0 +1,64 @@ +import threading +from typing import Dict, List, Callable, Optional, Any + +from opentelemetry.metrics import Observation + + +class ObservablesRegistry: + """ + Global registry for storing callbacks for observable metrics. + """ + def __init__(self, registry: Dict[str, List[Callable[[], List[Observation]]]] = None): + self._registry = registry or {} + self._lock = threading.Lock() + + def register(self, name: str, callback: Callable[[], List[Observation]]) -> None: + """ + Register a callback for an observable metric. + """ + with self._lock: + self._registry.setdefault(name, []).append(callback) + + def get(self, name: str) -> List[Callable[[], List[Observation]]]: + """ + Get all callbacks for an observable metric. + """ + with self._lock: + return self._registry.get(name, []) + + def clear(self) -> None: + """ + Clear the registry. + """ + with self._lock: + self._registry.clear() + + def __len__(self) -> int: + """ + Get the number of registered callbacks. + """ + return len(self._registry) + +# Global singleton instance +_observables_registry_instance: Optional[ObservablesRegistry] = None + +def get_observables_registry_instance() -> ObservablesRegistry: + """ + Get the global observables registry singleton instance. + + This is the Pythonic way to get the singleton instance. + + Returns: + The global ObservablesRegistry singleton + + Example: + >>> + >>> registry = get_observables_registry_instance() + >>> registry.register('my_metric', my_callback) + """ + global _observables_registry_instance + + if _observables_registry_instance is None: + _observables_registry_instance = ObservablesRegistry() + + return _observables_registry_instance \ No newline at end of file diff --git a/tests/test_connection.py b/tests/test_connection.py index 498041fcd8..67f96f00be 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -28,7 +28,7 @@ Connection, SSLConnection, UnixDomainSocketConnection, - parse_url, + parse_url, BlockingConnectionPool, ) from redis.credentials import UsernamePasswordCredentialProvider from redis.event import ( @@ -38,6 +38,7 @@ OnCacheMissEvent, ) from redis.exceptions import ConnectionError, InvalidResponse, RedisError, TimeoutError +from redis.observability.attributes import DB_CLIENT_CONNECTION_POOL_NAME, DB_CLIENT_CONNECTION_STATE, ConnectionState from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE @@ -778,3 +779,198 @@ def test_cache_hit_not_emitted_for_in_progress_entry(self, mock_connection): cache_hit_listener.listen.assert_not_called() # Cache miss should be emitted instead cache_miss_listener.listen.assert_called_once() + + +class TestConnectionPoolGetConnectionCount: + """Tests for ConnectionPool.get_connection_count() method.""" + + def test_get_connection_count_returns_idle_and_used_counts(self): + """Test that get_connection_count returns both idle and used connection counts.""" + pool = ConnectionPool(max_connections=10) + + # Initially, no connections exist + counts = pool.get_connection_count() + assert len(counts) == 2 + + # Check idle connections count + idle_count, idle_attrs = counts[0] + assert idle_count == 0 + assert DB_CLIENT_CONNECTION_POOL_NAME in idle_attrs + assert idle_attrs[DB_CLIENT_CONNECTION_STATE] == ConnectionState.IDLE.value + + # Check used connections count + used_count, used_attrs = counts[1] + assert used_count == 0 + assert DB_CLIENT_CONNECTION_POOL_NAME in used_attrs + assert used_attrs[DB_CLIENT_CONNECTION_STATE] == ConnectionState.USED.value + + pool.disconnect() + + def test_get_connection_count_with_connections_in_use(self): + """Test get_connection_count when connections are in use.""" + + pool = ConnectionPool(max_connections=10) + + # Create mock connections + mock_conn1 = MagicMock() + mock_conn1.pid = pool.pid + + mock_conn2 = MagicMock() + mock_conn2.pid = pool.pid + + # Simulate connections in use + pool._in_use_connections.add(mock_conn1) + pool._in_use_connections.add(mock_conn2) + + counts = pool.get_connection_count() + + idle_count, idle_attrs = counts[0] + used_count, used_attrs = counts[1] + + assert idle_count == 0 + assert used_count == 2 + assert idle_attrs[DB_CLIENT_CONNECTION_STATE] == ConnectionState.IDLE.value + assert used_attrs[DB_CLIENT_CONNECTION_STATE] == ConnectionState.USED.value + + pool.disconnect() + + def test_get_connection_count_with_available_connections(self): + """Test get_connection_count when connections are available (idle).""" + + pool = ConnectionPool(max_connections=10) + + # Create mock connections + mock_conn1 = MagicMock() + mock_conn1.pid = pool.pid + + mock_conn2 = MagicMock() + mock_conn2.pid = pool.pid + + mock_conn3 = MagicMock() + mock_conn3.pid = pool.pid + + # Simulate available connections + pool._available_connections.append(mock_conn1) + pool._available_connections.append(mock_conn2) + pool._available_connections.append(mock_conn3) + + counts = pool.get_connection_count() + + idle_count, idle_attrs = counts[0] + used_count, used_attrs = counts[1] + + assert idle_count == 3 + assert used_count == 0 + assert idle_attrs[DB_CLIENT_CONNECTION_STATE] == ConnectionState.IDLE.value + assert used_attrs[DB_CLIENT_CONNECTION_STATE] == ConnectionState.USED.value + + pool.disconnect() + + def test_get_connection_count_mixed_connections(self): + """Test get_connection_count with both idle and used connections.""" + + pool = ConnectionPool(max_connections=10) + + # Create mock connections + mock_idle = MagicMock() + mock_idle.pid = pool.pid + + mock_used1 = MagicMock() + mock_used1.pid = pool.pid + + mock_used2 = MagicMock() + mock_used2.pid = pool.pid + + # Simulate mixed state + pool._available_connections.append(mock_idle) + pool._in_use_connections.add(mock_used1) + pool._in_use_connections.add(mock_used2) + + counts = pool.get_connection_count() + + idle_count, _ = counts[0] + used_count, _ = counts[1] + + assert idle_count == 1 + assert used_count == 2 + + pool.disconnect() + + def test_get_connection_count_includes_pool_name_in_attributes(self): + """Test that get_connection_count includes pool name in attributes.""" + + pool = ConnectionPool(max_connections=10) + + counts = pool.get_connection_count() + + idle_count, idle_attrs = counts[0] + used_count, used_attrs = counts[1] + + # Both should have the pool name + assert DB_CLIENT_CONNECTION_POOL_NAME in idle_attrs + assert DB_CLIENT_CONNECTION_POOL_NAME in used_attrs + + # Pool name should be the repr of the pool + assert repr(pool) in idle_attrs[DB_CLIENT_CONNECTION_POOL_NAME] + assert repr(pool) in used_attrs[DB_CLIENT_CONNECTION_POOL_NAME] + + pool.disconnect() + + +class TestBlockingConnectionPoolGetConnectionCount: + """Tests for BlockingConnectionPool.get_connection_count() method.""" + + def test_get_connection_count_returns_idle_and_used_counts(self): + """Test that BlockingConnectionPool.get_connection_count returns both counts.""" + + pool = BlockingConnectionPool(max_connections=10) + + # Initially, no connections exist + counts = pool.get_connection_count() + assert len(counts) == 2 + + idle_count, idle_attrs = counts[0] + used_count, used_attrs = counts[1] + + assert idle_count == 0 + assert used_count == 0 + assert idle_attrs[DB_CLIENT_CONNECTION_STATE] == ConnectionState.IDLE.value + assert used_attrs[DB_CLIENT_CONNECTION_STATE] == ConnectionState.USED.value + + pool.disconnect() + + def test_get_connection_count_with_connections_in_queue(self): + """Test get_connection_count when connections are in the queue (idle).""" + + pool = BlockingConnectionPool(max_connections=10) + + # Create mock connections and add to queue + mock_conn1 = MagicMock() + mock_conn1.pid = pool.pid + + mock_conn2 = MagicMock() + mock_conn2.pid = pool.pid + + # Add connections to the pool's internal list and queue + pool._connections.append(mock_conn1) + pool._connections.append(mock_conn2) + + # Clear the queue and add our connections + while not pool.pool.empty(): + try: + pool.pool.get_nowait() + except Exception: + break + + pool.pool.put_nowait(mock_conn1) + pool.pool.put_nowait(mock_conn2) + + counts = pool.get_connection_count() + + idle_count, _ = counts[0] + used_count, _ = counts[1] + + assert idle_count == 2 + assert used_count == 0 + + pool.disconnect() diff --git a/tests/test_observability/test_recorder.py b/tests/test_observability/test_recorder.py index 09127de49e..54401bc421 100644 --- a/tests/test_observability/test_recorder.py +++ b/tests/test_observability/test_recorder.py @@ -8,9 +8,10 @@ """ import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, call from redis.observability import recorder +from redis.observability.registry import ObservablesRegistry, get_observables_registry_instance from redis.observability.attributes import ( ConnectionState, PubSubDirection, @@ -998,3 +999,429 @@ def test_enabled_group_receives_meter_calls_disabled_group_does_not(self): instruments.client_errors.add.assert_not_called() instruments.maintenance_notifications.add.assert_not_called() instruments.stream_lag.record.assert_not_called() + + +class TestObservablesRegistry: + """Tests for ObservablesRegistry singleton and callback registration.""" + + def test_registry_singleton_returns_same_instance(self): + """Test that get_observables_registry_instance returns the same instance.""" + registry1 = get_observables_registry_instance() + registry2 = get_observables_registry_instance() + assert registry1 is registry2 + + def test_registry_register_and_get_callbacks(self): + """Test registering and retrieving callbacks from the registry.""" + registry = ObservablesRegistry() + + callback1 = MagicMock(return_value=[]) + callback2 = MagicMock(return_value=[]) + + registry.register("test_metric", callback1) + registry.register("test_metric", callback2) + + callbacks = registry.get("test_metric") + assert len(callbacks) == 2 + assert callback1 in callbacks + assert callback2 in callbacks + + def test_registry_get_returns_empty_list_for_unknown_key(self): + """Test that get returns empty list for unknown metric key.""" + registry = ObservablesRegistry() + callbacks = registry.get("unknown_metric") + assert callbacks == [] + + def test_registry_clear_removes_all_callbacks(self): + """Test that clear removes all registered callbacks.""" + registry = ObservablesRegistry() + + registry.register("metric1", MagicMock()) + registry.register("metric2", MagicMock()) + + registry.clear() + + assert registry.get("metric1") == [] + assert registry.get("metric2") == [] + assert len(registry) == 0 + + +class TestInitConnectionCount: + """Tests for init_connection_count and register_pools_connection_count.""" + + @pytest.fixture + def mock_observable_gauge(self): + """Create a mock observable gauge.""" + return MagicMock() + + @pytest.fixture + def mock_meter_with_observable(self, mock_observable_gauge): + """Create a mock meter that returns our mock observable gauge.""" + meter = MagicMock() + meter.create_observable_gauge.return_value = mock_observable_gauge + return meter + + @pytest.fixture + def mock_config_with_connection_basic(self): + """Create a config with CONNECTION_BASIC enabled.""" + return OTelConfig(metric_groups=[MetricGroup.CONNECTION_BASIC]) + + @pytest.fixture + def setup_connection_count_recorder( + self, mock_meter_with_observable, mock_config_with_connection_basic + ): + """Setup recorder with mocked meter for connection count tests.""" + recorder.reset_collector() + get_observables_registry_instance().clear() + + with patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector( + mock_meter_with_observable, mock_config_with_connection_basic + ) + + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + yield mock_meter_with_observable + + recorder.reset_collector() + get_observables_registry_instance().clear() + + def test_init_connection_count_creates_observable_gauge( + self, setup_connection_count_recorder + ): + """Test that init_connection_count creates an observable gauge.""" + mock_meter = setup_connection_count_recorder + + recorder.init_connection_count() + + mock_meter.create_observable_gauge.assert_called_once() + call_kwargs = mock_meter.create_observable_gauge.call_args[1] + assert call_kwargs['name'] == 'db.client.connection.count' + + def test_init_connection_count_callback_aggregates_registry_callbacks( + self, setup_connection_count_recorder + ): + """Test that the observable callback aggregates all registered pool callbacks.""" + mock_meter = setup_connection_count_recorder + + recorder.init_connection_count() + + # Get the callback that was passed to create_observable_gauge + call_args = mock_meter.create_observable_gauge.call_args + observable_callback = call_args[1]['callbacks'][0] + + # Register some mock pool callbacks + from opentelemetry.metrics import Observation + + mock_observation1 = Observation(5, attributes={"pool": "pool1"}) + mock_observation2 = Observation(3, attributes={"pool": "pool2"}) + + pool_callback1 = MagicMock(return_value=[mock_observation1]) + pool_callback2 = MagicMock(return_value=[mock_observation2]) + + registry = get_observables_registry_instance() + registry.register(recorder.CONNECTION_COUNT_REGISTRY_KEY, pool_callback1) + registry.register(recorder.CONNECTION_COUNT_REGISTRY_KEY, pool_callback2) + + # Call the observable callback + observations = observable_callback(None) + + # Verify both pool callbacks were called and observations aggregated + pool_callback1.assert_called_once() + pool_callback2.assert_called_once() + assert len(observations) == 2 + assert mock_observation1 in observations + assert mock_observation2 in observations + + def test_register_pools_connection_count_adds_callback_to_registry( + self, setup_connection_count_recorder + ): + """Test that register_pools_connection_count adds a callback to the registry.""" + # Create mock connection pools + mock_pool1 = MagicMock() + mock_pool1.get_connection_count.return_value = [ + (5, {"state": "idle"}), + (2, {"state": "used"}), + ] + + mock_pool2 = MagicMock() + mock_pool2.get_connection_count.return_value = [ + (3, {"state": "idle"}), + ] + + recorder.register_pools_connection_count([mock_pool1, mock_pool2]) + + registry = get_observables_registry_instance() + callbacks = registry.get(recorder.CONNECTION_COUNT_REGISTRY_KEY) + + assert len(callbacks) == 1 + + # Call the registered callback and verify it returns observations + observations = callbacks[0]() + + assert len(observations) == 3 + mock_pool1.get_connection_count.assert_called_once() + mock_pool2.get_connection_count.assert_called_once() + + +class TestInitCSCItems: + """Tests for init_csc_items and register_csc_items_callback.""" + + @pytest.fixture + def mock_observable_gauge(self): + """Create a mock observable gauge.""" + return MagicMock() + + @pytest.fixture + def mock_meter_with_observable(self, mock_observable_gauge): + """Create a mock meter that returns our mock observable gauge.""" + meter = MagicMock() + meter.create_observable_gauge.return_value = mock_observable_gauge + return meter + + @pytest.fixture + def mock_config_with_csc(self): + """Create a config with CSC metric group enabled.""" + return OTelConfig(metric_groups=[MetricGroup.CSC]) + + @pytest.fixture + def setup_csc_recorder(self, mock_meter_with_observable, mock_config_with_csc): + """Setup recorder with mocked meter for CSC tests.""" + recorder.reset_collector() + get_observables_registry_instance().clear() + + with patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector( + mock_meter_with_observable, mock_config_with_csc + ) + + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + yield mock_meter_with_observable + + recorder.reset_collector() + get_observables_registry_instance().clear() + + def test_init_csc_items_creates_observable_gauge(self, setup_csc_recorder): + """Test that init_csc_items creates an observable gauge.""" + mock_meter = setup_csc_recorder + + recorder.init_csc_items() + + mock_meter.create_observable_gauge.assert_called_once() + call_args = mock_meter.create_observable_gauge.call_args + assert call_args[1]['name'] == 'redis.client.csc.items' + + def test_init_csc_items_callback_aggregates_registry_callbacks( + self, setup_csc_recorder + ): + """Test that the CSC observable callback aggregates all registered callbacks.""" + mock_meter = setup_csc_recorder + + recorder.init_csc_items() + + # Get the callback that was passed to create_observable_gauge + call_args = mock_meter.create_observable_gauge.call_args + observable_callback = call_args[1]['callbacks'][0] + + # Register some mock CSC callbacks + from opentelemetry.metrics import Observation + + mock_observation1 = Observation(100, attributes={"db": 0}) + mock_observation2 = Observation(50, attributes={"db": 1}) + + csc_callback1 = MagicMock(return_value=[mock_observation1]) + csc_callback2 = MagicMock(return_value=[mock_observation2]) + + registry = get_observables_registry_instance() + registry.register(recorder.CSC_ITEMS_REGISTRY_KEY, csc_callback1) + registry.register(recorder.CSC_ITEMS_REGISTRY_KEY, csc_callback2) + + # Call the observable callback + observations = observable_callback(None) + + # Verify both callbacks were called and observations aggregated + csc_callback1.assert_called_once() + csc_callback2.assert_called_once() + assert len(observations) == 2 + assert mock_observation1 in observations + assert mock_observation2 in observations + + def test_register_csc_items_callback_adds_callback_to_registry( + self, setup_csc_recorder + ): + """Test that register_csc_items_callback adds a callback to the registry.""" + # Create a mock cache size callback + cache_size_callback = MagicMock(return_value=42) + + recorder.register_csc_items_callback(cache_size_callback, db_namespace=0) + + registry = get_observables_registry_instance() + callbacks = registry.get(recorder.CSC_ITEMS_REGISTRY_KEY) + + assert len(callbacks) == 1 + + # Call the registered callback and verify it returns an observation + observations = callbacks[0]() + + assert len(observations) == 1 + assert observations[0].value == 42 + cache_size_callback.assert_called_once() + + def test_register_csc_items_callback_multiple_registrations( + self, setup_csc_recorder + ): + """Test registering multiple CSC callbacks.""" + callback1 = MagicMock(return_value=10) + callback2 = MagicMock(return_value=20) + + recorder.register_csc_items_callback(callback1, db_namespace=0) + recorder.register_csc_items_callback(callback2, db_namespace=1) + + registry = get_observables_registry_instance() + callbacks = registry.get(recorder.CSC_ITEMS_REGISTRY_KEY) + + assert len(callbacks) == 2 + + # Verify each callback returns correct observation + obs1 = callbacks[0]() + obs2 = callbacks[1]() + + assert obs1[0].value == 10 + assert obs2[0].value == 20 + + +class TestObservableGaugeIntegration: + """Integration tests for observable gauge pattern with registry.""" + + @pytest.fixture + def clean_registry(self): + """Ensure clean registry before and after test.""" + get_observables_registry_instance().clear() + yield + get_observables_registry_instance().clear() + + def test_full_observable_gauge_flow(self, clean_registry): + """Test the complete flow: init -> register -> callback invocation.""" + from opentelemetry.metrics import Observation + + # Create mock meter and collector + mock_meter = MagicMock() + captured_callback = None + + def capture_callback(name, **kwargs): + nonlocal captured_callback + captured_callback = kwargs.get('callbacks', [None])[0] + return MagicMock() + + mock_meter.create_observable_gauge.side_effect = capture_callback + mock_meter.create_counter.return_value = MagicMock() + mock_meter.create_histogram.return_value = MagicMock() + mock_meter.create_up_down_counter.return_value = MagicMock() + + config = OTelConfig(metric_groups=[MetricGroup.CONNECTION_BASIC]) + + recorder.reset_collector() + + with patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector(mock_meter, config) + + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + # Step 1: Initialize the observable gauge + recorder.init_connection_count() + + # Step 2: Register pool callbacks + mock_pool = MagicMock() + mock_pool.get_connection_count.return_value = [ + (5, {"state": "idle", "pool": "pool1"}), + ] + recorder.register_pools_connection_count([mock_pool]) + + # Step 3: Simulate OTel calling the observable callback + assert captured_callback is not None + observations = captured_callback(None) + + # Verify the observation was created correctly + assert len(observations) == 1 + assert observations[0].value == 5 + + recorder.reset_collector() + + def test_observable_callback_handles_empty_registry(self, clean_registry): + """Test that observable callback handles empty registry gracefully.""" + mock_meter = MagicMock() + captured_callback = None + + def capture_callback(name, **kwargs): + nonlocal captured_callback + captured_callback = kwargs.get('callbacks', [None])[0] + return MagicMock() + + mock_meter.create_observable_gauge.side_effect = capture_callback + mock_meter.create_counter.return_value = MagicMock() + mock_meter.create_histogram.return_value = MagicMock() + mock_meter.create_up_down_counter.return_value = MagicMock() + + config = OTelConfig(metric_groups=[MetricGroup.CONNECTION_BASIC]) + + recorder.reset_collector() + + with patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector(mock_meter, config) + + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + recorder.init_connection_count() + + # Don't register any pools - registry is empty + assert captured_callback is not None + observations = captured_callback(None) + + # Should return empty list, not raise an error + assert observations == [] + + recorder.reset_collector() + + def test_multiple_pools_aggregated_correctly(self, clean_registry): + """Test that observations from multiple pools are aggregated correctly.""" + mock_meter = MagicMock() + captured_callback = None + + def capture_callback(name, **kwargs): + nonlocal captured_callback + captured_callback = kwargs.get('callbacks', [None])[0] + return MagicMock() + + mock_meter.create_observable_gauge.side_effect = capture_callback + mock_meter.create_counter.return_value = MagicMock() + mock_meter.create_histogram.return_value = MagicMock() + mock_meter.create_up_down_counter.return_value = MagicMock() + + config = OTelConfig(metric_groups=[MetricGroup.CONNECTION_BASIC]) + + recorder.reset_collector() + + with patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector(mock_meter, config) + + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + recorder.init_connection_count() + + # Register multiple pools in separate calls + mock_pool1 = MagicMock() + mock_pool1.get_connection_count.return_value = [ + (5, {"state": "idle"}), + (2, {"state": "used"}), + ] + + mock_pool2 = MagicMock() + mock_pool2.get_connection_count.return_value = [ + (3, {"state": "idle"}), + ] + + recorder.register_pools_connection_count([mock_pool1]) + recorder.register_pools_connection_count([mock_pool2]) + + # Simulate OTel calling the observable callback + observations = captured_callback(None) + + # Should have 3 observations total (2 from pool1, 1 from pool2) + assert len(observations) == 3 + + recorder.reset_collector() From 09b9abb4642f220539b046750babe348055f7611 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Sun, 18 Jan 2026 13:35:27 +0200 Subject: [PATCH 5/5] Fixed case with trackin non-cachable cache tracking --- redis/connection.py | 56 ++++++++++++++++--------------- redis/event.py | 10 +++--- redis/observability/attributes.py | 2 +- tests/test_cache.py | 8 ++--- tests/test_connection.py | 6 ++-- 5 files changed, 42 insertions(+), 40 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 524410e5e6..49c2f0a12d 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -37,7 +37,7 @@ from .credentials import CredentialProvider, UsernamePasswordCredentialProvider from .event import AfterConnectionReleasedEvent, EventDispatcher, OnErrorEvent, OnMaintenanceNotificationEvent, \ AfterConnectionCreatedEvent, AfterConnectionAcquiredEvent, AfterConnectionClosedEvent, OnCacheHitEvent, \ - OnCacheMissEvent, OnCacheEvictionEvent, OnCacheInitialisationEvent + OnCacheMissEvent, OnCacheEvictionEvent, OnCacheInitializationEvent from .exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, @@ -1544,29 +1544,28 @@ def read_response( ): with self._cache_lock: # Check if command response exists in a cache and it's not in progress. - if ( - self._current_command_cache_key is not None - and self._cache.get(self._current_command_cache_key) is not None - and self._cache.get(self._current_command_cache_key).status - != CacheEntryStatus.IN_PROGRESS - ): - res = copy.deepcopy( - self._cache.get(self._current_command_cache_key).cache_value - ) - self._current_command_cache_key = None + if self._current_command_cache_key is not None: + if ( + self._cache.get(self._current_command_cache_key) is not None + and self._cache.get(self._current_command_cache_key).status + != CacheEntryStatus.IN_PROGRESS + ): + res = copy.deepcopy( + self._cache.get(self._current_command_cache_key).cache_value + ) + self._current_command_cache_key = None + self._event_dispatcher.dispatch( + OnCacheHitEvent( + bytes_saved=len(res), + db_namespace=self.db, + ) + ) + return res self._event_dispatcher.dispatch( - OnCacheHitEvent( - bytes_saved=len(res), + OnCacheMissEvent( db_namespace=self.db, ) ) - return res - - self._event_dispatcher.dispatch( - OnCacheMissEvent( - db_namespace=self.db, - ) - ) response = self._conn.read_response( disable_decoding=disable_decoding, @@ -1730,13 +1729,16 @@ def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]] if data[1] is None: self._cache.flush() else: - self._cache.delete_by_redis_keys(data[1]) - self._event_dispatcher.dispatch( - OnCacheEvictionEvent( - count=len(data[1]), - reason=CSCReason.INVALIDATION, + keys_deleted = self._cache.delete_by_redis_keys(data[1]) + + if len(keys_deleted) > 0: + self._event_dispatcher.dispatch( + OnCacheEvictionEvent( + count=len(data[1]), + reason=CSCReason.INVALIDATION, + db_namespace=self.db, + ) ) - ) class SSLConnection(Connection): @@ -2571,7 +2573,7 @@ def __init__( ).get_cache() self._event_dispatcher.dispatch( - OnCacheInitialisationEvent( + OnCacheInitializationEvent( cache_items_callback=lambda: self.cache.size, db_namespace=self._connection_kwargs.get("db"), ) diff --git a/redis/event.py b/redis/event.py index e060287769..3310efe425 100644 --- a/redis/event.py +++ b/redis/event.py @@ -127,7 +127,7 @@ def __init__( OnMaintenanceNotificationEvent: [ ExportMaintenanceNotificationCountMetric(), ], - OnCacheInitialisationEvent: [InitialiseCSCItemsObservability()], + OnCacheInitializationEvent: [InitializeCSCItemsObservability()], OnCacheEvictionEvent: [ExportCSCEvictionMetric()], OnCacheHitEvent: [ExportCSCNetworkSavedMetric(), ExportCSCRequestMetric()], OnCacheMissEvent: [ExportCSCRequestMetric()], @@ -441,7 +441,7 @@ class OnCacheMissEvent: db_namespace: Optional[int] = None @dataclass -class OnCacheInitialisationEvent: +class OnCacheInitializationEvent: """ Event fired after cache is initialized. """ @@ -455,6 +455,7 @@ class OnCacheEvictionEvent: """ count: int reason: CSCReason + db_namespace: Optional[int] = None class AsyncOnCommandsFailEvent(OnCommandsFailEvent): pass @@ -803,11 +804,11 @@ def listen(self, event: Union[OnCacheHitEvent, OnCacheMissEvent]): result=result, ) -class InitialiseCSCItemsObservability(EventListenerInterface): +class InitializeCSCItemsObservability(EventListenerInterface): """ Listener that initializes CSC items observability. """ - def listen(self, event: OnCacheInitialisationEvent): + def listen(self, event: OnCacheInitializationEvent): # Initialize gauge only once, subsequent calls won't have an affect. init_csc_items() @@ -822,6 +823,7 @@ def listen(self, event: OnCacheEvictionEvent): record_csc_eviction( count=event.count, reason=event.reason, + db_namespace=event.db_namespace, ) class ExportCSCNetworkSavedMetric(EventListenerInterface): diff --git a/redis/observability/attributes.py b/redis/observability/attributes.py index 2e62a0ea10..5bd392d86e 100644 --- a/redis/observability/attributes.py +++ b/redis/observability/attributes.py @@ -296,7 +296,7 @@ def build_csc_attributes( Args: db_namespace: Redis database index result: CSC result ('hit' or 'miss') - reason: Reason for CSC miss ('full' or 'invalidation') + reason: Reason for CSC eviction ('full' or 'invalidation') Returns: Dictionary of CSC attributes diff --git a/tests/test_cache.py b/tests/test_cache.py index 60d6650363..77e83c0962 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -19,7 +19,7 @@ EventDispatcher, EventListenerInterface, OnCacheEvictionEvent, - OnCacheInitialisationEvent, + OnCacheInitializationEvent, ) from redis.observability.attributes import CSCReason from tests.conftest import _get_client, skip_if_resp_version, skip_if_server_version_lt @@ -1534,14 +1534,14 @@ def test_initialization_emits_cache_initialisation_event(self, mock_cache): event_dispatcher = EventDispatcher() listener = MagicMock(spec=EventListenerInterface) event_dispatcher.register_listeners({ - OnCacheInitialisationEvent: [listener], + OnCacheInitializationEvent: [listener], }) CacheProxy(mock_cache, event_dispatcher) listener.listen.assert_called_once() event = listener.listen.call_args[0][0] - assert isinstance(event, OnCacheInitialisationEvent) + assert isinstance(event, OnCacheInitializationEvent) assert callable(event.cache_items_callback) def test_initialization_event_callback_returns_cache_size( @@ -1551,7 +1551,7 @@ def test_initialization_event_callback_returns_cache_size( event_dispatcher = EventDispatcher() listener = MagicMock(spec=EventListenerInterface) event_dispatcher.register_listeners({ - OnCacheInitialisationEvent: [listener], + OnCacheInitializationEvent: [listener], }) proxy = CacheProxy(mock_cache, event_dispatcher) diff --git a/tests/test_connection.py b/tests/test_connection.py index 67f96f00be..735ef19d27 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -699,7 +699,7 @@ def test_cache_miss_event_emitted_on_uncached_response(self, mock_connection): platform.python_implementation() == "PyPy", reason="Pypy doesn't support side_effect", ) - def test_cache_miss_event_emitted_for_non_cachable_command(self, mock_connection): + def test_cache_miss_event_not_emitted_for_non_cachable_command(self, mock_connection): """Test that OnCacheMissEvent is emitted for non-cachable commands.""" cache = DefaultCache(CacheConfig(max_size=10)) event_dispatcher = EventDispatcher() @@ -726,9 +726,7 @@ def test_cache_miss_event_emitted_for_non_cachable_command(self, mock_connection result = proxy_connection.read_response() assert result == b"OK" - cache_miss_listener.listen.assert_called_once() - event = cache_miss_listener.listen.call_args[0][0] - assert isinstance(event, OnCacheMissEvent) + cache_miss_listener.listen.assert_not_called() @pytest.mark.skipif( platform.python_implementation() == "PyPy",