Skip to content
6 changes: 4 additions & 2 deletions dogpile/cache/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,8 @@ def get_serialized_multi(
:meth:`.CacheRegion.get_multi` method, which will also be processed
by the "key mangling" function if one was present.

:return: list of bytes objects
:return: list of bytes objects or the :data:`.NO_VALUE` contant
if not present.

The default implementation of this method for :class:`.CacheBackend`
returns the value of the :meth:`.CacheBackend.get_multi` method.
Expand Down Expand Up @@ -543,7 +544,8 @@ def get_serialized_multi(
:meth:`.CacheRegion.get_multi` method, which will also be processed
by the "key mangling" function if one was present.

:return: list of bytes objects
:return: list of bytes objects or the :data:`.NO_VALUE`
constant if not present.

.. versionadded:: 1.1

Expand Down
1 change: 0 additions & 1 deletion dogpile/cache/backends/memcached.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from ..api import NO_VALUE
from ... import util


if typing.TYPE_CHECKING:
import bmemcached
import memcache
Expand Down
101 changes: 75 additions & 26 deletions dogpile/cache/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,28 @@

"""

import typing
from __future__ import annotations

from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypedDict
import warnings

from typing_extensions import NotRequired

from ..api import BytesBackend
from ..api import KeyType
from ..api import NO_VALUE
from ..api import SerializedReturnType

if typing.TYPE_CHECKING:
if TYPE_CHECKING:
import redis
else:
# delayed import
Expand All @@ -21,6 +36,41 @@
__all__ = ("RedisBackend", "RedisSentinelBackend", "RedisClusterBackend")


class RedisKwargs(TypedDict):
"""
TypedDict of kwargs for `RedisBackend` and derived classes
.. versionadded:: 1.4.1
"""

url: NotRequired[str]
host: NotRequired[str]
username: NotRequired[Optional[str]]
password: NotRequired[Optional[str]]
port: NotRequired[int]
db: NotRequired[int]
redis_expiration_time: NotRequired[int]
distributed_lock: NotRequired[bool]
lock_timeout: NotRequired[int]
socket_timeout: NotRequired[float]
socket_connect_timeout: NotRequired[float]
socket_keepalive: NotRequired[bool]
socket_keepalive_options: NotRequired[Dict]
lock_sleep: NotRequired[int]
connection_pool: NotRequired["redis.ConnectionPool"]
thread_local_lock: NotRequired[bool]
connection_kwargs: NotRequired[Dict[str, Any]]


class RedisKwargs_Sentinel(RedisKwargs):
sentinels: List[Tuple[str, str]]
service_name: NotRequired[str]
sentinel_kwargs: NotRequired[Dict[str, Any]]


class RedisKwargs_Cluster(RedisKwargs):
startup_nodes: List["redis.cluster.ClusterNode"]


class RedisBackend(BytesBackend):
r"""A `Redis <http://redis.io/>`_ backend, using the
`redis-py <http://pypi.python.org/pypi/redis/>`_ driver.
Expand Down Expand Up @@ -114,12 +164,9 @@ class RedisBackend(BytesBackend):

.. versionadded:: 1.1.6




"""

def __init__(self, arguments):
def __init__(self, arguments: RedisKwargs):
arguments = arguments.copy()
self._imports()
self.url = arguments.pop("url", None)
Expand Down Expand Up @@ -152,12 +199,12 @@ def __init__(self, arguments):
self.connection_pool = arguments.pop("connection_pool", None)
self._create_client()

def _imports(self):
def _imports(self) -> None:
# defer imports until backend is used
global redis
import redis # noqa

def _create_client(self):
def _create_client(self) -> None:
if self.connection_pool is not None:
# the connection pool already has all other connection
# options present within, so here we disregard socket_timeout
Expand Down Expand Up @@ -195,7 +242,7 @@ def _create_client(self):
self.writer_client = redis.StrictRedis(**args)
self.reader_client = self.writer_client

def get_mutex(self, key):
def get_mutex(self, key: KeyType) -> Optional[_RedisLockWrapper]:
if self.distributed_lock:
return _RedisLockWrapper(
self.writer_client.lock(
Expand All @@ -208,25 +255,27 @@ def get_mutex(self, key):
else:
return None

def get_serialized(self, key):
def get_serialized(self, key: KeyType) -> SerializedReturnType:
value = self.reader_client.get(key)
if value is None:
return NO_VALUE
return value
return cast(SerializedReturnType, value)

def get_serialized_multi(self, keys):
def get_serialized_multi(
self, keys: Sequence[KeyType]
) -> Sequence[SerializedReturnType]:
if not keys:
return []
values = self.reader_client.mget(keys)
return [v if v is not None else NO_VALUE for v in values]

def set_serialized(self, key, value):
def set_serialized(self, key: KeyType, value: bytes) -> None:
if self.redis_expiration_time:
self.writer_client.setex(key, self.redis_expiration_time, value)
else:
self.writer_client.set(key, value)

def set_serialized_multi(self, mapping):
def set_serialized_multi(self, mapping: Mapping[KeyType, bytes]) -> None:
if not self.redis_expiration_time:
self.writer_client.mset(mapping)
else:
Expand All @@ -235,23 +284,23 @@ def set_serialized_multi(self, mapping):
pipe.setex(key, self.redis_expiration_time, value)
pipe.execute()

def delete(self, key):
def delete(self, key: KeyType) -> None:
self.writer_client.delete(key)

def delete_multi(self, keys):
def delete_multi(self, keys: Sequence[KeyType]) -> None:
self.writer_client.delete(*keys)


class _RedisLockWrapper:
__slots__ = ("mutex", "__weakref__")

def __init__(self, mutex: typing.Any):
def __init__(self, mutex: Any):
self.mutex = mutex

def acquire(self, wait: bool = True) -> typing.Any:
def acquire(self, wait: bool = True) -> Any:
return self.mutex.acquire(blocking=wait)

def release(self) -> typing.Any:
def release(self) -> Any:
return self.mutex.release()

def locked(self) -> bool:
Expand Down Expand Up @@ -356,7 +405,7 @@ class RedisSentinelBackend(RedisBackend):

"""

def __init__(self, arguments):
def __init__(self, arguments: RedisKwargs_Sentinel):
arguments = arguments.copy()

self.sentinels = arguments.pop("sentinels", None)
Expand All @@ -371,7 +420,7 @@ def __init__(self, arguments):
}
)

def _imports(self):
def _imports(self) -> None:
# defer imports until backend is used
global redis
import redis.sentinel # noqa
Expand Down Expand Up @@ -545,17 +594,17 @@ class RedisClusterBackend(RedisBackend):

"""

def __init__(self, arguments):
def __init__(self, arguments: RedisKwargs_Cluster):
arguments = arguments.copy()
self.startup_nodes = arguments.pop("startup_nodes", None)
super().__init__(arguments)

def _imports(self):
def _imports(self) -> None:
global redis
import redis.cluster

def _create_client(self):
redis_cluster: redis.cluster.RedisCluster[typing.Any]
def _create_client(self) -> None:
redis_cluster: redis.cluster.RedisCluster[Any]
if self.url is not None:
redis_cluster = redis.cluster.RedisCluster.from_url(
self.url, **self.connection_kwargs
Expand All @@ -565,5 +614,5 @@ def _create_client(self):
startup_nodes=self.startup_nodes,
**self.connection_kwargs,
)
self.writer_client = typing.cast("redis.Redis[bytes]", redis_cluster)
self.writer_client = cast("redis.Redis[bytes]", redis_cluster)
self.reader_client = self.writer_client
10 changes: 5 additions & 5 deletions dogpile/cache/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,24 @@ class ProxyBackend(CacheBackend):
from dogpile.cache.proxy import ProxyBackend

class MyFirstProxy(ProxyBackend):
def get_serialized(self, key):
def get_serialized(self, key: KeyType) -> SerializedReturnType:
# ... custom code goes here ...
return self.proxied.get_serialized(key)

def get(self, key):
def get(self, key: KeyType) -> BackendFormatted:
# ... custom code goes here ...
return self.proxied.get(key)

def set(self, key, value):
def set(self, key: KeyType, value: BackendSetType) -> None:
# ... custom code goes here ...
self.proxied.set(key)

class MySecondProxy(ProxyBackend):
def get_serialized(self, key):
def get_serialized(self, key: KeyType) -> SerializedReturnType:
# ... custom code goes here ...
return self.proxied.get_serialized(key)

def get(self, key):
def get(self, key: KeyType) -> BackendFormatted:
# ... custom code goes here ...
return self.proxied.get(key)

Expand Down
Loading