Skip to content

Fix various commands to work with Redis Cluster #338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ jobs:
make test-all

test:
name: Python ${{ matrix.python-version }} - ${{ matrix.connection }} [redis ${{ matrix.redis-version }}]
name: Python ${{ matrix.python-version }} - ${{ matrix.connection }} - redis-py ${{ matrix.redis-py-version }} [redis ${{ matrix.redis-version }}]
runs-on: ubuntu-latest
needs: service-tests
env:
Expand All @@ -89,6 +89,7 @@ jobs:
# 3.11 tests are run in the service-tests job
python-version: ["3.9", "3.10", 3.12, 3.13]
connection: ["hiredis", "plain"]
redis-py-version: ["5.x", "6.x"]
redis-version: ["6.2.6-v9", "latest", "8.0-M03"]

steps:
Expand Down Expand Up @@ -116,6 +117,14 @@ jobs:
run: |
poetry install --all-extras

- name: Install specific redis-py version
run: |
if [[ "${{ matrix.redis-py-version }}" == "5.x" ]]; then
poetry add "redis>=5.0.0,<6.0.0"
else
poetry add "redis>=6.0.0,<7.0.0"
fi

- name: Install hiredis if needed
if: matrix.connection == 'hiredis'
run: |
Expand All @@ -135,7 +144,7 @@ jobs:
credentials_json: ${{ secrets.GOOGLE_CREDENTIALS }}

- name: Run tests
if: matrix.connection == 'plain' && matrix.redis-version == 'latest'
if: matrix.connection == 'plain' && matrix.redis-py-version == '6.x' && matrix.redis-version == 'latest'
env:
HF_HOME: ${{ github.workspace }}/hf_cache
GCP_LOCATION: ${{ secrets.GCP_LOCATION }}
Expand All @@ -144,12 +153,12 @@ jobs:
make test

- name: Run tests (alternate)
if: matrix.connection != 'plain' || matrix.redis-version != 'latest'
if: matrix.connection != 'plain' || matrix.redis-py-version != '6.x' || matrix.redis-version != 'latest'
run: |
make test

- name: Run notebooks
if: matrix.connection == 'plain' && matrix.redis-version == 'latest'
if: matrix.connection == 'plain' && matrix.redis-py-version == '6.x' && matrix.redis-version == 'latest'
env:
HF_HOME: ${{ github.workspace }}/hf_cache
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
Expand Down
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,17 @@ pyrightconfig.json
[Ll]ocal
pyvenv.cfg
pip-selfcheck.json
env
venv
.venv

libs/redis/docs/.Trash*
.python-version
.idea/*
.vscode/settings.json
.python-version
tests/data
.git
.cursor
.junie
.undodir
50 changes: 19 additions & 31 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ numpy = [
{ version = ">=1.26.0,<3", python = ">=3.12" },
]
pyyaml = ">=5.4,<7.0"
redis = "^5.0"
redis = ">=5.0,<7.0"
pydantic = "^2"
tenacity = ">=8.2.2"
ml-dtypes = ">=0.4.0,<1.0.0"
Expand Down Expand Up @@ -68,8 +68,8 @@ pytest-xdist = {extras = ["psutil"], version = "^3.6.1"}
pre-commit = "^4.1.0"
mypy = "1.9.0"
nbval = "^0.11.0"
types-redis = "*"
types-pyyaml = "*"
types-pyopenssl = "*"
testcontainers = "^4.3.1"
cryptography = { version = ">=44.0.1", markers = "python_version > '3.9.1'" }

Expand Down
85 changes: 54 additions & 31 deletions redisvl/extensions/cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
specific cache types such as LLM caches and embedding caches.
"""

from typing import Any, Dict, Optional
from collections.abc import Mapping
from typing import Any, Dict, Optional, Union

from redis import Redis
from redis.asyncio import Redis as AsyncRedis
from redis import Redis # For backwards compatibility in type checking
from redis.cluster import RedisCluster

from redisvl.redis.connection import RedisConnectionFactory
from redisvl.types import AsyncRedisClient, SyncRedisClient, SyncRedisCluster


class BaseCache:
Expand All @@ -19,14 +21,15 @@ class BaseCache:
including TTL management, connection handling, and basic cache operations.
"""

_redis_client: Optional[Redis]
_async_redis_client: Optional[AsyncRedis]
_redis_client: Optional[SyncRedisClient]
_async_redis_client: Optional[AsyncRedisClient]

def __init__(
self,
name: str,
ttl: Optional[int] = None,
redis_client: Optional[Redis] = None,
redis_client: Optional[SyncRedisClient] = None,
async_redis_client: Optional[AsyncRedisClient] = None,
redis_url: str = "redis://localhost:6379",
connection_kwargs: Dict[str, Any] = {},
):
Expand All @@ -36,7 +39,7 @@ def __init__(
name (str): The name of the cache.
ttl (Optional[int], optional): The time-to-live for records cached
in Redis. Defaults to None.
redis_client (Optional[Redis], optional): A redis client connection instance.
redis_client (Optional[SyncRedisClient], optional): A redis client connection instance.
Defaults to None.
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
connection_kwargs (Dict[str, Any]): The connection arguments
Expand All @@ -53,14 +56,13 @@ def __init__(
}

# Initialize Redis clients
self._async_redis_client = None
self._async_redis_client = async_redis_client
self._redis_client = redis_client

if redis_client:
if redis_client or async_redis_client:
self._owns_redis_client = False
self._redis_client = redis_client
else:
self._owns_redis_client = True
self._redis_client = None # type: ignore

def _get_prefix(self) -> str:
"""Get the key prefix for Redis keys.
Expand Down Expand Up @@ -103,11 +105,11 @@ def set_ttl(self, ttl: Optional[int] = None) -> None:
else:
self._ttl = None

def _get_redis_client(self) -> Redis:
def _get_redis_client(self) -> SyncRedisClient:
"""Get or create a Redis client.

Returns:
Redis: A Redis client instance.
SyncRedisClient: A Redis client instance.
"""
if self._redis_client is None:
# Create new Redis client
Expand All @@ -116,22 +118,29 @@ def _get_redis_client(self) -> Redis:
self._redis_client = Redis.from_url(url, **kwargs) # type: ignore
return self._redis_client

async def _get_async_redis_client(self) -> AsyncRedis:
async def _get_async_redis_client(self) -> AsyncRedisClient:
"""Get or create an async Redis client.

Returns:
AsyncRedis: An async Redis client instance.
AsyncRedisClient: An async Redis client instance.
"""
if not hasattr(self, "_async_redis_client") or self._async_redis_client is None:
client = self.redis_kwargs.get("redis_client")
if isinstance(client, Redis):

if client and isinstance(client, (Redis, RedisCluster)):
self._async_redis_client = RedisConnectionFactory.sync_to_async_redis(
client
)
else:
url = self.redis_kwargs["redis_url"]
kwargs = self.redis_kwargs["connection_kwargs"]
self._async_redis_client = RedisConnectionFactory.get_async_redis_connection(url, **kwargs) # type: ignore
url = str(self.redis_kwargs["redis_url"])
kwargs = self.redis_kwargs.get("connection_kwargs", {})
if not isinstance(kwargs, Mapping):
raise ValueError(
f"connection_kwargs must be a mapping, got {type(kwargs)}"
)
self._async_redis_client = (
RedisConnectionFactory.get_async_redis_connection(url, **kwargs)
)
return self._async_redis_client

def expire(self, key: str, ttl: Optional[int] = None) -> None:
Expand Down Expand Up @@ -183,7 +192,14 @@ def clear(self) -> None:
client.delete(*keys)
if cursor_int == 0: # Redis returns 0 when scan is complete
break
cursor = cursor_int # Update cursor for next iteration
# Cluster returns a dict of cursor values. We need to stop if these all
# come back as 0.
elif isinstance(cursor_int, Mapping):
cursor_values = list(cursor_int.values())
if all(v == 0 for v in cursor_values):
break
else:
cursor = cursor_int # Update cursor for next iteration

async def aclear(self) -> None:
"""Async clear the cache of all keys."""
Expand All @@ -193,12 +209,21 @@ async def aclear(self) -> None:
# Scan for all keys with our prefix
cursor = 0 # Start with cursor 0
while True:
cursor_int, keys = await client.scan(cursor=cursor, match=f"{prefix}*", count=100) # type: ignore
cursor_int, keys = await client.scan(
cursor=cursor, match=f"{prefix}*", count=100
) # type: ignore
if keys:
await client.delete(*keys)
if cursor_int == 0: # Redis returns 0 when scan is complete
break
cursor = cursor_int # Update cursor for next iteration
# Cluster returns a dict of cursor values. We need to stop if these all
# come back as 0.
elif isinstance(cursor_int, Mapping):
cursor_values = list(cursor_int.values())
if all(v == 0 for v in cursor_values):
break
else:
cursor = cursor_int # Update cursor for next iteration

def disconnect(self) -> None:
"""Disconnect from Redis."""
Expand All @@ -207,12 +232,10 @@ def disconnect(self) -> None:

if self._redis_client:
self._redis_client.close()
self._redis_client = None # type: ignore

if hasattr(self, "_async_redis_client") and self._async_redis_client:
# Use synchronous close for async client in synchronous context
self._async_redis_client.close() # type: ignore
self._async_redis_client = None # type: ignore
self._redis_client = None
# Async clients don't have a sync close method, so we just
# zero them out to allow garbage collection.
self._async_redis_client = None

async def adisconnect(self) -> None:
"""Async disconnect from Redis."""
Expand All @@ -221,9 +244,9 @@ async def adisconnect(self) -> None:

if self._redis_client:
self._redis_client.close()
self._redis_client = None # type: ignore
self._redis_client = None

if hasattr(self, "_async_redis_client") and self._async_redis_client:
# Use proper async close method
await self._async_redis_client.aclose() # type: ignore
self._async_redis_client = None # type: ignore
await self._async_redis_client.aclose()
self._async_redis_client = None
Loading