diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index 940cefd2..281bdc3e 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -36,27 +36,4 @@ jobs: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} # Define which tools Claude can use - allowed_tools: | - Bash(git status) - Bash(git log) - Bash(git show) - Bash(git blame) - Bash(git reflog) - Bash(git stash list) - Bash(git ls-files) - Bash(git branch) - Bash(git tag) - Bash(git diff) - Bash(make:*) - Bash(pytest:*) - Bash(cd:*) - Bash(ls:*) - Bash(make) - Bash(make:*) - View - GlobTool - GrepTool - BatchTool - - # Your Anthropic API key (stored as a GitHub secret) - anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} + allowed_tools: "Bash(git status),Bash(git log),Bash(git show),Bash(git blame),Bash(git reflog),Bash(git stash list),Bash(git ls-files),Bash(git branch),Bash(git tag),Bash(git diff),Bash(make:*),Bash(pytest:*),Bash(cd:*),Bash(ls:*),Bash(make),Bash(make:*),View,GlobTool,GrepTool,BatchTool" diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c006b8f3..008b2e7b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -78,7 +78,7 @@ jobs: make test-all test: - name: Python ${{ matrix.python-version }} - ${{ matrix.connection }} [redis ${{ matrix.redis-version }}] + name: Python ${{ matrix.python-version }} - ${{ matrix.connection }} - redis-py ${{ matrix.redis-py-version }} [redis ${{ matrix.redis-version }}] runs-on: ubuntu-latest needs: service-tests env: @@ -89,7 +89,8 @@ jobs: # 3.11 tests are run in the service-tests job python-version: ["3.9", "3.10", 3.12, 3.13] connection: ["hiredis", "plain"] - redis-version: ["6.2.6-v9", "latest", "8.0-M03"] + redis-py-version: ["5.x", "6.x"] + redis-version: ["6.2.6-v9", "latest", "8.0.1"] steps: - name: Check out repository @@ -116,6 +117,14 @@ jobs: run: | poetry install --all-extras + - name: Install specific redis-py version + run: | + if [[ "${{ matrix.redis-py-version }}" == "5.x" ]]; then + poetry add "redis>=5.0.0,<6.0.0" + else + poetry add "redis>=6.0.0,<7.0.0" + fi + - name: Install hiredis if needed if: matrix.connection == 'hiredis' run: | @@ -123,7 +132,7 @@ jobs: - name: Set Redis image name run: | - if [[ "${{ matrix.redis-version }}" == "8.0-M03" ]]; then + if [[ "${{ matrix.redis-version }}" == "8.0.1" ]]; then echo "REDIS_IMAGE=redis:${{ matrix.redis-version }}" >> $GITHUB_ENV else echo "REDIS_IMAGE=redis/redis-stack-server:${{ matrix.redis-version }}" >> $GITHUB_ENV @@ -135,7 +144,7 @@ jobs: credentials_json: ${{ secrets.GOOGLE_CREDENTIALS }} - name: Run tests - if: matrix.connection == 'plain' && matrix.redis-version == 'latest' + if: matrix.connection == 'plain' && matrix.redis-py-version == '6.x' && matrix.redis-version == 'latest' env: HF_HOME: ${{ github.workspace }}/hf_cache GCP_LOCATION: ${{ secrets.GCP_LOCATION }} @@ -144,12 +153,12 @@ jobs: make test - name: Run tests (alternate) - if: matrix.connection != 'plain' || matrix.redis-version != 'latest' + if: matrix.connection != 'plain' || matrix.redis-py-version != '6.x' || matrix.redis-version != 'latest' run: | make test - name: Run notebooks - if: matrix.connection == 'plain' && matrix.redis-version == 'latest' + if: matrix.connection == 'plain' && matrix.redis-py-version == '6.x' && matrix.redis-version == 'latest' env: HF_HOME: ${{ github.workspace }}/hf_cache OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} diff --git a/.gitignore b/.gitignore index dca4f4e3..6b25de1b 100644 --- a/.gitignore +++ b/.gitignore @@ -216,9 +216,17 @@ pyrightconfig.json [Ll]ocal pyvenv.cfg pip-selfcheck.json +env +venv +.venv libs/redis/docs/.Trash* .python-version .idea/* .vscode/settings.json .python-version +tests/data +.git +.cursor +.junie +.undodir diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..9b407044 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,116 @@ +# CLAUDE.md - RedisVL Project Context + +## Frequently Used Commands + +```bash +# Development workflow +make install # Install dependencies +make format # Format code (black + isort) +make check-types # Run mypy type checking +make lint # Run all linting (format + types) +make test # Run tests (no external APIs) +make test-all # Run all tests (includes API tests) +make check # Full check (lint + test) + +# Redis setup +make redis-start # Start Redis Stack container +make redis-stop # Stop Redis Stack container + +# Documentation +make docs-build # Build documentation +make docs-serve # Serve docs locally +``` + +Pre-commit hooks are also configured, which you should +run before you commit: +```bash +pre-commit run --all-files +``` + +## Important Architectural Patterns + +### Async/Sync Dual Interfaces +- Most core classes have both sync and async versions (e.g., `SearchIndex` / `AsyncSearchIndex`) +- Follow existing patterns when adding new functionality + +### Schema-Driven Design +```python +# Index schemas define structure +schema = IndexSchema.from_yaml("schema.yaml") +index = SearchIndex(schema, redis_url="redis://localhost:6379") +``` + +## Critical Rules + +### Do Not Modify +- **CRITICAL**: Do not change this line unless explicitly asked: + ```python + token.strip().strip(",").replace(""", "").replace(""", "").lower() + ``` + +### README.md Maintenance +**IMPORTANT**: DO NOT modify README.md unless explicitly requested. + +**If you need to document something, use these alternatives:** +- Development info → CONTRIBUTING.md +- API details → docs/ directory +- Examples → docs/examples/ +- Project memory (explicit preferences, directives, etc.) → CLAUDE.md + +## Testing Notes +RedisVL uses `pytest` with `testcontainers` for testing. + +- `make test` - unit tests only (no external APIs) +- `make test-all` - includes integration tests requiring API keys + +## Project Structure + +``` +redisvl/ +├── cli/ # Command-line interface (rvl command) +├── extensions/ # AI extensions (cache, memory, routing) +│ ├── cache/ # Semantic caching for LLMs +│ ├── llmcache/ # LLM-specific caching +│ ├── message_history/ # Chat history management +│ ├── router/ # Semantic routing +│ └── session_manager/ # Session management +├── index/ # SearchIndex classes (sync/async) +├── query/ # Query builders (Vector, Range, Filter, Count) +├── redis/ # Redis client utilities +├── schema/ # Index schema definitions +└── utils/ # Utilities (vectorizers, rerankers, optimization) + ├── optimize/ # Threshold optimization + ├── rerank/ # Result reranking + └── vectorize/ # Embedding providers integration +``` + +## Core Components + +### 1. Index Management +- `SearchIndex` / `AsyncSearchIndex` - Main interface for Redis vector indices +- `IndexSchema` - Define index structure with fields (text, tags, vectors, etc.) +- Support for JSON and Hash storage types + +### 2. Query System +- `VectorQuery` - Semantic similarity search +- `RangeQuery` - Vector search within distance range +- `FilterQuery` - Metadata filtering and full-text search +- `CountQuery` - Count matching records +- Etc. + +### 3. AI Extensions +- `SemanticCache` - LLM response caching with semantic similarity +- `EmbeddingsCache` - Cache for vector embeddings +- `MessageHistory` - Chat history with recency/relevancy retrieval +- `SemanticRouter` - Route queries to topics/intents + +### 4. Vectorizers (Optional Dependencies) +- OpenAI, Azure OpenAI, Cohere, HuggingFace, Mistral, VoyageAI +- Custom vectorizer support +- Batch processing capabilities + +## Documentation +- Main docs: https://docs.redisvl.com +- Built with Sphinx from `docs/` directory +- Includes API reference and user guides +- Example notebooks in documentation `docs/user_guide/...` diff --git a/poetry.lock b/poetry.lock index d9cdd41c..9c4ba90e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5615,14 +5615,14 @@ tqdm = "*" [[package]] name = "redis" -version = "5.2.1" +version = "6.1.0" description = "Python client for Redis database and key-value store" optional = false python-versions = ">=3.8" groups = ["main"] files = [ - {file = "redis-5.2.1-py3-none-any.whl", hash = "sha256:ee7e1056b9aea0f04c6c2ed59452947f34c4940ee025f5dd83e6a6418b6989e4"}, - {file = "redis-5.2.1.tar.gz", hash = "sha256:16f2e22dff21d5125e8481515e386711a34cbec50f0e44413dd7d9c060a54e0f"}, + {file = "redis-6.1.0-py3-none-any.whl", hash = "sha256:3b72622f3d3a89df2a6041e82acd896b0e67d9f54e9bcd906d091d23ba5219f6"}, + {file = "redis-6.1.0.tar.gz", hash = "sha256:c928e267ad69d3069af28a9823a07726edf72c7e37764f43dc0123f37928c075"}, ] [package.dependencies] @@ -5630,7 +5630,8 @@ async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\ [package.extras] hiredis = ["hiredis (>=3.0.0)"] -ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"] +jwt = ["pyjwt (>=2.9.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (>=20.0.1)", "requests (>=2.31.0)"] [[package]] name = "referencing" @@ -6206,14 +6207,14 @@ train = ["accelerate (>=0.20.3)", "datasets"] name = "setuptools" version = "75.8.2" description = "Easily download, build, install, upgrade, and uninstall Python packages" -optional = true +optional = false python-versions = ">=3.9" -groups = ["main"] -markers = "extra == \"sentence-transformers\" and python_version >= \"3.12\"" +groups = ["main", "dev"] files = [ {file = "setuptools-75.8.2-py3-none-any.whl", hash = "sha256:558e47c15f1811c1fa7adbd0096669bf76c1d3f433f58324df69f3f5ecac4e8f"}, {file = "setuptools-75.8.2.tar.gz", hash = "sha256:4880473a969e5f23f2a2be3646b2dfd84af9028716d398e46192f84bc36900d2"}, ] +markers = {main = "extra == \"sentence-transformers\" and python_version >= \"3.12\""} [package.extras] check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""] @@ -7106,14 +7107,14 @@ tutorials = ["matplotlib", "pandas", "tabulate"] [[package]] name = "types-cffi" -version = "1.16.0.20241221" +version = "1.17.0.20250326" description = "Typing stubs for cffi" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["dev"] files = [ - {file = "types_cffi-1.16.0.20241221-py3-none-any.whl", hash = "sha256:e5b76b4211d7a9185f6ab8d06a106d56c7eb80af7cdb8bfcb4186ade10fb112f"}, - {file = "types_cffi-1.16.0.20241221.tar.gz", hash = "sha256:1c96649618f4b6145f58231acb976e0b448be6b847f7ab733dabe62dfbff6591"}, + {file = "types_cffi-1.17.0.20250326-py3-none-any.whl", hash = "sha256:5af4ecd7374ae0d5fa9e80864e8d4b31088cc32c51c544e3af7ed5b5ed681447"}, + {file = "types_cffi-1.17.0.20250326.tar.gz", hash = "sha256:6c8fea2c2f34b55e5fb77b1184c8ad849d57cf0ddccbc67a62121ac4b8b32254"}, ] [package.dependencies] @@ -7147,22 +7148,6 @@ files = [ {file = "types_pyyaml-6.0.12.20241230.tar.gz", hash = "sha256:7f07622dbd34bb9c8b264fe860a17e0efcad00d50b5f27e93984909d9363498c"}, ] -[[package]] -name = "types-redis" -version = "4.6.0.20241004" -description = "Typing stubs for redis" -optional = false -python-versions = ">=3.8" -groups = ["dev"] -files = [ - {file = "types-redis-4.6.0.20241004.tar.gz", hash = "sha256:5f17d2b3f9091ab75384153bfa276619ffa1cf6a38da60e10d5e6749cc5b902e"}, - {file = "types_redis-4.6.0.20241004-py3-none-any.whl", hash = "sha256:ef5da68cb827e5f606c8f9c0b49eeee4c2669d6d97122f301d3a55dc6a63f6ed"}, -] - -[package.dependencies] -cryptography = ">=35.0.0" -types-pyOpenSSL = "*" - [[package]] name = "types-requests" version = "2.31.0.6" @@ -7197,16 +7182,19 @@ urllib3 = ">=2" [[package]] name = "types-setuptools" -version = "75.8.0.20250225" +version = "80.3.0.20250505" description = "Typing stubs for setuptools" optional = false python-versions = ">=3.9" groups = ["dev"] files = [ - {file = "types_setuptools-75.8.0.20250225-py3-none-any.whl", hash = "sha256:94c86b439cc60bcc68c1cda3fd2c301f007f8f9502f4fbb54c66cb5ce9b875af"}, - {file = "types_setuptools-75.8.0.20250225.tar.gz", hash = "sha256:6038f7e983d55792a5f90d8fdbf5d4c186026214a16bb65dd6ae83c624ae9636"}, + {file = "types_setuptools-80.3.0.20250505-py3-none-any.whl", hash = "sha256:117c86a82367306388b55310d04da807ff4c3ecdf769656a5fdc0fdd06a2c1b6"}, + {file = "types_setuptools-80.3.0.20250505.tar.gz", hash = "sha256:5fd3d34b8fa3441d68d010fef95e232d1e48f3f5cb578f3477b7aae4f8374502"}, ] +[package.dependencies] +setuptools = "*" + [[package]] name = "types-urllib3" version = "1.26.25.14" @@ -7667,4 +7655,4 @@ voyageai = ["voyageai"] [metadata] lock-version = "2.1" python-versions = ">=3.9,<3.14" -content-hash = "f68c5ee8aebfc273918a13adc848cc7fc868e2a4ba0a35b59266a30471328cd3" +content-hash = "8ea19be9a1ad40e1ad2bdabb367bf49832fb0319f8c463ae433230956d94f723" diff --git a/pyproject.toml b/pyproject.toml index 14130821..dd900ab7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ numpy = [ { version = ">=1.26.0,<3", python = ">=3.12" }, ] pyyaml = ">=5.4,<7.0" -redis = "^5.0" +redis = ">=5.0,<7.0" pydantic = "^2" tenacity = ">=8.2.2" ml-dtypes = ">=0.4.0,<1.0.0" @@ -68,8 +68,8 @@ pytest-xdist = {extras = ["psutil"], version = "^3.6.1"} pre-commit = "^4.1.0" mypy = "1.9.0" nbval = "^0.11.0" -types-redis = "*" types-pyyaml = "*" +types-pyopenssl = "*" testcontainers = "^4.3.1" cryptography = { version = ">=44.0.1", markers = "python_version > '3.9.1'" } diff --git a/redisvl/extensions/cache/base.py b/redisvl/extensions/cache/base.py index aabc548e..75975b52 100644 --- a/redisvl/extensions/cache/base.py +++ b/redisvl/extensions/cache/base.py @@ -4,12 +4,14 @@ specific cache types such as LLM caches and embedding caches. """ -from typing import Any, Dict, Optional +from collections.abc import Mapping +from typing import Any, Dict, Optional, Union -from redis import Redis -from redis.asyncio import Redis as AsyncRedis +from redis import Redis # For backwards compatibility in type checking +from redis.cluster import RedisCluster from redisvl.redis.connection import RedisConnectionFactory +from redisvl.types import AsyncRedisClient, SyncRedisClient, SyncRedisCluster class BaseCache: @@ -19,14 +21,15 @@ class BaseCache: including TTL management, connection handling, and basic cache operations. """ - _redis_client: Optional[Redis] - _async_redis_client: Optional[AsyncRedis] + _redis_client: Optional[SyncRedisClient] + _async_redis_client: Optional[AsyncRedisClient] def __init__( self, name: str, ttl: Optional[int] = None, - redis_client: Optional[Redis] = None, + redis_client: Optional[SyncRedisClient] = None, + async_redis_client: Optional[AsyncRedisClient] = None, redis_url: str = "redis://localhost:6379", connection_kwargs: Dict[str, Any] = {}, ): @@ -36,7 +39,7 @@ def __init__( name (str): The name of the cache. ttl (Optional[int], optional): The time-to-live for records cached in Redis. Defaults to None. - redis_client (Optional[Redis], optional): A redis client connection instance. + redis_client (Optional[SyncRedisClient], optional): A redis client connection instance. Defaults to None. redis_url (str, optional): The redis url. Defaults to redis://localhost:6379. connection_kwargs (Dict[str, Any]): The connection arguments @@ -53,14 +56,13 @@ def __init__( } # Initialize Redis clients - self._async_redis_client = None + self._async_redis_client = async_redis_client + self._redis_client = redis_client - if redis_client: + if redis_client or async_redis_client: self._owns_redis_client = False - self._redis_client = redis_client else: self._owns_redis_client = True - self._redis_client = None # type: ignore def _get_prefix(self) -> str: """Get the key prefix for Redis keys. @@ -103,11 +105,11 @@ def set_ttl(self, ttl: Optional[int] = None) -> None: else: self._ttl = None - def _get_redis_client(self) -> Redis: + def _get_redis_client(self) -> SyncRedisClient: """Get or create a Redis client. Returns: - Redis: A Redis client instance. + SyncRedisClient: A Redis client instance. """ if self._redis_client is None: # Create new Redis client @@ -116,22 +118,30 @@ def _get_redis_client(self) -> Redis: self._redis_client = Redis.from_url(url, **kwargs) # type: ignore return self._redis_client - async def _get_async_redis_client(self) -> AsyncRedis: + async def _get_async_redis_client(self) -> AsyncRedisClient: """Get or create an async Redis client. Returns: - AsyncRedis: An async Redis client instance. + AsyncRedisClient: An async Redis client instance. """ if not hasattr(self, "_async_redis_client") or self._async_redis_client is None: client = self.redis_kwargs.get("redis_client") - if isinstance(client, Redis): + + if client and isinstance(client, (Redis, RedisCluster)): self._async_redis_client = RedisConnectionFactory.sync_to_async_redis( client ) else: - url = self.redis_kwargs["redis_url"] - kwargs = self.redis_kwargs["connection_kwargs"] - self._async_redis_client = RedisConnectionFactory.get_async_redis_connection(url, **kwargs) # type: ignore + url = str(self.redis_kwargs["redis_url"]) + kwargs = self.redis_kwargs.get("connection_kwargs", {}) + if not isinstance(kwargs, Mapping): + raise TypeError( + "Expected `connection_kwargs` to be a dictionary (e.g. {'decode_responses': True}), " + f"but got type: {type(kwargs).__name__}" + ) + self._async_redis_client = ( + RedisConnectionFactory.get_async_redis_connection(url, **kwargs) + ) return self._async_redis_client def expire(self, key: str, ttl: Optional[int] = None) -> None: @@ -183,7 +193,14 @@ def clear(self) -> None: client.delete(*keys) if cursor_int == 0: # Redis returns 0 when scan is complete break - cursor = cursor_int # Update cursor for next iteration + # Cluster returns a dict of cursor values. We need to stop if these all + # come back as 0. + elif isinstance(cursor_int, Mapping): + cursor_values = list(cursor_int.values()) + if all(v == 0 for v in cursor_values): + break + else: + cursor = cursor_int # Update cursor for next iteration async def aclear(self) -> None: """Async clear the cache of all keys.""" @@ -193,12 +210,21 @@ async def aclear(self) -> None: # Scan for all keys with our prefix cursor = 0 # Start with cursor 0 while True: - cursor_int, keys = await client.scan(cursor=cursor, match=f"{prefix}*", count=100) # type: ignore + cursor_int, keys = await client.scan( + cursor=cursor, match=f"{prefix}*", count=100 + ) # type: ignore if keys: await client.delete(*keys) if cursor_int == 0: # Redis returns 0 when scan is complete break - cursor = cursor_int # Update cursor for next iteration + # Cluster returns a dict of cursor values. We need to stop if these all + # come back as 0. + elif isinstance(cursor_int, Mapping): + cursor_values = list(cursor_int.values()) + if all(v == 0 for v in cursor_values): + break + else: + cursor = cursor_int # Update cursor for next iteration def disconnect(self) -> None: """Disconnect from Redis.""" @@ -207,12 +233,10 @@ def disconnect(self) -> None: if self._redis_client: self._redis_client.close() - self._redis_client = None # type: ignore - - if hasattr(self, "_async_redis_client") and self._async_redis_client: - # Use synchronous close for async client in synchronous context - self._async_redis_client.close() # type: ignore - self._async_redis_client = None # type: ignore + self._redis_client = None + # Async clients don't have a sync close method, so we just + # zero them out to allow garbage collection. + self._async_redis_client = None async def adisconnect(self) -> None: """Async disconnect from Redis.""" @@ -221,9 +245,9 @@ async def adisconnect(self) -> None: if self._redis_client: self._redis_client.close() - self._redis_client = None # type: ignore + self._redis_client = None if hasattr(self, "_async_redis_client") and self._async_redis_client: # Use proper async close method - await self._async_redis_client.aclose() # type: ignore - self._async_redis_client = None # type: ignore + await self._async_redis_client.aclose() + self._async_redis_client = None diff --git a/redisvl/extensions/cache/embeddings/embeddings.py b/redisvl/extensions/cache/embeddings/embeddings.py index 795096bd..31f73220 100644 --- a/redisvl/extensions/cache/embeddings/embeddings.py +++ b/redisvl/extensions/cache/embeddings/embeddings.py @@ -1,13 +1,14 @@ """Embeddings cache implementation for RedisVL.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union - -from redis import Redis -from redis.asyncio import Redis as AsyncRedis +from typing import Any, Awaitable, Dict, List, Optional, Tuple, cast from redisvl.extensions.cache.base import BaseCache from redisvl.extensions.cache.embeddings.schema import CacheEntry from redisvl.redis.utils import convert_bytes, hashify +from redisvl.types import AsyncRedisClient, SyncRedisClient +from redisvl.utils.log import get_logger + +logger = get_logger(__name__) class EmbeddingsCache(BaseCache): @@ -17,7 +18,8 @@ def __init__( self, name: str = "embedcache", ttl: Optional[int] = None, - redis_client: Optional[Redis] = None, + redis_client: Optional[SyncRedisClient] = None, + async_redis_client: Optional[AsyncRedisClient] = None, redis_url: str = "redis://localhost:6379", connection_kwargs: Dict[str, Any] = {}, ): @@ -26,7 +28,7 @@ def __init__( Args: name (str): The name of the cache. Defaults to "embedcache". ttl (Optional[int]): The time-to-live for cached embeddings. Defaults to None. - redis_client (Optional[Redis]): Redis client instance. Defaults to None. + redis_client (Optional[SyncRedisClient]): Redis client instance. Defaults to None. redis_url (str): Redis URL for connection. Defaults to "redis://localhost:6379". connection_kwargs (Dict[str, Any]): Redis connection arguments. Defaults to {}. @@ -45,6 +47,7 @@ def __init__( name=name, ttl=ttl, redis_client=redis_client, + async_redis_client=async_redis_client, redis_url=redis_url, connection_kwargs=connection_kwargs, ) @@ -173,7 +176,7 @@ def get_by_key(self, key: str) -> Optional[Dict[str, Any]]: if data: self.expire(key) - return self._process_cache_data(data) + return self._process_cache_data(data) # type: ignore def mget_by_keys(self, keys: List[str]) -> List[Optional[Dict[str, Any]]]: """Get multiple embeddings by their Redis keys. @@ -570,7 +573,7 @@ async def aget_by_key(self, key: str) -> Optional[Dict[str, Any]]: client = await self._get_async_redis_client() # Get all fields - data = await client.hgetall(key) + data = await client.hgetall(key) # type: ignore # Refresh TTL if data exists if data: @@ -608,7 +611,7 @@ async def amget_by_keys(self, keys: List[str]) -> List[Optional[Dict[str, Any]]] async with client.pipeline(transaction=False) as pipeline: # Queue all hgetall operations for key in keys: - await pipeline.hgetall(key) + pipeline.hgetall(key) results = await pipeline.execute() # Process results and refresh TTLs separately diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index ab86f240..1bfbc9b0 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -1,10 +1,9 @@ from pathlib import Path -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Mapping, Optional, Type, Union import redis.commands.search.reducers as reducers import yaml from pydantic import BaseModel, ConfigDict, Field, PrivateAttr -from redis import Redis from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer from redis.exceptions import ResponseError @@ -22,6 +21,7 @@ from redisvl.query.filter import Tag from redisvl.redis.connection import RedisConnectionFactory from redisvl.redis.utils import convert_bytes, hashify, make_dict +from redisvl.types import SyncRedisClient from redisvl.utils.log import get_logger from redisvl.utils.utils import deprecated_argument, model_to_dict, scan_by_pattern from redisvl.utils.vectorize.base import BaseVectorizer @@ -53,7 +53,7 @@ def __init__( routes: List[Route], vectorizer: Optional[BaseVectorizer] = None, routing_config: Optional[RoutingConfig] = None, - redis_client: Optional[Redis] = None, + redis_client: Optional[SyncRedisClient] = None, redis_url: str = "redis://localhost:6379", overwrite: bool = False, connection_kwargs: Dict[str, Any] = {}, @@ -66,7 +66,7 @@ def __init__( routes (List[Route]): List of Route objects. vectorizer (BaseVectorizer, optional): The vectorizer used to embed route references. Defaults to default HFTextVectorizer. routing_config (RoutingConfig, optional): Configuration for routing behavior. Defaults to the default RoutingConfig. - redis_client (Optional[Redis], optional): Redis client for connection. Defaults to None. + redis_client (Optional[SyncRedisClient], optional): Redis client for connection. Defaults to None. redis_url (str, optional): The redis url. Defaults to redis://localhost:6379. overwrite (bool, optional): Whether to overwrite existing index. Defaults to False. connection_kwargs (Dict[str, Any]): The connection arguments @@ -113,7 +113,7 @@ def __init__( def from_existing( cls, name: str, - redis_client: Optional[Redis] = None, + redis_client: Optional[SyncRedisClient] = None, redis_url: str = "redis://localhost:6379", **kwargs, ) -> "SemanticRouter": @@ -130,8 +130,16 @@ def from_existing( raise RedisModuleVersionError( f"Loading from existing index failed. {str(e)}" ) + if redis_client is None: + raise ValueError( + "Creating Redis client failed. Please check the redis_url and connection_kwargs." + ) - router_dict = redis_client.json().get(f"{name}:route_config") # type: ignore + router_dict = redis_client.json().get(f"{name}:route_config") + if not isinstance(router_dict, dict): + raise ValueError( + f"No valid router config found for {name}. Received: {router_dict!r}" + ) return cls.from_dict( router_dict, redis_url=redis_url, redis_client=redis_client ) @@ -139,7 +147,7 @@ def from_existing( @deprecated_argument("dtype") def _initialize_index( self, - redis_client: Optional[Redis] = None, + redis_client: Optional[SyncRedisClient] = None, redis_url: str = "redis://localhost:6379", overwrite: bool = False, dtype: str = "float32", @@ -300,9 +308,10 @@ def _build_aggregate_request( aggregate_request = ( AggregateRequest(aggregate_query) .group_by( - "@route_name", aggregation_func("vector_distance").alias("distance") + "@route_name", # type: ignore + aggregation_func("vector_distance").alias("distance"), ) - .sort_by("@distance", max=max_k) + .sort_by("@distance", max=max_k) # type: ignore .dialect(2) ) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 77271e4e..c89c263f 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -17,10 +17,28 @@ Sequence, Tuple, Union, + cast, ) +import redis.exceptions + +# Add missing imports +from redis import Redis +from redis.asyncio import Redis as AsyncRedis +from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster +from redis.cluster import RedisCluster + from redisvl.query.query import VectorQuery -from redisvl.redis.utils import convert_bytes, make_dict +from redisvl.redis.utils import ( + _keys_share_hash_tag, + async_cluster_create_index, + async_cluster_search, + cluster_create_index, + cluster_search, + convert_bytes, + make_dict, +) +from redisvl.types import AsyncRedisClient, SyncRedisClient from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper if TYPE_CHECKING: @@ -29,11 +47,22 @@ from redis.commands.search.result import Result from redisvl.query.query import BaseQuery -import redis -import redis.asyncio as aredis +from redis import __version__ as redis_version from redis.client import NEVER_DECODE from redis.commands.helpers import get_protocol_version # type: ignore -from redis.commands.search.indexDefinition import IndexDefinition + +# Redis 5.x compatibility (6 fixed the import path) +if redis_version.startswith("5"): + from redis.commands.search.indexDefinition import ( # type: ignore[import-untyped] + IndexDefinition, + ) +else: + from redis.commands.search.index_definition import ( # type: ignore[no-redef] + IndexDefinition, + ) + +# Need Result outside TYPE_CHECKING for cast +from redis.commands.search.result import Result from redisvl.exceptions import ( QueryValidationError, @@ -154,7 +183,7 @@ def _process(doc: "Document") -> Dict[str, Any]: def process_aggregate_results( results: "AggregateResult", query: AggregationQuery, storage_type: StorageType ) -> List[Dict[str, Any]]: - """Convert an aggregate reslt object into a list of document dictionaries. + """Convert an aggregate result object into a list of document dictionaries. This function processes results from Redis, handling different storage types and query types. For JSON storage with empty return fields, it @@ -338,7 +367,7 @@ class SearchIndex(BaseSearchIndex): def __init__( self, schema: IndexSchema, - redis_client: Optional[redis.Redis] = None, + redis_client: Optional[SyncRedisClient] = None, redis_url: Optional[str] = None, connection_kwargs: Optional[Dict[str, Any]] = None, validate_on_load: bool = False, @@ -350,7 +379,7 @@ def __init__( Args: schema (IndexSchema): Index schema object. - redis_client(Optional[redis.Redis]): An + redis_client(Optional[Redis]): An instantiated redis client. redis_url (Optional[str]): The URL of the Redis server to connect to. @@ -393,7 +422,7 @@ def disconnect(self): def from_existing( cls, name: str, - redis_client: Optional[redis.Redis] = None, + redis_client: Optional[SyncRedisClient] = None, redis_url: Optional[str] = None, **kwargs, ): @@ -402,7 +431,7 @@ def from_existing( Args: name (str): Name of the search index in Redis. - redis_client(Optional[redis.Redis]): An + redis_client(Optional[Redis]): An instantiated redis client. redis_url (Optional[str]): The URL of the Redis server to connect to. @@ -437,12 +466,12 @@ def from_existing( return cls(schema, redis_client, **kwargs) @property - def client(self) -> Optional[redis.Redis]: + def client(self) -> Optional[SyncRedisClient]: """The underlying redis-py client object.""" return self.__redis_client @property - def _redis_client(self) -> redis.Redis: + def _redis_client(self) -> SyncRedisClient: """ Get a Redis client instance. @@ -487,7 +516,7 @@ def connect(self, redis_url: Optional[str] = None, **kwargs): ) @deprecated_function("set_client", "Pass connection parameters in __init__.") - def set_client(self, redis_client: redis.Redis, **kwargs): + def set_client(self, redis_client: SyncRedisClient, **kwargs): """Manually set the Redis client to use with the search index. This method configures the search index to use a specific Redis or @@ -495,7 +524,7 @@ def set_client(self, redis_client: redis.Redis, **kwargs): custom-configured client is preferred instead of creating a new one. Args: - redis_client (redis.Redis): A Redis or Async Redis + redis_client (Redis): A Redis or Async Redis client instance to be used for the connection. Raises: @@ -544,15 +573,30 @@ def create(self, overwrite: bool = False, drop: bool = False) -> None: self.delete(drop=drop) try: - self._redis_client.ft(self.name).create_index( # type: ignore - fields=redis_fields, - definition=IndexDefinition( - prefix=[self.schema.index.prefix], index_type=self._storage.type - ), + definition = IndexDefinition( + prefix=[self.schema.index.prefix], index_type=self._storage.type ) - except: + if isinstance(self._redis_client, RedisCluster): + cluster_create_index( + index_name=self.name, + client=self._redis_client, + fields=redis_fields, + definition=definition, + ) + else: + self._redis_client.ft(self.name).create_index( + fields=redis_fields, + definition=definition, + ) + except redis.exceptions.RedisError as e: + raise RedisSearchError( + f"Failed to create index '{self.name}' on Redis: {str(e)}" + ) from e + except Exception as e: logger.exception("Error while trying to create the index") - raise + raise RedisSearchError( + f"Unexpected error creating index '{self.name}': {str(e)}" + ) from e def delete(self, drop: bool = True): """Delete the search index while optionally dropping all keys associated @@ -566,9 +610,24 @@ def delete(self, drop: bool = True): redis.exceptions.ResponseError: If the index does not exist. """ try: - self._redis_client.ft(self.schema.index.name).dropindex( # type: ignore - delete_documents=drop - ) + # For Redis Cluster with drop=True, we need to handle key deletion manually + # to avoid cross-slot errors since we control the keys opaquely + if drop and isinstance(self._redis_client, RedisCluster): + # First clear all keys manually (handles cluster compatibility) + self.clear() + # Then drop the index without the DD flag + cmd_args = ["FT.DROPINDEX", self.schema.index.name] + else: + # Standard approach for non-cluster or when not dropping keys + cmd_args = ["FT.DROPINDEX", self.schema.index.name] + if drop: + cmd_args.append("DD") + + if isinstance(self._redis_client, RedisCluster): + target_nodes = [self._redis_client.get_default_node()] + self._redis_client.execute_command(*cmd_args, target_nodes=target_nodes) + else: + self._redis_client.execute_command(*cmd_args) except Exception as e: raise RedisSearchError(f"Error while deleting index: {str(e)}") from e @@ -576,19 +635,35 @@ def clear(self) -> int: """Clear all keys in Redis associated with the index, leaving the index available and in-place for future insertions or updates. + NOTE: This method requires custom behavior for Redis Cluster because + here, we can't easily give control of the keys we're clearing to the + user so they can separate them based on hash tag. + Returns: int: Count of records deleted from Redis. """ - # Track deleted records + client = cast(SyncRedisClient, self._redis_client) total_records_deleted: int = 0 - # Paginate using queries and delete in batches for batch in self.paginate( FilterQuery(FilterExpression("*"), return_fields=["id"]), page_size=500 ): batch_keys = [record["id"] for record in batch] - record_deleted = self._redis_client.delete(*batch_keys) # type: ignore - total_records_deleted += record_deleted # type: ignore + if batch_keys: + is_cluster = isinstance(client, RedisCluster) + if is_cluster: + records_deleted_in_batch = 0 + for key_to_delete in batch_keys: + try: + records_deleted_in_batch += cast( + int, client.delete(key_to_delete) + ) + except redis.exceptions.RedisError as e: + logger.warning(f"Failed to delete key {key_to_delete}: {e}") + total_records_deleted += records_deleted_in_batch + else: + record_deleted = cast(int, client.delete(*batch_keys)) + total_records_deleted += record_deleted return total_records_deleted @@ -612,6 +687,10 @@ def drop_documents(self, ids: Union[str, List[str]]) -> int: This method converts document IDs to Redis keys automatically by applying the index's key prefix and separator configuration. + NOTE: Cluster users will need to incorporate hash tags into their + document IDs and only call this method with documents from a single hash + tag at a time. + Args: ids (Union[str, List[str]]): The document ID or IDs to remove from the index. @@ -622,6 +701,13 @@ def drop_documents(self, ids: Union[str, List[str]]) -> int: if not ids: return 0 keys = [self.key(id) for id in ids] + # Check for cluster compatibility + if isinstance( + self._redis_client, RedisCluster + ) and not _keys_share_hash_tag(keys): + raise ValueError( + "All keys must share a hash tag when using Redis Cluster." + ) return self._redis_client.delete(*keys) # type: ignore else: key = self.key(ids) @@ -684,7 +770,7 @@ def load( """ try: return self._storage.write( - self._redis_client, # type: ignore + self._redis_client, objects=data, id_field=id_field, keys=keys, @@ -716,15 +802,16 @@ def fetch(self, id: str) -> Optional[Dict[str, Any]]: Returns: Dict[str, Any]: The fetched object. """ - obj = self._storage.get(self._redis_client, [self.key(id)]) # type: ignore + obj = self._storage.get(self._redis_client, [self.key(id)]) if obj: return convert_bytes(obj[0]) return None def _aggregate(self, aggregation_query: AggregationQuery) -> List[Dict[str, Any]]: - """Execute an aggretation query and processes the results.""" + """Execute an aggregation query and processes the results.""" results = self.aggregate( - aggregation_query, query_params=aggregation_query.params # type: ignore[attr-defined] + aggregation_query, + query_params=aggregation_query.params, # type: ignore[attr-defined] ) return process_aggregate_results( results, @@ -743,16 +830,22 @@ def aggregate(self, *args, **kwargs) -> "AggregateResult": Result: Raw Redis aggregation results. """ try: - return self._redis_client.ft(self.schema.index.name).aggregate( # type: ignore + return self._redis_client.ft(self.schema.index.name).aggregate( *args, **kwargs ) - except Exception as e: + except redis.exceptions.RedisError as e: + if "CROSSSLOT" in str(e): + raise RedisSearchError( + "Cross-slot error during aggregation. Ensure consistent hash tags in your keys." + ) raise RedisSearchError(f"Error while aggregating: {str(e)}") from e + except Exception as e: + raise RedisSearchError( + f"Unexpected error while aggregating: {str(e)}" + ) from e def batch_search( - self, - queries: List[SearchParams], - batch_size: int = 10, + self, queries: List[SearchParams], batch_size: int = 10 ) -> List["Result"]: """Perform a search against the index for multiple queries. @@ -760,15 +853,18 @@ def batch_search( returns a list of Result objects for each query. Results are returned in the same order as the queries. + NOTE: Cluster users may need to incorporate hash tags into their query + to avoid cross-slot operations. + Args: - queries (List[SearchParams]): The queries to search for. batch_size - (int, optional): The number of queries to search for at a time. + queries (List[SearchParams]): The queries to search for. + batch_size (int, optional): The number of queries to search for at a time. Defaults to 10. Returns: List[Result]: The search results for each query. """ - all_parsed = [] + all_results = [] search = self._redis_client.ft(self.schema.index.name) options = {} if get_protocol_version(self._redis_client) not in ["3", 3]: @@ -805,16 +901,16 @@ def batch_search( # for all queries in the batch as the duration for each query duration = (time.time() - st) * 1000.0 - for i, query_results in enumerate(results): - _built_query = batch_built_queries[i] + for j, query_results in enumerate(results): + _built_query = batch_built_queries[j] parsed_result = search._parse_search( # type: ignore query_results, query=_built_query, duration=duration, ) # Return a parsed Result object for each query - all_parsed.append(parsed_result) - return all_parsed + all_results.append(parsed_result) + return all_results def search(self, *args, **kwargs) -> "Result": """Perform a search against the index. @@ -827,11 +923,25 @@ def search(self, *args, **kwargs) -> "Result": Result: Raw Redis search results. """ try: - return self._redis_client.ft(self.schema.index.name).search( # type: ignore - *args, **kwargs - ) - except Exception as e: + if isinstance(self._redis_client, RedisCluster): + # Use special cluster search for RedisCluster + return cluster_search( + self._redis_client.ft(self.schema.index.name), + *args, + **kwargs, # type: ignore + ) + else: + return self._redis_client.ft(self.schema.index.name).search( + *args, **kwargs + ) # type: ignore + except redis.exceptions.RedisError as e: + if "CROSSSLOT" in str(e): + raise RedisSearchError( + "Cross-slot error during search. Ensure consistent hash tags in your keys." + ) raise RedisSearchError(f"Error while searching: {str(e)}") from e + except Exception as e: + raise RedisSearchError(f"Unexpected error while searching: {str(e)}") from e def batch_query( self, queries: Sequence[BaseQuery], batch_size: int = 10 @@ -943,7 +1053,7 @@ def listall(self) -> List[str]: Returns: List[str]: The list of indices in the database. """ - return convert_bytes(self._redis_client.execute_command("FT._LIST")) # type: ignore + return convert_bytes(self._redis_client.execute_command("FT._LIST")) def exists(self) -> bool: """Check if the index exists in Redis. @@ -954,7 +1064,7 @@ def exists(self) -> bool: return self.schema.index.name in self.listall() @staticmethod - def _info(name: str, redis_client: redis.Redis) -> Dict[str, Any]: + def _info(name: str, redis_client: SyncRedisClient) -> Dict[str, Any]: """Run FT.INFO to fetch information about the index.""" try: return convert_bytes(redis_client.ft(name).info()) # type: ignore @@ -974,7 +1084,7 @@ def info(self, name: Optional[str] = None) -> Dict[str, Any]: dict: A dictionary containing the information about the index. """ index_name = name or self.schema.index.name - return self._info(index_name, self._redis_client) # type: ignore + return self._info(index_name, self._redis_client) def __enter__(self): return self @@ -1019,7 +1129,7 @@ def __init__( schema: IndexSchema, *, redis_url: Optional[str] = None, - redis_client: Optional[aredis.Redis] = None, + redis_client: Optional[AsyncRedisClient] = None, connection_kwargs: Optional[Dict[str, Any]] = None, validate_on_load: bool = False, **kwargs, @@ -1030,7 +1140,7 @@ def __init__( schema (IndexSchema): Index schema object. redis_url (Optional[str], optional): The URL of the Redis server to connect to. - redis_client (Optional[aredis.Redis]): An + redis_client (Optional[AsyncRedis]): An instantiated redis client. connection_kwargs (Optional[Dict[str, Any]]): Redis client connection args. @@ -1063,7 +1173,7 @@ def __init__( async def from_existing( cls, name: str, - redis_client: Optional[aredis.Redis] = None, + redis_client: Optional[AsyncRedisClient] = None, redis_url: Optional[str] = None, **kwargs, ): @@ -1072,7 +1182,7 @@ async def from_existing( Args: name (str): Name of the search index in Redis. - redis_client(Optional[redis.Redis]): An + redis_client(Optional[Redis]): An instantiated redis client. redis_url (Optional[str]): The URL of the Redis server to connect to. @@ -1111,7 +1221,7 @@ async def from_existing( return cls(schema, redis_client=redis_client, **kwargs) @property - def client(self) -> Optional[aredis.Redis]: + def client(self) -> Optional[AsyncRedisClient]: """The underlying redis-py client object.""" return self._redis_client @@ -1128,7 +1238,7 @@ async def connect(self, redis_url: Optional[str] = None, **kwargs): await self.set_client(client) @deprecated_function("set_client", "Pass connection parameters in __init__.") - async def set_client(self, redis_client: Union[aredis.Redis, redis.Redis]): + async def set_client(self, redis_client: Union[AsyncRedisClient, SyncRedisClient]): """ [DEPRECATED] Manually set the Redis client to use with the search index. This method is deprecated; please provide connection parameters in __init__. @@ -1139,7 +1249,7 @@ async def set_client(self, redis_client: Union[aredis.Redis, redis.Redis]): self._redis_client = redis_client return self - async def _get_client(self) -> aredis.Redis: + async def _get_client(self) -> AsyncRedisClient: """Lazily instantiate and return the async Redis client.""" if self._redis_client is None: async with self._lock: @@ -1160,20 +1270,38 @@ async def _get_client(self) -> aredis.Redis: return self._redis_client async def _validate_client( - self, redis_client: Union[aredis.Redis, redis.Redis] - ) -> aredis.Redis: - if isinstance(redis_client, redis.Redis): + self, redis_client: Union[AsyncRedisClient, SyncRedisClient] + ) -> AsyncRedisClient: + # Handle deprecated sync client conversion + if isinstance(redis_client, (Redis, RedisCluster)): warnings.warn( - "Converting sync Redis client to async client is deprecated " + "Passing a sync Redis client to AsyncSearchIndex is deprecated " "and will be removed in the next major version. Please use an " "async Redis client instead.", DeprecationWarning, ) - redis_client = RedisConnectionFactory.sync_to_async_redis(redis_client) - elif not isinstance(redis_client, aredis.Redis): - raise ValueError("Invalid client type: must be redis.asyncio.Redis") + # Use a new variable name + async_redis_client: AsyncRedisClient = ( + RedisConnectionFactory.sync_to_async_redis(redis_client) + ) + return async_redis_client # Return the converted client + # Check if it's a valid async client (standard or cluster) + elif not isinstance(redis_client, (AsyncRedis, AsyncRedisCluster)): + raise ValueError( + "Invalid async client type: must be AsyncRedis or AsyncRedisCluster" + ) + # If it passed the elif, it's already an AsyncRedisClient return redis_client + @staticmethod + async def _info(name: str, redis_client: AsyncRedisClient) -> Dict[str, Any]: + try: + return convert_bytes(await redis_client.ft(name).info()) + except Exception as e: + raise RedisSearchError( + f"Error while fetching {name} index info: {str(e)}" + ) from e + async def create(self, overwrite: bool = False, drop: bool = False) -> None: """Asynchronously create an index in Redis with the current schema and properties. @@ -1215,15 +1343,30 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None: await self.delete(drop) try: - await client.ft(self.schema.index.name).create_index( - fields=redis_fields, - definition=IndexDefinition( - prefix=[self.schema.index.prefix], index_type=self._storage.type - ), + definition = IndexDefinition( + prefix=[self.schema.index.prefix], index_type=self._storage.type ) - except: + if isinstance(client, AsyncRedisCluster): + await async_cluster_create_index( + index_name=self.schema.index.name, + client=client, + fields=redis_fields, + definition=definition, + ) + else: + await client.ft(self.schema.index.name).create_index( + fields=redis_fields, + definition=definition, + ) + except redis.exceptions.RedisError as e: + raise RedisSearchError( + f"Failed to create index '{self.name}' on Redis: {str(e)}" + ) from e + except Exception as e: logger.exception("Error while trying to create the index") - raise + raise RedisSearchError( + f"Unexpected error creating index '{self.name}': {str(e)}" + ) from e async def delete(self, drop: bool = True): """Delete the search index. @@ -1237,7 +1380,24 @@ async def delete(self, drop: bool = True): """ client = await self._get_client() try: - await client.ft(self.schema.index.name).dropindex(delete_documents=drop) + # For Redis Cluster with drop=True, we need to handle key deletion manually + # to avoid cross-slot errors since we control the keys opaquely + if drop and isinstance(client, AsyncRedisCluster): + # First clear all keys manually (handles cluster compatibility) + await self.clear() + # Then drop the index without the DD flag + cmd_args = ["FT.DROPINDEX", self.schema.index.name] + else: + # Standard approach for non-cluster or when not dropping keys + cmd_args = ["FT.DROPINDEX", self.schema.index.name] + if drop: + cmd_args.append("DD") + + if isinstance(client, AsyncRedisCluster): + target_nodes = [client.get_default_node()] + await client.execute_command(*cmd_args, target_nodes=target_nodes) + else: + await client.execute_command(*cmd_args) except Exception as e: raise RedisSearchError(f"Error while deleting index: {str(e)}") from e @@ -1245,6 +1405,10 @@ async def clear(self) -> int: """Clear all keys in Redis associated with the index, leaving the index available and in-place for future insertions or updates. + NOTE: This method requires custom behavior for Redis Cluster because here, + we can't easily give control of the keys we're clearing to the user so they + can separate them based on hash tag. + Returns: int: Count of records deleted from Redis. """ @@ -1255,8 +1419,21 @@ async def clear(self) -> int: FilterQuery(FilterExpression("*"), return_fields=["id"]), page_size=500 ): batch_keys = [record["id"] for record in batch] - records_deleted = await client.delete(*batch_keys) - total_records_deleted += records_deleted + if batch_keys: + is_cluster = isinstance(client, AsyncRedisCluster) + if is_cluster: + records_deleted_in_batch = 0 + for key_to_delete in batch_keys: + try: + records_deleted_in_batch += cast( + int, await client.delete(key_to_delete) + ) + except redis.exceptions.RedisError as e: + logger.warning(f"Failed to delete key {key_to_delete}: {e}") + total_records_deleted += records_deleted_in_batch + else: + records_deleted = await client.delete(*batch_keys) + total_records_deleted += records_deleted return total_records_deleted @@ -1281,6 +1458,10 @@ async def drop_documents(self, ids: Union[str, List[str]]) -> int: This method converts document IDs to Redis keys automatically by applying the index's key prefix and separator configuration. + NOTE: Cluster users will need to incorporate hash tags into their + document IDs and only call this method with documents from a single hash + tag at a time. + Args: ids (Union[str, List[str]]): The document ID or IDs to remove from the index. @@ -1292,6 +1473,11 @@ async def drop_documents(self, ids: Union[str, List[str]]) -> int: if not ids: return 0 keys = [self.key(id) for id in ids] + # Check for cluster compatibility + if isinstance(client, AsyncRedisCluster) and not _keys_share_hash_tag(keys): + raise ValueError( + "All keys must share a hash tag when using Redis Cluster." + ) return await client.delete(*keys) else: key = self.key(ids) @@ -1440,25 +1626,46 @@ async def aggregate(self, *args, **kwargs) -> "AggregateResult": """ client = await self._get_client() try: - return client.ft(self.schema.index.name).aggregate(*args, **kwargs) - except Exception as e: + return await client.ft(self.schema.index.name).aggregate(*args, **kwargs) + except redis.exceptions.RedisError as e: + if "CROSSSLOT" in str(e): + raise RedisSearchError( + "Cross-slot error during aggregation. Ensure consistent hash tags in your keys." + ) raise RedisSearchError(f"Error while aggregating: {str(e)}") from e + except Exception as e: + raise RedisSearchError( + f"Unexpected error while aggregating: {str(e)}" + ) from e async def batch_search( self, queries: List[SearchParams], batch_size: int = 10 ) -> List["Result"]: - """Perform a search against the index for multiple queries. + """Asynchronously execute a batch of search queries. + + This method takes a list of search queries and executes them in batches + to improve performance when dealing with multiple queries. - This method takes a list of queries and returns a list of Result objects - for each query. Results are returned in the same order as the queries. + NOTE: Cluster users may need to incorporate hash tags into their query + to avoid cross-slot operations. Args: - queries (List[SearchParams]): The queries to search for. batch_size - (int, optional): The number of queries to search for at a time. - Defaults to 10. + queries (List[SearchParams]): A list of search queries to execute. + Each query can be either a string or a tuple of (query, params). + batch_size (int, optional): The number of queries to execute in each + batch. Defaults to 10. Returns: - List[Result]: The search results for each query. + List[Result]: A list of search results corresponding to each query. + + .. code-block:: python + + queries = [ + "hello world", + ("goodbye world", {"num_results": 5}), + ] + + results = await index.batch_search(queries) """ all_results = [] client = await self._get_client() @@ -1470,7 +1677,7 @@ async def batch_search( for i in range(0, len(queries), batch_size): batch_queries = queries[i : i + batch_size] - # redis-py doesn't support calling `search` in a pipeline, + # redis-py doesn't support calling `search` in an async pipeline, # so we need to manually execute each command in a pipeline # and parse the results async with client.pipeline(transaction=False) as pipe: @@ -1498,8 +1705,8 @@ async def batch_search( # for all queries in the batch as the duration for each query duration = (time.time() - st) * 1000.0 - for i, query_results in enumerate(results): - _built_query = batch_built_queries[i] + for j, query_results in enumerate(results): + _built_query = batch_built_queries[j] parsed_result = search._parse_search( # type: ignore query_results, query=_built_query, @@ -1510,20 +1717,34 @@ async def batch_search( return all_results async def search(self, *args, **kwargs) -> "Result": - """Perform a search on this index. + """Perform an async search against the index. - Wrapper around redis.search.Search that adds the index name - to the search query and passes along the rest of the arguments - to the redis-py ft.search() method. + Wrapper around the search API that adds the index name + to the query and passes along the rest of the arguments + to the redis-py ft().search() method. Returns: Result: Raw Redis search results. """ - client = await self._get_client() try: - return await client.ft(self.schema.index.name).search(*args, **kwargs) # type: ignore - except Exception as e: + client = await self._get_client() + if isinstance(client, AsyncRedisCluster): + # Use special cluster search for AsyncRedisCluster + return await async_cluster_search( + client.ft(self.schema.index.name), + *args, + **kwargs, # type: ignore + ) + else: + return await client.ft(self.schema.index.name).search(*args, **kwargs) # type: ignore + except redis.exceptions.RedisError as e: + if "CROSSSLOT" in str(e): + raise RedisSearchError( + "Cross-slot error during search. Ensure consistent hash tags in your keys." + ) raise RedisSearchError(f"Error while searching: {str(e)}") from e + except Exception as e: + raise RedisSearchError(f"Unexpected error while searching: {str(e)}") from e async def batch_query( self, queries: List[BaseQuery], batch_size: int = 10 @@ -1641,7 +1862,11 @@ async def listall(self) -> List[str]: List[str]: The list of indices in the database. """ client = await self._get_client() - return convert_bytes(await client.execute_command("FT._LIST")) + if isinstance(client, AsyncRedisCluster): + target_nodes = client.get_random_node() + return convert_bytes(await target_nodes.execute_command("FT._LIST")) + else: + return convert_bytes(await client.execute_command("FT._LIST")) async def exists(self) -> bool: """Check if the index exists in Redis. @@ -1663,22 +1888,13 @@ async def info(self, name: Optional[str] = None) -> Dict[str, Any]: """ client = await self._get_client() index_name = name or self.schema.index.name - return await type(self)._info(index_name, client) - - @staticmethod - async def _info(name: str, redis_client: aredis.Redis) -> Dict[str, Any]: - try: - return convert_bytes(await redis_client.ft(name).info()) # type: ignore - except Exception as e: - raise RedisSearchError( - f"Error while fetching {name} index info: {str(e)}" - ) from e + return await self._info(index_name, client) async def disconnect(self): if self._owns_redis_client is False: return if self._redis_client is not None: - await self._redis_client.aclose() # type: ignore + await self._redis_client.aclose() self._redis_client = None def disconnect_sync(self): diff --git a/redisvl/index/storage.py b/redisvl/index/storage.py index 792b6bc4..f372bd56 100644 --- a/redisvl/index/storage.py +++ b/redisvl/index/storage.py @@ -1,14 +1,46 @@ -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from collections.abc import Collection +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, + cast, +) from pydantic import BaseModel, ValidationError -from redis import Redis -from redis.asyncio import Redis as AsyncRedis -from redis.commands.search.indexDefinition import IndexType +from redis import __version__ as redis_version + +# Add imports for Pipeline types +from redis.asyncio.client import Pipeline as AsyncPipeline +from redis.asyncio.cluster import ClusterPipeline as AsyncClusterPipeline + +# Redis 5.x compatibility (6 fixed the import path) +if redis_version.startswith("5"): + from redis.commands.search.indexDefinition import ( # type: ignore[import-untyped] + IndexType, + ) +else: + from redis.commands.search.index_definition import ( # type: ignore[no-redef] + IndexType, + ) from redisvl.exceptions import SchemaValidationError from redisvl.redis.utils import convert_bytes from redisvl.schema import IndexSchema from redisvl.schema.validation import validate_object +from redisvl.types import ( + AsyncRedisClient, + AsyncRedisClientOrPipeline, + AsyncRedisPipeline, + RedisClientOrPipeline, + SyncRedisClient, + SyncRedisPipeline, +) from redisvl.utils.log import get_logger from redisvl.utils.utils import create_ulid @@ -96,51 +128,37 @@ def _preprocess(obj: Any, preprocess: Optional[Callable] = None) -> Dict[str, An return obj @staticmethod - def _set(client: Redis, key: str, obj: Dict[str, Any]): + def _set( + client: RedisClientOrPipeline, key: str, obj: Dict[str, Any] + ) -> Union[SyncRedisPipeline, Dict[str, Any]]: """Synchronously set the value in Redis for the given key. Args: - client (Redis): The Redis client instance. + client (RedisClientOrPipeline): The Redis client instance. key (str): The key under which to store the object. obj (Dict[str, Any]): The object to store in Redis. """ raise NotImplementedError @staticmethod - async def _aset(client: AsyncRedis, key: str, obj: Dict[str, Any]): - """Asynchronously set the value in Redis for the given key. - - Args: - client (AsyncRedis): The Redis client instance. - key (str): The key under which to store the object. - obj (Dict[str, Any]): The object to store in Redis. - """ + async def _aset( + client: AsyncRedisClientOrPipeline, key: str, obj: Dict[str, Any] + ) -> Union[AsyncRedisPipeline, Dict[str, Any]]: + """Asynchronously set data in Redis using the provided client or pipeline.""" raise NotImplementedError @staticmethod - def _get(client: Redis, key: str) -> Dict[str, Any]: - """Synchronously get the value from Redis for the given key. - - Args: - client (Redis): The Redis client instance. - key (str): The key for which to retrieve the object. - - Returns: - Dict[str, Any]: The retrieved object from Redis. - """ + def _get( + client: RedisClientOrPipeline, key: str + ) -> Union[SyncRedisPipeline, Dict[str, Any]]: + """Synchronously get data from Redis using the provided client or pipeline.""" raise NotImplementedError @staticmethod - async def _aget(client: AsyncRedis, key: str) -> Dict[str, Any]: - """Asynchronously get the value from Redis for the given key. - - Args: - client (AsyncRedis): The Redis client instance. - key (str): The key for which to retrieve the object. - - Returns: - Dict[str, Any]: The retrieved object from Redis. - """ + async def _aget( + client: AsyncRedisClientOrPipeline, key: str + ) -> Union[AsyncRedisPipeline, Dict[str, Any]]: + """Asynchronously get data from Redis using the provided client or pipeline.""" raise NotImplementedError def _validate(self, obj: Dict[str, Any]) -> Dict[str, Any]: @@ -159,6 +177,29 @@ def _validate(self, obj: Dict[str, Any]) -> Dict[str, Any]: # Pass directly to validation function and let any errors propagate return validate_object(self.index_schema, obj) + def _get_keys( + self, + objects: List[Any], + keys: Optional[Iterable[str]] = None, + id_field: Optional[str] = None, + ) -> List[str]: + """Generate Redis keys for a list of objects.""" + generated_keys: List[str] = [] + keys_iterator = iter(keys) if keys else None + + if keys and len(list(keys)) != len(objects): + raise ValueError( + "Length of provided keys does not match the length of objects." + ) + + for obj in objects: + if keys_iterator: + key = next(keys_iterator) + else: + key = self._create_key(obj, id_field) + generated_keys.append(key) + return generated_keys + def _preprocess_and_validate_objects( self, objects: Iterable[Any], @@ -220,7 +261,7 @@ def _preprocess_and_validate_objects( def write( self, - redis_client: Redis, + redis_client: SyncRedisClient, objects: Iterable[Any], id_field: Optional[str] = None, keys: Optional[Iterable[str]] = None, @@ -233,7 +274,7 @@ def write( returns a list of Redis keys written to the database. Args: - redis_client (Redis): A Redis client used for writing data. + redis_client (RedisClient): A Redis client used for writing data. objects (Iterable[Any]): An iterable of objects to store. id_field (Optional[str], optional): Field used as the key for each object. Defaults to None. @@ -295,7 +336,7 @@ def write( async def awrite( self, - redis_client: AsyncRedis, + redis_client: AsyncRedisClient, objects: Iterable[Any], id_field: Optional[str] = None, keys: Optional[Iterable[str]] = None, @@ -308,7 +349,7 @@ async def awrite( The method returns a list of keys written to the database. Args: - redis_client (AsyncRedis): An asynchronous Redis client used + redis_client (AsyncRedisClient): An asynchronous Redis client used for writing data. objects (Iterable[Any]): An iterable of objects to store. id_field (Optional[str], optional): Field used as the key for each @@ -373,13 +414,16 @@ async def awrite( return added_keys def get( - self, redis_client: Redis, keys: Iterable[str], batch_size: Optional[int] = None + self, + redis_client: SyncRedisClient, + keys: Collection[str], + batch_size: Optional[int] = None, ) -> List[Dict[str, Any]]: """Retrieve objects from Redis by keys. Args: - redis_client (Redis): Synchronous Redis client. - keys (Iterable[str]): Keys to retrieve from Redis. + redis_client (SyncRedisClient): Synchronous Redis client. + keys (Collection[str]): Keys to retrieve from Redis. batch_size (Optional[int], optional): Number of objects to write in a single Redis pipeline execution. Defaults to class's default batch size. @@ -389,10 +433,10 @@ def get( """ results: List = [] - if not isinstance(keys, Iterable): # type: ignore - raise TypeError("Keys must be an iterable of strings") + if not isinstance(keys, Collection): + raise TypeError("Keys must be a collection of strings") - if len(keys) == 0: # type: ignore + if len(keys) == 0: return [] if batch_size is None: @@ -412,15 +456,15 @@ def get( async def aget( self, - redis_client: AsyncRedis, - keys: Iterable[str], + redis_client: AsyncRedisClient, + keys: Collection[str], batch_size: Optional[int] = None, ) -> List[Dict[str, Any]]: """Asynchronously retrieve objects from Redis by keys. Args: - redis_client (AsyncRedis): Asynchronous Redis client. - keys (Iterable[str]): Keys to retrieve from Redis. + redis_client (AsyncRedisClient): Asynchronous Redis client. + keys (Collection[str]): Keys to retrieve from Redis. batch_size (Optional[int], optional): Number of objects to write in a single Redis pipeline execution. Defaults to class's default batch size. @@ -431,10 +475,10 @@ async def aget( """ results: List = [] - if not isinstance(keys, Iterable): # type: ignore - raise TypeError("Keys must be an iterable of strings") + if not isinstance(keys, Collection): + raise TypeError("Keys must be a collection of strings") - if len(keys) == 0: # type: ignore + if len(keys) == 0: return [] if batch_size is None: @@ -465,52 +509,60 @@ class HashStorage(BaseStorage): """Hash data type for the index""" @staticmethod - def _set(client: Redis, key: str, obj: Dict[str, Any]): + def _set(client: RedisClientOrPipeline, key: str, obj: Dict[str, Any]): """Synchronously set a hash value in Redis for the given key. Args: - client (Redis): The Redis client instance. + client (SyncRedisClient): The Redis client instance. key (str): The key under which to store the hash. obj (Dict[str, Any]): The hash to store in Redis. """ - client.hset(name=key, mapping=obj) # type: ignore + client.hset(name=key, mapping=obj) @staticmethod - async def _aset(client: AsyncRedis, key: str, obj: Dict[str, Any]): + async def _aset(client: AsyncRedisClientOrPipeline, key: str, obj: Dict[str, Any]): """Asynchronously set a hash value in Redis for the given key. Args: - client (AsyncRedis): The Redis client instance. + client (AsyncClientOrPipeline): The async Redis client or pipeline instance. key (str): The key under which to store the hash. obj (Dict[str, Any]): The hash to store in Redis. """ - await client.hset(name=key, mapping=obj) # type: ignore + if isinstance(client, (AsyncPipeline, AsyncClusterPipeline)): + client.hset(name=key, mapping=obj) # type: ignore + else: + await client.hset(name=key, mapping=obj) # type: ignore @staticmethod - def _get(client: Redis, key: str) -> Dict[str, Any]: + def _get(client: SyncRedisClient, key: str) -> Dict[str, Any]: """Synchronously retrieve a hash value from Redis for the given key. Args: - client (Redis): The Redis client instance. + client (SyncRedisClient): The Redis client instance. key (str): The key for which to retrieve the hash. Returns: Dict[str, Any]: The retrieved hash from Redis. """ - return client.hgetall(key) + return client.hgetall(key) # type: ignore @staticmethod - async def _aget(client: AsyncRedis, key: str) -> Dict[str, Any]: + async def _aget( + client: AsyncRedisClientOrPipeline, key: str + ) -> Union[AsyncRedisPipeline, Dict[str, Any]]: """Asynchronously retrieve a hash value from Redis for the given key. Args: - client (AsyncRedis): The Redis client instance. + client (AsyncRedisClient): The async Redis client or pipeline instance. key (str): The key for which to retrieve the hash. Returns: Dict[str, Any]: The retrieved hash from Redis. """ - return await client.hgetall(key) + if isinstance(client, (AsyncPipeline, AsyncClusterPipeline)): + return client.hgetall(key) # type: ignore[return-value] + else: + return await client.hgetall(key) # type: ignore[return-value, misc] class JsonStorage(BaseStorage): @@ -525,49 +577,55 @@ class JsonStorage(BaseStorage): """JSON data type for the index""" @staticmethod - def _set(client: Redis, key: str, obj: Dict[str, Any]): + def _set(client: RedisClientOrPipeline, key: str, obj: Dict[str, Any]): """Synchronously set a JSON obj in Redis for the given key. Args: - client (AsyncRedis): The Redis client instance. + client (SyncRedisClient): The Redis client instance. key (str): The key under which to store the JSON obj. obj (Dict[str, Any]): The JSON obj to store in Redis. """ client.json().set(key, "$", obj) @staticmethod - async def _aset(client: AsyncRedis, key: str, obj: Dict[str, Any]): + async def _aset(client: AsyncRedisClientOrPipeline, key: str, obj: Dict[str, Any]): """Asynchronously set a JSON obj in Redis for the given key. Args: - client (AsyncRedis): The Redis client instance. + client (AsyncClientOrPipeline): The async Redis client or pipeline instance. key (str): The key under which to store the JSON obj. obj (Dict[str, Any]): The JSON obj to store in Redis. """ - await client.json().set(key, "$", obj) + if isinstance(client, (AsyncPipeline, AsyncClusterPipeline)): + client.json().set(key, "$", obj) # type: ignore[return-value, misc] + else: + await client.json().set(key, "$", obj) # type: ignore[return-value, misc] @staticmethod - def _get(client: Redis, key: str) -> Dict[str, Any]: + def _get(client: RedisClientOrPipeline, key: str) -> Dict[str, Any]: """Synchronously retrieve a JSON obj from Redis for the given key. Args: - client (AsyncRedis): The Redis client instance. + client (SyncRedisClient): The Redis client instance. key (str): The key for which to retrieve the JSON obj. Returns: Dict[str, Any]: The retrieved JSON obj from Redis. """ - return client.json().get(key) + return client.json().get(key) # type: ignore[return-value, misc] @staticmethod - async def _aget(client: AsyncRedis, key: str) -> Dict[str, Any]: - """Asynchronously retrieve a JSON obj from Redis for the given key. + async def _aget(client: AsyncRedisClientOrPipeline, key: str) -> Dict[str, Any]: + """Asynchronously retrieve a JSON object from Redis for the given key. Args: - client (AsyncRedis): The Redis client instance. - key (str): The key for which to retrieve the JSON obj. + client (AsyncRedisClient): The async Redis client or pipeline instance. + key (str): The key for which to retrieve the JSON object. Returns: - Dict[str, Any]: The retrieved JSON obj from Redis. + Dict[str, Any]: The retrieved JSON object from Redis. """ - return await client.json().get(key) + if isinstance(client, (AsyncPipeline, AsyncClusterPipeline)): + return client.json().get(key) # type: ignore[return-value, misc] + else: + return await client.json().get(key) # type: ignore[return-value, misc] diff --git a/redisvl/query/aggregate.py b/redisvl/query/aggregate.py index 8b4f9003..ee5b0279 100644 --- a/redisvl/query/aggregate.py +++ b/redisvl/query/aggregate.py @@ -117,16 +117,16 @@ def __init__( query_string = self._build_query_string() super().__init__(query_string) - self.scorer(text_scorer) # type: ignore[attr-defined] - self.add_scores() # type: ignore[attr-defined] + self.scorer(text_scorer) + self.add_scores() self.apply( vector_similarity=f"(2 - @{self.DISTANCE_ID})/2", text_score="@__score" ) self.apply(hybrid_score=f"{1-alpha}*@text_score + {alpha}*@vector_similarity") - self.sort_by(Desc("@hybrid_score"), max=num_results) - self.dialect(dialect) # type: ignore[attr-defined] + self.sort_by(Desc("@hybrid_score"), max=num_results) # type: ignore + self.dialect(dialect) if return_fields: - self.load(*return_fields) + self.load(*return_fields) # type: ignore[arg-type] @property def params(self) -> Dict[str, Any]: @@ -135,10 +135,10 @@ def params(self) -> Dict[str, Any]: Returns: Dict[str, Any]: The parameters for the aggregation. """ - if isinstance(self._vector, bytes): - vector = self._vector - else: + if isinstance(self._vector, list): vector = array_to_buffer(self._vector, dtype=self._dtype) + else: + vector = self._vector params = {self.VECTOR_PARAM: vector} diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index 9690fb56..7bdbcdfa 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -1,18 +1,21 @@ import os -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type from warnings import warn -from redis import Redis -from redis.asyncio import Connection as AsyncConnection +from redis import Redis, RedisCluster from redis.asyncio import ConnectionPool as AsyncConnectionPool from redis.asyncio import Redis as AsyncRedis -from redis.asyncio import SSLConnection as AsyncSSLConnection -from redis.connection import AbstractConnection, SSLConnection +from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster +from redis.asyncio.connection import AbstractConnection as AsyncAbstractConnection +from redis.asyncio.connection import Connection as AsyncConnection +from redis.asyncio.connection import SSLConnection as AsyncSSLConnection +from redis.connection import SSLConnection from redis.exceptions import ResponseError from redisvl.exceptions import RedisModuleVersionError -from redisvl.redis.constants import DEFAULT_REQUIRED_MODULES -from redisvl.redis.utils import convert_bytes +from redisvl.redis.constants import DEFAULT_REQUIRED_MODULES, REDIS_URL_ENV_VAR +from redisvl.redis.utils import convert_bytes, is_cluster_url +from redisvl.types import AsyncRedisClient, RedisClient, SyncRedisClient from redisvl.utils.utils import deprecated_function from redisvl.version import __version__ @@ -52,14 +55,11 @@ def unpack_redis_modules(module_list: List[Dict[str, Any]]) -> Dict[str, Any]: def get_address_from_env() -> str: - """Get a redis connection from environment variables. - - Returns: - str: Redis URL - """ - if "REDIS_URL" not in os.environ: - raise ValueError("REDIS_URL env var not set") - return os.environ["REDIS_URL"] + """Get Redis URL from environment variable.""" + redis_url = os.getenv(REDIS_URL_ENV_VAR) + if not redis_url: + raise ValueError(f"{REDIS_URL_ENV_VAR} environment variable not set.") + return redis_url def make_lib_name(*args) -> str: @@ -196,7 +196,7 @@ class RedisConnectionFactory: ) def connect( cls, redis_url: Optional[str] = None, use_async: bool = False, **kwargs - ) -> Union[Redis, AsyncRedis]: + ) -> RedisClient: """Create a connection to the Redis database based on a URL and some connection kwargs. @@ -226,7 +226,7 @@ def get_redis_connection( redis_url: Optional[str] = None, required_modules: Optional[List[Dict[str, Any]]] = None, **kwargs, - ) -> Redis: + ) -> SyncRedisClient: """Creates and returns a synchronous Redis client. Args: @@ -246,12 +246,13 @@ def get_redis_connection( RedisModuleVersionError: If required Redis modules are not installed. """ url = redis_url or get_address_from_env() - client = Redis.from_url(url, **kwargs) - + if is_cluster_url(url, **kwargs): + client = RedisCluster.from_url(url, **kwargs) + else: + client = Redis.from_url(url, **kwargs) RedisConnectionFactory.validate_sync_redis( client, required_modules=required_modules ) - return client @staticmethod @@ -259,7 +260,7 @@ async def _get_aredis_connection( url: Optional[str] = None, required_modules: Optional[List[Dict[str, Any]]] = None, **kwargs, - ) -> AsyncRedis: + ) -> AsyncRedisClient: """Creates and returns an asynchronous Redis client. NOTE: This method is the future form of `get_async_redis_connection` but is @@ -274,7 +275,7 @@ async def _get_aredis_connection( Redis client constructor. Returns: - AsyncRedis: An asynchronous Redis client instance. + AsyncRedisClient: An asynchronous Redis client instance (either AsyncRedis or AsyncRedisCluster). Raises: ValueError: If url is not provided and REDIS_URL environment @@ -282,7 +283,11 @@ async def _get_aredis_connection( RedisModuleVersionError: If required Redis modules are not installed. """ url = url or get_address_from_env() - client = AsyncRedis.from_url(url, **kwargs) + + if is_cluster_url(url, **kwargs): + client = AsyncRedisCluster.from_url(url, **kwargs) + else: + client = AsyncRedis.from_url(url, **kwargs) await RedisConnectionFactory.validate_async_redis( client, required_modules=required_modules @@ -293,7 +298,7 @@ async def _get_aredis_connection( def get_async_redis_connection( url: Optional[str] = None, **kwargs, - ) -> AsyncRedis: + ) -> AsyncRedisClient: """Creates and returns an asynchronous Redis client. Args: @@ -317,16 +322,41 @@ def get_async_redis_connection( return AsyncRedis.from_url(url, **kwargs) @staticmethod - def sync_to_async_redis(redis_client: Redis) -> AsyncRedis: + def get_redis_cluster_connection( + redis_url: Optional[str] = None, + **kwargs, + ) -> RedisCluster: + """Creates and returns a synchronous Redis client for a Redis cluster.""" + url = redis_url or get_address_from_env() + return RedisCluster.from_url(url, **kwargs) + + @staticmethod + def get_async_redis_cluster_connection( + redis_url: Optional[str] = None, + **kwargs, + ) -> AsyncRedisCluster: + """Creates and returns an asynchronous Redis client for a Redis cluster.""" + url = redis_url or get_address_from_env() + return AsyncRedisCluster.from_url(url, **kwargs) + + @staticmethod + def sync_to_async_redis( + redis_client: SyncRedisClient, + ) -> AsyncRedisClient: """Convert a synchronous Redis client to an asynchronous one.""" + if isinstance(redis_client, RedisCluster): + raise ValueError( + "RedisCluster is not supported for sync-to-async conversion." + ) + # pick the right connection class - connection_class: Type[AbstractConnection] = ( + connection_class: Type[AsyncAbstractConnection] = ( AsyncSSLConnection if redis_client.connection_pool.connection_class == SSLConnection else AsyncConnection - ) # type: ignore + ) # make async client - return AsyncRedis.from_pool( # type: ignore + return AsyncRedis.from_pool( AsyncConnectionPool( connection_class=connection_class, **redis_client.connection_pool.connection_kwargs, @@ -334,27 +364,29 @@ def sync_to_async_redis(redis_client: Redis) -> AsyncRedis: ) @staticmethod - def get_modules(client: Redis) -> Dict[str, Any]: + def get_modules(client: SyncRedisClient) -> Dict[str, Any]: return unpack_redis_modules(convert_bytes(client.module_list())) @staticmethod - async def get_modules_async(client: AsyncRedis) -> Dict[str, Any]: + async def get_modules_async(client: AsyncRedisClient) -> Dict[str, Any]: return unpack_redis_modules(convert_bytes(await client.module_list())) @staticmethod def validate_sync_redis( - redis_client: Redis, + redis_client: SyncRedisClient, lib_name: Optional[str] = None, required_modules: Optional[List[Dict[str, Any]]] = None, ) -> None: """Validates the sync Redis client.""" - if not isinstance(redis_client, Redis): - raise TypeError("Invalid Redis client instance") + if not issubclass(type(redis_client), (Redis, RedisCluster)): + raise TypeError( + "Invalid Redis client instance. Must be Redis or RedisCluster." + ) # Set client library name _lib_name = make_lib_name(lib_name) try: - redis_client.client_setinfo("LIB-NAME", _lib_name) # type: ignore + redis_client.client_setinfo("LIB-NAME", _lib_name) except ResponseError: # Fall back to a simple log echo redis_client.echo(_lib_name) @@ -367,15 +399,19 @@ def validate_sync_redis( @staticmethod async def validate_async_redis( - redis_client: AsyncRedis, + redis_client: AsyncRedisClient, lib_name: Optional[str] = None, required_modules: Optional[List[Dict[str, Any]]] = None, ) -> None: """Validates the async Redis client.""" + if not issubclass(type(redis_client), (AsyncRedis, AsyncRedisCluster)): + raise TypeError( + "Invalid async Redis client instance. Must be async Redis or async RedisCluster." + ) # Set client library name _lib_name = make_lib_name(lib_name) try: - await redis_client.client_setinfo("LIB-NAME", _lib_name) # type: ignore + await redis_client.client_setinfo("LIB-NAME", _lib_name) except ResponseError: # Fall back to a simple log echo await redis_client.echo(_lib_name) diff --git a/redisvl/redis/constants.py b/redisvl/redis/constants.py index aeae2541..434a2ff5 100644 --- a/redisvl/redis/constants.py +++ b/redisvl/redis/constants.py @@ -6,3 +6,6 @@ # default tag separator REDIS_TAG_SEPARATOR = "," + + +REDIS_URL_ENV_VAR = "REDIS_URL" diff --git a/redisvl/redis/utils.py b/redisvl/redis/utils.py index 0778c592..7de3b036 100644 --- a/redisvl/redis/utils.py +++ b/redisvl/redis/utils.py @@ -1,5 +1,40 @@ import hashlib -from typing import Any, Dict, List, Optional +import itertools +import time +from typing import Any, Dict, List, Optional, Union + +from redis import RedisCluster +from redis import __version__ as redis_version +from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster +from redis.client import NEVER_DECODE, Pipeline +from redis.commands.helpers import get_protocol_version +from redis.commands.search import AsyncSearch, Search +from redis.commands.search.commands import ( + CREATE_CMD, + MAXTEXTFIELDS, + NOFIELDS, + NOFREQS, + NOHL, + NOOFFSETS, + SEARCH_CMD, + SKIPINITIALSCAN, + STOPWORDS, + TEMPORARY, +) +from redis.commands.search.field import Field + +# Redis 5.x compatibility (6 fixed the import path) +if redis_version.startswith("5"): + from redis.commands.search.indexDefinition import ( # type: ignore[import-untyped] + IndexDefinition, + ) +else: + from redis.commands.search.index_definition import ( # type: ignore[no-redef] + IndexDefinition, + ) + +from redis.commands.search.query import Query +from redis.commands.search.result import Result from redisvl.utils.utils import lazy_import @@ -78,3 +113,257 @@ def hashify(content: str, extras: Optional[Dict[str, Any]] = None) -> str: extra_string = " ".join([str(k) + str(v) for k, v in sorted(extras.items())]) content = content + extra_string return hashlib.sha256(content.encode("utf-8")).hexdigest() + + +def cluster_create_index( + index_name: str, + client: RedisCluster, + fields: List[Field], + no_term_offsets: bool = False, + no_field_flags: bool = False, + stopwords: Optional[List[str]] = None, + definition: Optional[IndexDefinition] = None, + max_text_fields=False, + temporary=None, + no_highlight: bool = False, + no_term_frequencies: bool = False, + skip_initial_scan: bool = False, +): + """ + Creates the search index. The index must not already exist. + + For more information, see https://redis.io/commands/ft.create/ + + Args: + index_name: The name of the index to create. + client: The redis client to use. + fields: A list of Field objects. + no_term_offsets: If `true`, term offsets will not be saved in the index. + no_field_flags: If true, field flags that allow searching in specific fields + will not be saved. + stopwords: If provided, the index will be created with this custom stopword + list. The list can be empty. + definition: If provided, the index will be created with this custom index + definition. + max_text_fields: If true, indexes will be encoded as if there were more than + 32 text fields, allowing for additional fields beyond 32. + temporary: Creates a lightweight temporary index which will expire after the + specified period of inactivity. The internal idle timer is reset + whenever the index is searched or added to. + no_highlight: If true, disables highlighting support. Also implied by + `no_term_offsets`. + no_term_frequencies: If true, term frequencies will not be saved in the + index. + skip_initial_scan: If true, the initial scan and indexing will be skipped. + + """ + args = [CREATE_CMD, index_name] + if definition is not None: + args += definition.args + if max_text_fields: + args.append(MAXTEXTFIELDS) + if temporary is not None and isinstance(temporary, int): + args.append(TEMPORARY) + args.append(str(temporary)) + if no_term_offsets: + args.append(NOOFFSETS) + if no_highlight: + args.append(NOHL) + if no_field_flags: + args.append(NOFIELDS) + if no_term_frequencies: + args.append(NOFREQS) + if skip_initial_scan: + args.append(SKIPINITIALSCAN) + if stopwords is not None and isinstance(stopwords, (list, tuple, set)): + args += [STOPWORDS, str(len(stopwords))] + if len(stopwords) > 0: + args += list(stopwords) + + args.append("SCHEMA") + try: + args += list(itertools.chain(*(f.redis_args() for f in fields))) + except TypeError: + args += fields.redis_args() # type: ignore + + default_node = client.get_default_node() + return client.execute_command(*args, target_nodes=[default_node]) + + +async def async_cluster_create_index( + index_name: str, + client: AsyncRedisCluster, + fields: List[Field], + no_term_offsets: bool = False, + no_field_flags: bool = False, + stopwords: Optional[List[str]] = None, + definition: Optional[IndexDefinition] = None, + max_text_fields=False, + temporary=None, + no_highlight: bool = False, + no_term_frequencies: bool = False, + skip_initial_scan: bool = False, +): + """ + Creates the search index. The index must not already exist. + + For more information, see https://redis.io/commands/ft.create/ + + Args: + index_name: The name of the index to create. + client: The redis client to use. + fields: A list of Field objects. + no_term_offsets: If `true`, term offsets will not be saved in the index. + no_field_flags: If true, field flags that allow searching in specific fields + will not be saved. + stopwords: If provided, the index will be created with this custom stopword + list. The list can be empty. + definition: If provided, the index will be created with this custom index + definition. + max_text_fields: If true, indexes will be encoded as if there were more than + 32 text fields, allowing for additional fields beyond 32. + temporary: Creates a lightweight temporary index which will expire after the + specified period of inactivity. The internal idle timer is reset + whenever the index is searched or added to. + no_highlight: If true, disables highlighting support. Also implied by + `no_term_offsets`. + no_term_frequencies: If true, term frequencies will not be saved in the + index. + skip_initial_scan: If true, the initial scan and indexing will be skipped. + + """ + args = [CREATE_CMD, index_name] + if definition is not None: + args += definition.args + if max_text_fields: + args.append(MAXTEXTFIELDS) + if temporary is not None and isinstance(temporary, int): + args.append(TEMPORARY) + args.append(str(temporary)) + if no_term_offsets: + args.append(NOOFFSETS) + if no_highlight: + args.append(NOHL) + if no_field_flags: + args.append(NOFIELDS) + if no_term_frequencies: + args.append(NOFREQS) + if skip_initial_scan: + args.append(SKIPINITIALSCAN) + if stopwords is not None and isinstance(stopwords, (list, tuple, set)): + args += [STOPWORDS, str(len(stopwords))] + if len(stopwords) > 0: + args += list(stopwords) + + args.append("SCHEMA") + try: + args += list(itertools.chain(*(f.redis_args() for f in fields))) + except TypeError: + args += fields.redis_args() # type: ignore + + default_node = client.get_default_node() + return await default_node.execute_command(*args) + + +# TODO: The return type is incorrect because 5.x doesn't have "ProfileInformation" +def cluster_search( + client: Search, + query: Union[str, Query], + query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None, +) -> Union[Result, Pipeline, Any]: # type: ignore[type-arg] + args, query = client._mk_query_args(query, query_params=query_params) + st = time.monotonic() + + options = {} + if get_protocol_version(client.client) not in ["3", 3]: + options[NEVER_DECODE] = True + + node = client.client.get_default_node() + res = client.execute_command(SEARCH_CMD, *args, **options, target_nodes=[node]) + + if isinstance(res, Pipeline): + return res + + return client._parse_results( + SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0 + ) + + +# TODO: The return type is incorrect because 5.x doesn't have "ProfileInformation" +async def async_cluster_search( + client: AsyncSearch, + query: Union[str, Query], + query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None, +) -> Union[Result, Pipeline, Any]: # type: ignore[type-arg] + args, query = client._mk_query_args(query, query_params=query_params) + st = time.monotonic() + + options = {} + if get_protocol_version(client.client) not in ["3", 3]: + options[NEVER_DECODE] = True + + node = client.client.get_default_node() + res = await client.execute_command( + SEARCH_CMD, *args, **options, target_nodes=[node] + ) + + if isinstance(res, Pipeline): + return res + + return client._parse_results( + SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0 + ) + + +def _extract_hash_tag(key: str) -> str: + """Extract hash tag from key. Returns empty string if no hash tag. + + Args: + key (str): Redis key that may contain a hash tag. + + Returns: + str: The hash tag including braces, or empty string if no hash tag. + """ + start = key.find("{") + if start == -1: + return "" + end = key.find("}", start + 1) + if end == -1: + return "" + return key[start : end + 1] + + +def _keys_share_hash_tag(keys: List[str]) -> bool: + """Check if all keys share the same hash tag for Redis Cluster compatibility. + + Args: + keys (List[str]): List of Redis keys to check. + + Returns: + bool: True if all keys share the same hash tag, False otherwise. + """ + if not keys: + return True + + first_tag = _extract_hash_tag(keys[0]) + return all(_extract_hash_tag(key) == first_tag for key in keys) + + +def is_cluster_url(url: str, **kwargs) -> bool: + """ + Determine if the given URL and/or kwargs indicate a Redis Cluster connection. + + Args: + url (str): The Redis connection URL. + **kwargs: Additional keyword arguments that may indicate cluster usage. + + Returns: + bool: True if the connection should be a cluster, False otherwise. + """ + if "cluster" in kwargs and kwargs["cluster"]: + return True + if url: + # Check if URL contains multiple hosts or has cluster flag + if "," in url or "cluster=true" in url.lower(): + return True + return False diff --git a/redisvl/types.py b/redisvl/types.py new file mode 100644 index 00000000..4fb05c6b --- /dev/null +++ b/redisvl/types.py @@ -0,0 +1,20 @@ +from typing import Union + +from redis import Redis as SyncRedis +from redis.asyncio import Redis as AsyncRedis +from redis.asyncio.client import Pipeline as AsyncPipeline +from redis.asyncio.cluster import ClusterPipeline as AsyncClusterPipeline +from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster +from redis.client import Pipeline as SyncPipeline +from redis.cluster import ClusterPipeline as SyncClusterPipeline +from redis.cluster import RedisCluster as SyncRedisCluster + +SyncRedisClient = Union[SyncRedis, SyncRedisCluster] +AsyncRedisClient = Union[AsyncRedis, AsyncRedisCluster] +RedisClient = Union[SyncRedisClient, AsyncRedisClient] + +SyncRedisPipeline = Union[SyncPipeline, SyncClusterPipeline] +AsyncRedisPipeline = Union[AsyncPipeline, AsyncClusterPipeline] + +RedisClientOrPipeline = Union[SyncRedisClient, SyncRedisPipeline] +AsyncRedisClientOrPipeline = Union[AsyncRedisClient, AsyncRedisPipeline] diff --git a/tests/cluster-compose.yml b/tests/cluster-compose.yml new file mode 100644 index 00000000..8411fbd4 --- /dev/null +++ b/tests/cluster-compose.yml @@ -0,0 +1,199 @@ +services: + redis-node-1: + image: "redis:8.0" + command: + - redis-server + - --port 7001 + - --cluster-enabled yes + - --cluster-config-file nodes.conf + - --cluster-node-timeout 5000 + - --appendonly yes + - --cluster-announce-ip host.docker.internal + - --cluster-announce-port 7001 + - --cluster-announce-bus-port 17001 + - --protected-mode no + ports: + - "7001:7001" + - "17001:17001" + volumes: + - ./data/7001:/data + + redis-node-2: + image: "redis:8.0" + command: + - redis-server + - --port 7002 + - --cluster-enabled yes + - --cluster-config-file nodes.conf + - --cluster-node-timeout 5000 + - --appendonly yes + - --cluster-announce-ip host.docker.internal + - --cluster-announce-port 7002 + - --cluster-announce-bus-port 17002 + - --protected-mode no + ports: + - "7002:7002" + - "17002:17002" + volumes: + - ./data/7002:/data + + redis-node-3: + image: "redis:8.0" + command: + - redis-server + - --port 7003 + - --cluster-enabled yes + - --cluster-config-file nodes.conf + - --cluster-node-timeout 5000 + - --appendonly yes + - --cluster-announce-ip host.docker.internal + - --cluster-announce-port 7003 + - --cluster-announce-bus-port 17003 + - --protected-mode no + ports: + - "7003:7003" + - "17003:17003" + volumes: + - ./data/7003:/data + + redis-node-4: + image: "redis:8.0" + command: + - redis-server + - --port 7004 + - --cluster-enabled yes + - --cluster-config-file nodes.conf + - --cluster-node-timeout 5000 + - --appendonly yes + - --cluster-announce-ip host.docker.internal + - --cluster-announce-port 7004 + - --cluster-announce-bus-port 17004 + - --protected-mode no + ports: + - "7004:7004" + - "17004:17004" + volumes: + - ./data/7004:/data + + redis-node-5: + image: "redis:8.0" + command: + - redis-server + - --port 7005 + - --cluster-enabled yes + - --cluster-config-file nodes.conf + - --cluster-node-timeout 5000 + - --appendonly yes + - --cluster-announce-ip host.docker.internal + - --cluster-announce-port 7005 + - --cluster-announce-bus-port 17005 + - --protected-mode no + ports: + - "7005:7005" + - "17005:17005" + volumes: + - ./data/7005:/data + + redis-node-6: + image: "redis:8.0" + command: + - redis-server + - --port 7006 + - --cluster-enabled yes + - --cluster-config-file nodes.conf + - --cluster-node-timeout 5000 + - --appendonly yes + - --cluster-announce-ip host.docker.internal + - --cluster-announce-port 7006 + - --cluster-announce-bus-port 17006 + - --protected-mode no + ports: + - "7006:7006" + - "17006:17006" + volumes: + - ./data/7006:/data + + redis-cluster-setup: + image: "redis:8.0" + depends_on: + - redis-node-1 + - redis-node-2 + - redis-node-3 + - redis-node-4 + - redis-node-5 + - redis-node-6 + entrypoint: + - sh + - -c + - | + echo "Waiting for Redis nodes to be available..." + sleep 15 + + NODES="redis-node-1:7001 redis-node-2:7002 redis-node-3:7003 redis-node-4:7004 redis-node-5:7005 redis-node-6:7006" + + echo "Force resetting all nodes before cluster creation..." + for NODE_ADDR_PORT in $$NODES; do + NODE_HOST=$$(echo $$NODE_ADDR_PORT | cut -d':' -f1) + NODE_PORT=$$(echo $$NODE_ADDR_PORT | cut -d':' -f2) + echo "Resetting node $$NODE_HOST:$$NODE_PORT" + + # Wait for node to be responsive + retry_count=0 + max_retries=10 + until redis-cli -h $$NODE_HOST -p $$NODE_PORT ping 2>/dev/null | grep -q PONG; do + retry_count=$$((retry_count+1)) + if [ "$$retry_count" -gt "$$max_retries" ]; then + echo "Error: Node $$NODE_HOST:$$NODE_PORT did not respond after $$max_retries retries." + exit 1 # Exit if a node is unresponsive + fi + echo "Waiting for $$NODE_HOST:$$NODE_PORT to respond (attempt $$retry_count/$$max_retries)..." + sleep 3 # Increased sleep between pings + done + + echo "Flushing and hard resetting $$NODE_HOST:$$NODE_PORT" + redis-cli -h $$NODE_HOST -p $$NODE_PORT FLUSHALL || echo "Warning: FLUSHALL failed on $$NODE_HOST:$$NODE_PORT, attempting to continue..." + # Use CLUSTER RESET HARD + redis-cli -h $$NODE_HOST -p $$NODE_PORT CLUSTER RESET HARD || echo "Warning: CLUSTER RESET HARD failed on $$NODE_HOST:$$NODE_PORT, attempting to continue..." + done + echo "Node reset complete." + sleep 5 # Give a moment for resets to settle + + MAX_ATTEMPTS=5 + ATTEMPT=1 + CLUSTER_CREATED=false + + while [ $$ATTEMPT -le $$MAX_ATTEMPTS ]; do + echo "Attempting to create Redis cluster (Attempt $$ATTEMPT/$$MAX_ATTEMPTS)..." + output=$$(echo yes | redis-cli --cluster create \ + $$NODES \ + --cluster-replicas 1 2>&1) + + if echo "$$output" | grep -q "\[OK\] All 16384 slots covered."; then + echo "Cluster created successfully." + CLUSTER_CREATED=true + break + else + echo "Failed to create cluster on attempt $$ATTEMPT." + echo "Output from redis-cli: $$output" + if [ $$ATTEMPT -lt $$MAX_ATTEMPTS ]; then + echo "Retrying in 10 seconds..." + sleep 10 + fi + fi + ATTEMPT=$$((ATTEMPT + 1)) + done + + if [ "$$CLUSTER_CREATED" = "false" ]; then + echo "Failed to create cluster after $$MAX_ATTEMPTS attempts. Exiting." + exit 1 + fi + + echo "Redis cluster setup complete. Container will remain active for health checks." + tail -f /dev/null + healthcheck: + test: > + sh -c "redis-cli -h redis-node-1 -p 7001 cluster info | grep -q 'cluster_state:ok'" + interval: 5s + timeout: 5s + retries: 12 + start_period: 10s diff --git a/tests/conftest.py b/tests/conftest.py index 9995ff88..c4c708cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,6 @@ +import logging import os +import subprocess from datetime import datetime, timezone import pytest @@ -9,6 +11,8 @@ from redisvl.redis.utils import array_to_buffer from redisvl.utils.vectorize import HFTextVectorizer +logger = logging.getLogger(__name__) + @pytest.fixture(scope="session") def worker_id(request): @@ -52,6 +56,109 @@ def redis_container(worker_id): compose.stop() +@pytest.fixture(scope="session") +def redis_cluster_container(worker_id): + project_name = f"redis_test_cluster_{worker_id}" + # Use cwd if not running in GitHub Actions + pwd = os.getcwd() + compose_file = os.path.join( + os.environ.get("GITHUB_WORKSPACE", pwd), "tests", "cluster-compose.yml" + ) + os.environ["COMPOSE_PROJECT_NAME"] = ( + project_name # For docker compose to pick it up if needed + ) + # redis-stack-server comes up without modules in cluster mode, so we hard-code + # the Redis 8 image for now. + os.environ.setdefault("REDIS_IMAGE", "redis:8") + + # The DockerCompose helper isn't working with multiple services because the + # subprocess command returns non-zero exit codes even on successful + # completion. Here, we run the commands manually. + + # First attempt the docker-compose up command and handle its errors directly + docker_cmd = [ + "docker", + "compose", + "-f", + compose_file, + "-p", # Explicitly pass project name + project_name, + "up", + "--wait", # Wait for healthchecks + "-d", # Detach + ] + + try: + result = subprocess.run( + docker_cmd, + capture_output=True, + check=False, # Don't raise exception, we'll handle it ourselves + ) + + if result.returncode != 0: + logger.error(f"Docker Compose up failed with exit code {result.returncode}") + if result.stdout: + logger.error( + f"STDOUT: {result.stdout.decode('utf-8', errors='replace')}" + ) + if result.stderr: + logger.error( + f"STDERR: {result.stderr.decode('utf-8', errors='replace')}" + ) + + # Try to get logs for more details + logger.info("Attempting to fetch container logs...") + try: + logs_result = subprocess.run( + [ + "docker", + "compose", + "-f", + compose_file, + "-p", + project_name, + "logs", + ], + capture_output=True, + text=True, + ) + logger.info("Docker Compose logs:\n%s", logs_result.stdout) + if logs_result.stderr: + logger.error("Docker Compose logs stderr: \n%s", logs_result.stderr) + except Exception as log_e: + logger.error(f"Failed to get Docker Compose logs: {repr(log_e)}") + + # Now raise the exception with the original result + raise subprocess.CalledProcessError( + result.returncode, + docker_cmd, + output=result.stdout, + stderr=result.stderr, + ) + + # If we get here, setup was successful + yield + finally: + # Always clean up + try: + subprocess.run( + [ + "docker", + "compose", + "-f", + compose_file, + "-p", + project_name, + "down", + "-v", # Remove volumes + ], + check=False, # Don't raise on cleanup failure + capture_output=True, + ) + except Exception as e: + logger.error(f"Error during cleanup: {repr(e)}") + + @pytest.fixture(scope="session") def redis_url(redis_container): """ @@ -62,6 +169,12 @@ def redis_url(redis_container): return f"redis://{host}:{port}" +@pytest.fixture(scope="session") +def redis_cluster_url(redis_cluster_container): + # Hard-coded due to Docker issues + return "redis://localhost:7001" + + @pytest.fixture async def async_client(redis_url): """ @@ -80,7 +193,18 @@ def client(redis_url): yield conn -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture +def cluster_client(redis_cluster_url): + """ + A sync Redis client that uses the dynamic `redis_cluster_url`. + """ + conn = RedisConnectionFactory.get_redis_cluster_connection( + redis_url=redis_cluster_url + ) + yield conn + + +@pytest.fixture(scope="session") def hf_vectorizer(): return HFTextVectorizer( model="sentence-transformers/all-mpnet-base-v2", @@ -191,27 +315,44 @@ def pytest_addoption(parser: pytest.Parser) -> None: default=False, help="Run tests that require API keys", ) + parser.addoption( + "--run-cluster-tests", + action="store_true", + default=False, + help="Run tests that require a Redis cluster", + ) def pytest_configure(config: pytest.Config) -> None: config.addinivalue_line( "markers", "requires_api_keys: mark test as requiring API keys" ) + config.addinivalue_line( + "markers", "requires_cluster: mark test as requiring a Redis cluster" + ) def pytest_collection_modifyitems( config: pytest.Config, items: list[pytest.Item] ) -> None: - if config.getoption("--run-api-tests"): - return + # Check each flag independently + run_api_tests = config.getoption("--run-api-tests") + run_cluster_tests = config.getoption("--run-cluster-tests") - # Otherwise skip all tests requiring an API key + # Create skip markers skip_api = pytest.mark.skip( reason="Skipping test because API keys are not provided. Use --run-api-tests to run these tests." ) + skip_cluster = pytest.mark.skip( + reason="Skipping test because Redis cluster is not available. Use --run-cluster-tests to run these tests." + ) + + # Apply skip markers independently based on flags for item in items: - if item.get_closest_marker("requires_api_keys"): + if item.get_closest_marker("requires_api_keys") and not run_api_tests: item.add_marker(skip_api) + if item.get_closest_marker("requires_cluster") and not run_cluster_tests: + item.add_marker(skip_cluster) @pytest.fixture diff --git a/tests/integration/test_aggregation.py b/tests/integration/test_aggregation.py index 222f8b9f..73625cb9 100644 --- a/tests/integration/test_aggregation.py +++ b/tests/integration/test_aggregation.py @@ -1,10 +1,8 @@ import pytest -from redis.commands.search.aggregation import AggregateResult -from redis.commands.search.result import Result from redisvl.index import SearchIndex from redisvl.query import HybridQuery -from redisvl.query.filter import FilterExpression, Geo, GeoRadius, Num, Tag, Text +from redisvl.query.filter import Geo, GeoRadius, Num, Tag, Text from redisvl.redis.connection import compare_versions from redisvl.redis.utils import array_to_buffer diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index 10794726..94da698a 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -166,15 +166,18 @@ def test_search_index_no_prefix(index_schema): @pytest.mark.asyncio async def test_search_index_redis_url(redis_url, index_schema): - async_index = AsyncSearchIndex(schema=index_schema, redis_url=redis_url) - # Client is None until a command is run - assert async_index.client is None + async with AsyncSearchIndex( + schema=index_schema, redis_url=redis_url + ) as async_index: + # Client is None until a command is run + assert async_index.client is None - # Lazily create the client by running a command - await async_index.create(overwrite=True, drop=True) - assert async_index.client + # Lazily create the client by running a command + await async_index.create(overwrite=True, drop=True) + assert async_index.client - await async_index.disconnect() + # After exiting async with, if the index owned the client, it should be disconnected + # and client attribute should be None again by __aexit__ assert async_index.client is None @@ -186,20 +189,25 @@ async def test_search_index_client(async_client, index_schema): @pytest.mark.asyncio async def test_search_index_set_client(client, redis_url, index_schema): - async_index = AsyncSearchIndex(schema=index_schema, redis_url=redis_url) - # Ignore deprecation warnings - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=DeprecationWarning) - await async_index.create(overwrite=True, drop=True) - assert isinstance(async_index.client, AsyncRedis) + # Use async with for the index that owns its initial client via redis_url + async with AsyncSearchIndex( + schema=index_schema, redis_url=redis_url + ) as async_index: + # Ignore deprecation warnings for set_client + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + await async_index.create(overwrite=True, drop=True) + assert isinstance(async_index.client, AsyncRedis) - # Tests deprecated sync -> async conversion behavior - assert isinstance(client, SyncRedis) - await async_index.set_client(client) - assert isinstance(async_index.client, AsyncRedis) + # Tests deprecated sync -> async conversion behavior + assert isinstance(client, SyncRedis) - await async_index.disconnect() - assert async_index.client is None + await async_index.set_client(client) + assert isinstance(async_index.client, AsyncRedis) + + if async_index.client: + await async_index.disconnect() + assert async_index.client is None @pytest.mark.asyncio diff --git a/tests/integration/test_cross_encoder_reranker.py b/tests/integration/test_cross_encoder_reranker.py index 8db57bb8..fa109347 100644 --- a/tests/integration/test_cross_encoder_reranker.py +++ b/tests/integration/test_cross_encoder_reranker.py @@ -1,10 +1,9 @@ import pytest -from sentence_transformers import CrossEncoder from redisvl.utils.rerank.hf_cross_encoder import HFCrossEncoderReranker -@pytest.fixture +@pytest.fixture(scope="session") def reranker(): return HFCrossEncoderReranker() diff --git a/tests/integration/test_embedcache.py b/tests/integration/test_embedcache.py index 45150989..ca2ed301 100644 --- a/tests/integration/test_embedcache.py +++ b/tests/integration/test_embedcache.py @@ -11,10 +11,10 @@ @pytest.fixture -def cache(redis_url): +def cache(redis_url, worker_id): """Basic EmbeddingsCache fixture with cleanup.""" cache_instance = EmbeddingsCache( - name="test_embed_cache", + name=f"test_embed_cache_{worker_id}", redis_url=redis_url, ) yield cache_instance diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index d3e4d34b..9a16f17c 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -1,5 +1,4 @@ import asyncio -import os import warnings from collections import namedtuple from time import sleep, time @@ -15,7 +14,7 @@ from redisvl.utils.vectorize import HFTextVectorizer -@pytest.fixture +@pytest.fixture(scope="session") def vectorizer(): return HFTextVectorizer("redis/langcache-embed-v1") @@ -820,10 +819,11 @@ def test_complex_filters(cache_with_filters): assert len(results) == 1 -def test_cache_index_overwrite(redis_url, worker_id): +def test_cache_index_overwrite(redis_url, worker_id, hf_vectorizer): cache_no_tags = SemanticCache( name=f"test_cache_{worker_id}", redis_url=redis_url, + vectorizer=hf_vectorizer, ) cache_no_tags.store( @@ -853,12 +853,14 @@ def test_cache_index_overwrite(redis_url, worker_id): SemanticCache( name=f"test_cache_{worker_id}", redis_url=redis_url, + vectorizer=hf_vectorizer, filterable_fields=[{"name": "some_tag", "type": "tag"}], ) cache_overwrite = SemanticCache( name=f"test_cache_{worker_id}", redis_url=redis_url, + vectorizer=hf_vectorizer, filterable_fields=[{"name": "some_tag", "type": "tag"}], overwrite=True, ) @@ -870,10 +872,11 @@ def test_cache_index_overwrite(redis_url, worker_id): assert len(response) == 1 -def test_no_key_collision_on_identical_prompts(redis_url, worker_id): +def test_no_key_collision_on_identical_prompts(redis_url, worker_id, hf_vectorizer): private_cache = SemanticCache( name=f"private_cache_{worker_id}", redis_url=redis_url, + vectorizer=hf_vectorizer, filterable_fields=[ {"name": "user_id", "type": "tag"}, {"name": "zip_code", "type": "numeric"}, @@ -1000,9 +1003,11 @@ def test_deprecated_dtype_argument(redis_url, worker_id): @pytest.mark.asyncio -async def test_cache_async_context_manager(redis_url, worker_id): +async def test_cache_async_context_manager(redis_url, worker_id, hf_vectorizer): async with SemanticCache( - name=f"test_cache_async_context_manager_{worker_id}", redis_url=redis_url + name=f"test_cache_async_context_manager_{worker_id}", + redis_url=redis_url, + vectorizer=hf_vectorizer, ) as cache: await cache.astore("test prompt", "test response") assert cache._aindex @@ -1010,11 +1015,14 @@ async def test_cache_async_context_manager(redis_url, worker_id): @pytest.mark.asyncio -async def test_cache_async_context_manager_with_exception(redis_url, worker_id): +async def test_cache_async_context_manager_with_exception( + redis_url, worker_id, hf_vectorizer +): try: async with SemanticCache( name=f"test_cache_async_context_manager_with_exception_{worker_id}", redis_url=redis_url, + vectorizer=hf_vectorizer, ) as cache: await cache.astore("test prompt", "test response") raise ValueError("test") @@ -1024,18 +1032,22 @@ async def test_cache_async_context_manager_with_exception(redis_url, worker_id): @pytest.mark.asyncio -async def test_cache_async_disconnect(redis_url, worker_id): +async def test_cache_async_disconnect(redis_url, worker_id, hf_vectorizer): cache = SemanticCache( - name=f"test_cache_async_disconnect_{worker_id}", redis_url=redis_url + name=f"test_cache_async_disconnect_{worker_id}", + redis_url=redis_url, + vectorizer=hf_vectorizer, ) await cache.astore("test prompt", "test response") await cache.adisconnect() assert cache._aindex is None -def test_cache_disconnect(redis_url, worker_id): +def test_cache_disconnect(redis_url, worker_id, hf_vectorizer): cache = SemanticCache( - name=f"test_cache_disconnect_{worker_id}", redis_url=redis_url + name=f"test_cache_disconnect_{worker_id}", + redis_url=redis_url, + vectorizer=hf_vectorizer, ) cache.store("test prompt", "test response") cache.disconnect() diff --git a/tests/integration/test_message_history.py b/tests/integration/test_message_history.py index d04cb0ab..c23e691e 100644 --- a/tests/integration/test_message_history.py +++ b/tests/integration/test_message_history.py @@ -6,7 +6,6 @@ from redisvl.exceptions import RedisModuleVersionError from redisvl.extensions.constants import ID_FIELD_NAME from redisvl.extensions.message_history import MessageHistory, SemanticMessageHistory -from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer @pytest.fixture @@ -22,8 +21,10 @@ def standard_history(app_name, client): @pytest.fixture -def semantic_history(app_name, client): - history = SemanticMessageHistory(app_name, redis_client=client, overwrite=True) +def semantic_history(app_name, client, hf_vectorizer): + history = SemanticMessageHistory( + app_name, redis_client=client, overwrite=True, vectorizer=hf_vectorizer + ) yield history history.clear() history.delete() @@ -296,19 +297,24 @@ def test_standard_clear(standard_history): # test semantic message history -def test_semantic_specify_client(client): +def test_semantic_specify_client(client, hf_vectorizer): history = SemanticMessageHistory( - name="test_app", session_tag="abc", redis_client=client, overwrite=True + name="test_app", + session_tag="abc", + redis_client=client, + overwrite=True, + vectorizer=hf_vectorizer, ) assert isinstance(history._index.client, type(client)) -def test_semantic_bad_connection_info(): +def test_semantic_bad_connection_info(hf_vectorizer): with pytest.raises(ConnectionError): SemanticMessageHistory( name="test_app", session_tag="abc", redis_url="redis://localhost:6389", + vectorizer=hf_vectorizer, ) @@ -573,7 +579,7 @@ def test_different_vector_dtypes(): pytest.skip("Not using a late enough version of Redis") -def test_bad_dtype_connecting_to_exiting_history(redis_url): +def test_bad_dtype_connecting_to_exiting_history(redis_url, hf_vectorizer): try: history = SemanticMessageHistory( name="float64 history", dtype="float64", redis_url=redis_url diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index 96deea26..82310622 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -1,3 +1,4 @@ +import uuid from datetime import timedelta import pytest @@ -144,13 +145,14 @@ def sorted_range_query(): @pytest.fixture def index(sample_data, redis_url, worker_id): + unique_id = str(uuid.uuid4())[:8] # Use first 8 chars of UUID for brevity # construct a search index from the schema index = SearchIndex.from_dict( { "index": { - "name": "user_index", - "prefix": f"v1_{worker_id}", + "name": f"user_index_{worker_id}_{unique_id}", + "prefix": f"v1_{worker_id}_{unique_id}", "storage_type": "hash", }, "fields": [ @@ -196,13 +198,14 @@ def hash_preprocess(item: dict) -> dict: @pytest.fixture def L2_index(sample_data, redis_url, worker_id): + unique_id = str(uuid.uuid4())[:8] # Use first 8 chars of UUID for brevity # construct a search index from the schema index = SearchIndex.from_dict( { "index": { - "name": "L2_index", - "prefix": f"L2_index_{worker_id}", + "name": f"L2_index_{worker_id}_{unique_id}", + "prefix": f"L2_index_{worker_id}_{unique_id}", "storage_type": "hash", }, "fields": [ diff --git a/tests/integration/test_redis_cluster_support.py b/tests/integration/test_redis_cluster_support.py new file mode 100644 index 00000000..c7134eca --- /dev/null +++ b/tests/integration/test_redis_cluster_support.py @@ -0,0 +1,191 @@ +"""Tests for Redis Cluster support in RedisVL.""" + +import pytest +from redis import Redis +from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster +from redis.cluster import RedisCluster + +from redisvl.extensions.cache.embeddings.embeddings import EmbeddingsCache +from redisvl.extensions.router.semantic import Route, SemanticRouter +from redisvl.index import SearchIndex +from redisvl.index.index import AsyncSearchIndex +from redisvl.query.query import TextQuery +from redisvl.redis.connection import RedisConnectionFactory +from redisvl.schema import IndexSchema + + +@pytest.mark.requires_cluster +def test_sync_client_validation(redis_url, redis_cluster_url): + """Test validation of sync Redis client types.""" + # Test regular Redis client + redis_client = Redis.from_url(redis_url) + RedisConnectionFactory.validate_sync_redis(redis_client) + + # Test with RedisCluster client type + cluster_client = RedisCluster.from_url(redis_cluster_url) + RedisConnectionFactory.validate_sync_redis(cluster_client) + + +@pytest.mark.requires_cluster +@pytest.mark.asyncio +async def test_async_client_validation(redis_cluster_url): + """Test validation of async Redis client types.""" + async_cluster_client = await RedisConnectionFactory._get_aredis_connection( + redis_cluster_url + ) + await RedisConnectionFactory.validate_async_redis(async_cluster_client) + + +@pytest.mark.requires_cluster +@pytest.mark.asyncio +async def test_sync_to_async_conversion_rejects_cluster_client(redis_cluster_url): + """Test that sync-to-async conversion rejects RedisCluster clients.""" + cluster_client = RedisCluster.from_url(redis_cluster_url) + with pytest.raises( + ValueError, match="RedisCluster is not supported for sync-to-async conversion." + ): + RedisConnectionFactory.sync_to_async_redis(cluster_client) + + +@pytest.mark.requires_cluster +def test_search_index_cluster_client(redis_cluster_url): + """Test that SearchIndex correctly accepts RedisCluster clients.""" + # Create a simple schema + schema = IndexSchema.from_dict( + { + "index": {"name": "test_cluster_index", "prefix": "test_cluster"}, + "fields": [ + {"name": "name", "type": "text"}, + {"name": "age", "type": "numeric"}, + ], + } + ) + + cluster_client = RedisCluster.from_url(redis_cluster_url) + index = SearchIndex(schema=schema, redis_client=cluster_client) + index.create(overwrite=True) + index.load([{"name": "test1", "age": 30}]) + results = index.query(TextQuery("test1", "name")) + assert results[0]["name"] == "test1" + index.delete(drop=True) + + +@pytest.mark.requires_cluster +@pytest.mark.asyncio +async def test_async_search_index_client(redis_cluster_url): + """Test that AsyncSearchIndex correctly handles AsyncRedis clients.""" + # Create a simple schema + schema = IndexSchema.from_dict( + { + "index": {"name": "async_test_index", "prefix": "async_test"}, + "fields": [ + {"name": "name", "type": "text"}, + {"name": "age", "type": "numeric"}, + ], + } + ) + + # Test with AsyncRedis client + cluster_client = AsyncRedisCluster.from_url(redis_cluster_url) + index = AsyncSearchIndex(schema=schema, redis_client=cluster_client) + try: + await index.create(overwrite=True) + await index.load([{"name": "async_test", "age": 25}]) + results = await index.query(TextQuery("async_test", "name")) + assert results[0]["name"] == "async_test" + await index.delete(drop=True) + finally: + # Manually close the cluster client to prevent connection leaks + await cluster_client.aclose() + + +@pytest.mark.requires_cluster +@pytest.mark.asyncio +async def test_embeddings_cache_cluster_async(redis_cluster_url): + """Test that EmbeddingsCache correctly handles AsyncRedisCluster clients.""" + cluster_client = RedisConnectionFactory.get_async_redis_cluster_connection( + redis_cluster_url + ) + cache = EmbeddingsCache(async_redis_client=cluster_client) + + try: + await cache.aset( + text="hey", + model_name="test", + embedding=[1, 2, 3], + ) + result = await cache.aget("hey", "test") + assert result is not None + assert result["embedding"] == [1, 2, 3] + await cache.aclear() + finally: + # Manually close the cluster client to prevent connection leaks + await cluster_client.aclose() + + +@pytest.mark.requires_cluster +def test_embeddings_cache_cluster_sync(redis_cluster_url): + """Test that EmbeddingsCache correctly handles RedisCluster clients.""" + cluster_client = RedisCluster.from_url(redis_cluster_url) + cache = EmbeddingsCache(redis_client=cluster_client) + + for i in range(100): + cache.set( + text=f"hey_{i}", + model_name="test", + embedding=[1, 2, 3], + ) + result = cache.get("hey_0", "test") + assert result is not None + assert result["embedding"] == [1, 2, 3] + cache.clear() + + cache.mset( + [ + {"text": "hey_0", "model_name": "test", "embedding": [1, 2, 3]}, + {"text": "hey_1", "model_name": "test", "embedding": [1, 2, 3]}, + ] + ) + result = cache.mget(["hey_0", "hey_1"], "test") + assert result[0] is not None + assert result[1] is not None + assert result[0]["embedding"] == [1, 2, 3] + assert result[1]["embedding"] == [1, 2, 3] + cache.clear() + + +@pytest.mark.requires_cluster +def test_semantic_router_cluster_client(redis_cluster_url, hf_vectorizer): + """Test that SemanticRouter works correctly with RedisCluster clients.""" + routes = [ + Route( + name="General Inquiry", + references=["What are your hours?", "Tell me about your services."], + ), + Route( + name="Technical Support", + references=[ + "I have an issue with my account.", + "My product is broken.", + ], + ), + ] + client = RedisCluster.from_url(redis_cluster_url) + + router_name = "test_cluster_router" + router = SemanticRouter( + name=router_name, + routes=routes, + vectorizer=hf_vectorizer, + redis_client=client, + overwrite=True, + ) + + query_text = "I need help with my login." + matched_route = router(query_text) + + assert matched_route is not None + assert matched_route.name == "Technical Support" + + if router._index and router._index.exists(): + router._index.delete(drop=True) diff --git a/tests/integration/test_semantic_router.py b/tests/integration/test_semantic_router.py index 97d32173..2c162502 100644 --- a/tests/integration/test_semantic_router.py +++ b/tests/integration/test_semantic_router.py @@ -13,7 +13,6 @@ RoutingConfig, ) from redisvl.redis.connection import compare_versions -from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer def get_base_path(): @@ -39,13 +38,14 @@ def routes(): @pytest.fixture -def semantic_router(client, routes): +def semantic_router(client, routes, hf_vectorizer): router = SemanticRouter( name=f"test-router-{str(ULID())}", routes=routes, routing_config=RoutingConfig(max_k=2), redis_client=client, overwrite=False, + vectorizer=hf_vectorizer, ) yield router router.clear() diff --git a/tests/unit/test_error_handling.py b/tests/unit/test_error_handling.py new file mode 100644 index 00000000..8c31e8ac --- /dev/null +++ b/tests/unit/test_error_handling.py @@ -0,0 +1,618 @@ +""" +Unit tests for error handling improvements in RedisVL. + +This module tests the enhanced error handling behavior introduced for: +1. Redis error handling in index operations +2. CROSSSLOT error detection and messaging +3. Connection kwargs validation in BaseCache +4. Router config error handling +5. Cluster compatibility validation +""" + +import asyncio +from collections.abc import Mapping +from unittest.mock import MagicMock, Mock, patch + +import pytest +import redis.exceptions +from redis import Redis +from redis.asyncio import Redis as AsyncRedis +from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster +from redis.cluster import RedisCluster + +from redisvl.exceptions import RedisSearchError +from redisvl.extensions.cache.base import BaseCache +from redisvl.extensions.router.semantic import SemanticRouter +from redisvl.schema import StorageType + + +class TestRedisErrorHandling: + """Test enhanced Redis error handling in index operations.""" + + @patch("redisvl.redis.connection.RedisConnectionFactory.validate_sync_redis") + def test_redis_error_in_create_method(self, mock_validate): + """Test that Redis errors are caught and re-raised with context.""" + from redisvl.index import SearchIndex + from redisvl.schema import IndexSchema + + # Create a mock schema + schema = Mock(spec=IndexSchema) + schema.redis_fields = ["test_field"] + schema.index = Mock() + schema.index.name = "test_index" + schema.index.prefix = "test:" + schema.index.storage_type = StorageType.HASH + + # Create a mock Redis client that raises RedisError + mock_client = Mock(spec=Redis) + mock_client.ft.return_value.create_index.side_effect = ( + redis.exceptions.RedisError("Connection failed") + ) + mock_client.execute_command.return_value = [] + + index = SearchIndex(schema=schema, redis_client=mock_client) + + with pytest.raises(RedisSearchError) as exc_info: + index.create() + + error_msg = str(exc_info.value) + assert "Failed to create index 'test_index' on Redis" in error_msg + assert "Connection failed" in error_msg + + @patch("redisvl.redis.connection.RedisConnectionFactory.validate_sync_redis") + def test_unexpected_error_in_create_method(self, mock_validate): + """Test that unexpected errors are caught and re-raised with context.""" + from redisvl.index import SearchIndex + from redisvl.schema import IndexSchema + + # Create a mock schema + schema = Mock(spec=IndexSchema) + schema.redis_fields = ["test_field"] + schema.index = Mock() + schema.index.name = "test_index" + schema.index.prefix = "test:" + schema.index.storage_type = StorageType.HASH + + # Create a mock Redis client that raises unexpected error + mock_client = Mock(spec=Redis) + mock_client.ft.return_value.create_index.side_effect = ValueError( + "Unexpected error" + ) + mock_client.execute_command.return_value = [] + + index = SearchIndex(schema=schema, redis_client=mock_client) + + with pytest.raises(RedisSearchError) as exc_info: + index.create() + + error_msg = str(exc_info.value) + assert "Unexpected error creating index 'test_index'" in error_msg + assert "Unexpected error" in error_msg + + +class TestCrossSlotErrorHandling: + """Test CROSSSLOT error detection and helpful messaging.""" + + @patch("redisvl.redis.connection.RedisConnectionFactory.validate_sync_redis") + def test_crossslot_error_in_search(self, mock_validate): + """Test that CROSSSLOT errors in search provide helpful guidance.""" + from redisvl.index import SearchIndex + from redisvl.schema import IndexSchema + + # Create a mock schema and index + schema = Mock(spec=IndexSchema) + schema.index = Mock() + schema.index.name = "test_index" + schema.index.storage_type = StorageType.HASH + + mock_client = Mock(spec=Redis) + crossslot_error = redis.exceptions.ResponseError( + "CROSSSLOT Keys in request don't hash to the same slot" + ) + mock_client.ft.return_value.search.side_effect = crossslot_error + + index = SearchIndex(schema=schema, redis_client=mock_client) + + with pytest.raises(RedisSearchError) as exc_info: + index.search("test query") + + error_msg = str(exc_info.value) + assert "Cross-slot error during search" in error_msg + assert "hash tags" in error_msg + + @patch("redisvl.redis.connection.RedisConnectionFactory.validate_sync_redis") + def test_crossslot_error_in_aggregate(self, mock_validate): + """Test that CROSSSLOT errors in aggregate provide helpful guidance.""" + from redisvl.index import SearchIndex + from redisvl.schema import IndexSchema + + # Create a mock schema and index + schema = Mock(spec=IndexSchema) + schema.index = Mock() + schema.index.name = "test_index" + schema.index.storage_type = StorageType.HASH + + mock_client = Mock(spec=Redis) + crossslot_error = redis.exceptions.ResponseError( + "CROSSSLOT Keys in request don't hash to the same slot" + ) + mock_client.ft.return_value.aggregate.side_effect = crossslot_error + + index = SearchIndex(schema=schema, redis_client=mock_client) + + with pytest.raises(RedisSearchError) as exc_info: + index.aggregate("test query") + + error_msg = str(exc_info.value) + assert "Cross-slot error during aggregation" in error_msg + assert "hash tags" in error_msg + + @patch("redisvl.redis.connection.RedisConnectionFactory.validate_sync_redis") + def test_other_redis_error_in_search(self, mock_validate): + """Test that other Redis errors are handled with generic message.""" + from redisvl.index import SearchIndex + from redisvl.schema import IndexSchema + + # Create a mock schema and index + schema = Mock(spec=IndexSchema) + schema.index = Mock() + schema.index.name = "test_index" + schema.index.storage_type = StorageType.HASH + + mock_client = Mock(spec=Redis) + other_error = redis.exceptions.ResponseError("Some other error") + mock_client.ft.return_value.search.side_effect = other_error + + index = SearchIndex(schema=schema, redis_client=mock_client) + + with pytest.raises(RedisSearchError) as exc_info: + index.search("test query") + + error_msg = str(exc_info.value) + assert "Error while searching" in error_msg + assert "Some other error" in error_msg + + +class TestConnectionKwargsValidation: + """Test improved connection_kwargs validation in BaseCache.""" + + @pytest.mark.asyncio + async def test_connection_kwargs_type_error(self): + """Test that invalid connection_kwargs type raises TypeError with helpful message.""" + cache = BaseCache( + name="test_cache", + connection_kwargs="not_a_dict", # type: ignore + ) + + with pytest.raises(TypeError) as exc_info: + await cache._get_async_redis_client() + + error_msg = str(exc_info.value) + assert "Expected `connection_kwargs` to be a dictionary" in error_msg + assert "{'decode_responses': True}" in error_msg + assert "got type: str" in error_msg + + @pytest.mark.asyncio + async def test_connection_kwargs_valid_dict(self): + """Test that valid connection_kwargs work correctly.""" + cache = BaseCache( + name="test_cache", connection_kwargs={"decode_responses": True} + ) + + # Mock the RedisConnectionFactory to avoid actual connection + with patch( + "redisvl.extensions.cache.base.RedisConnectionFactory" + ) as mock_factory: + mock_client = Mock() + mock_factory.get_async_redis_connection.return_value = mock_client + + result = await cache._get_async_redis_client() + assert result == mock_client + mock_factory.get_async_redis_connection.assert_called_once() + + +class TestRouterConfigErrorHandling: + """Test improved router config error handling.""" + + def test_router_config_invalid_type_error(self): + """Test that invalid router config shows actual received value.""" + # This simulates the error that would be raised in SemanticRouter.from_existing + invalid_router_dict = "not_a_dict" + + with pytest.raises(ValueError) as exc_info: + if not isinstance(invalid_router_dict, dict): + raise ValueError( + f"No valid router config found for test_router. Received: {invalid_router_dict!r}" + ) + + error_msg = str(exc_info.value) + assert "Received: 'not_a_dict'" in error_msg + + def test_router_config_none_error(self): + """Test error message when router config is None.""" + invalid_router_dict = None + + with pytest.raises(ValueError) as exc_info: + if not isinstance(invalid_router_dict, dict): + raise ValueError( + f"No valid router config found for test_router. Received: {invalid_router_dict!r}" + ) + + error_msg = str(exc_info.value) + assert "Received: None" in error_msg + + def test_router_config_numeric_error(self): + """Test error message when router config is a number.""" + invalid_router_dict = 42 + + with pytest.raises(ValueError) as exc_info: + if not isinstance(invalid_router_dict, dict): + raise ValueError( + f"No valid router config found for test_router. Received: {invalid_router_dict!r}" + ) + + error_msg = str(exc_info.value) + assert "Received: 42" in error_msg + + +class TestClusterCompatibilityValidation: + """Test cluster compatibility validation for drop_documents.""" + + @patch("redisvl.redis.connection.RedisConnectionFactory.validate_sync_redis") + def test_drop_documents_cluster_validation_success(self, mock_validate): + """Test that documents with same hash tag work in cluster.""" + from redisvl.index import SearchIndex + from redisvl.schema import IndexSchema + + # Create a mock schema and cluster client + schema = Mock(spec=IndexSchema) + schema.index = Mock() + schema.index.prefix = "test" + schema.index.key_separator = ":" + schema.index.storage_type = StorageType.HASH + + mock_cluster_client = Mock(spec=RedisCluster) + mock_cluster_client.delete.return_value = 2 + + index = SearchIndex(schema=schema, redis_client=mock_cluster_client) + + # These IDs will create keys with the same hash tag + ids = ["{user123}:doc1", "{user123}:doc2"] + + result = index.drop_documents(ids) + assert result == 2 + mock_cluster_client.delete.assert_called_once() + + @patch("redisvl.redis.connection.RedisConnectionFactory.validate_sync_redis") + def test_drop_documents_cluster_validation_failure(self, mock_validate): + """Test that documents with different hash tags fail in cluster.""" + from redisvl.index import SearchIndex + from redisvl.schema import IndexSchema + + # Create a mock schema and cluster client + schema = Mock(spec=IndexSchema) + schema.index = Mock() + schema.index.prefix = "test" + schema.index.key_separator = ":" + schema.index.storage_type = StorageType.HASH + + mock_cluster_client = Mock(spec=RedisCluster) + + index = SearchIndex(schema=schema, redis_client=mock_cluster_client) + + # These IDs will create keys with different hash tags + ids = ["{user123}:doc1", "{user456}:doc2"] + + with pytest.raises(ValueError) as exc_info: + index.drop_documents(ids) + + error_msg = str(exc_info.value) + assert "All keys must share a hash tag when using Redis Cluster" in error_msg + + @patch("redisvl.redis.connection.RedisConnectionFactory.validate_sync_redis") + def test_drop_documents_non_cluster_no_validation(self, mock_validate): + """Test that non-cluster clients don't perform hash tag validation.""" + from redisvl.index import SearchIndex + from redisvl.schema import IndexSchema + + # Create a mock schema and regular Redis client + schema = Mock(spec=IndexSchema) + schema.index = Mock() + schema.index.prefix = "test" + schema.index.key_separator = ":" + schema.index.storage_type = StorageType.HASH + + mock_client = Mock(spec=Redis) + mock_client.delete.return_value = 2 + + index = SearchIndex(schema=schema, redis_client=mock_client) + + # These IDs would fail in cluster, but should work in regular Redis + ids = ["{user123}:doc1", "{user456}:doc2"] + + result = index.drop_documents(ids) + assert result == 2 + mock_client.delete.assert_called_once() + + +class TestAsyncErrorHandling: + """Test error handling in async methods.""" + + @patch("redisvl.redis.connection.RedisConnectionFactory.validate_async_redis") + @pytest.mark.asyncio + async def test_async_crossslot_error_in_search(self, mock_validate): + """Test that CROSSSLOT errors in async search provide helpful guidance.""" + from redisvl.index import AsyncSearchIndex + from redisvl.schema import IndexSchema + + # Create a mock schema + schema = Mock(spec=IndexSchema) + schema.index = Mock() + schema.index.name = "test_index" + schema.index.storage_type = StorageType.HASH + + mock_client = Mock(spec=AsyncRedis) + crossslot_error = redis.exceptions.ResponseError( + "CROSSSLOT Keys in request don't hash to the same slot" + ) + mock_client.ft.return_value.search.side_effect = crossslot_error + + index = AsyncSearchIndex(schema=schema, redis_client=mock_client) + + with pytest.raises(RedisSearchError) as exc_info: + await index.search("test query") + + error_msg = str(exc_info.value) + assert "Cross-slot error during search" in error_msg + assert "hash tags" in error_msg + + @patch("redisvl.redis.connection.RedisConnectionFactory.validate_async_redis") + @pytest.mark.asyncio + async def test_async_crossslot_error_in_aggregate(self, mock_validate): + """Test that CROSSSLOT errors in async aggregate provide helpful guidance.""" + from redisvl.index import AsyncSearchIndex + from redisvl.schema import IndexSchema + + # Create a mock schema + schema = Mock(spec=IndexSchema) + schema.index = Mock() + schema.index.name = "test_index" + schema.index.storage_type = StorageType.HASH + + mock_client = Mock(spec=AsyncRedis) + crossslot_error = redis.exceptions.ResponseError( + "CROSSSLOT Keys in request don't hash to the same slot" + ) + mock_client.ft.return_value.aggregate.side_effect = crossslot_error + + index = AsyncSearchIndex(schema=schema, redis_client=mock_client) + + with pytest.raises(RedisSearchError) as exc_info: + await index.aggregate("test query") + + error_msg = str(exc_info.value) + assert "Cross-slot error during aggregation" in error_msg + assert "hash tags" in error_msg + + @patch("redisvl.redis.connection.RedisConnectionFactory.validate_async_redis") + @pytest.mark.asyncio + async def test_async_drop_documents_cluster_validation(self, mock_validate): + """Test async drop_documents cluster validation.""" + from unittest.mock import AsyncMock + + from redisvl.index import AsyncSearchIndex + from redisvl.schema import IndexSchema + + # Create a mock schema and async cluster client + schema = Mock(spec=IndexSchema) + schema.index = Mock() + schema.index.prefix = "test" + schema.index.key_separator = ":" + schema.index.storage_type = StorageType.HASH + + # Create a mock that properly inherits from RedisCluster for isinstance check + mock_cluster_client = Mock() + mock_cluster_client.__class__ = AsyncRedisCluster + mock_cluster_client.delete = AsyncMock(return_value=2) + + index = AsyncSearchIndex(schema=schema, redis_client=mock_cluster_client) + + # These IDs will create keys with different hash tags + ids = ["{user123}:doc1", "{user456}:doc2"] + + with pytest.raises(ValueError) as exc_info: + await index.drop_documents(ids) + + error_msg = str(exc_info.value) + assert "All keys must share a hash tag when using Redis Cluster" in error_msg + + +class TestClusterOperationsErrorHandling: + """Test error handling for cluster operations.""" + + @patch("redisvl.redis.connection.RedisConnectionFactory.validate_sync_redis") + def test_clear_individual_key_deletion_errors(self, mock_validate): + """Test clear method handles individual key deletion errors in cluster.""" + from unittest.mock import patch + + from redisvl.index import SearchIndex + from redisvl.schema import IndexSchema + + # Create a mock schema and cluster client + schema = Mock(spec=IndexSchema) + schema.index = Mock() + schema.index.prefix = "test" + schema.index.key_separator = ":" + schema.index.storage_type = StorageType.HASH + + mock_cluster_client = Mock(spec=RedisCluster) + mock_cluster_client.delete.side_effect = [ + 1, # First key deletion succeeds + redis.exceptions.RedisError("Some cluster error"), # Second fails + 1, # Third succeeds + ] + + # Mock the paginate method to return test data + with patch.object(SearchIndex, "paginate") as mock_paginate: + mock_paginate.return_value = [ + [{"id": "test:key1"}, {"id": "test:key2"}, {"id": "test:key3"}] + ] + + # Create index with mocked client + index = SearchIndex(schema) + index._SearchIndex__redis_client = mock_cluster_client + + # Test that clear handles individual key deletion errors + with patch("redisvl.index.index.logger") as mock_logger: + result = index.clear() + + # Should have attempted to delete all 3 keys + assert mock_cluster_client.delete.call_count == 3 + # Should have logged the error for the failed key + mock_logger.warning.assert_called_once_with( + "Failed to delete key test:key2: Some cluster error" + ) + # Should return count of successfully deleted keys (2 out of 3) + assert result == 2 + + @patch("redisvl.redis.connection.RedisConnectionFactory.validate_async_redis") + @pytest.mark.asyncio + async def test_async_clear_individual_key_deletion_errors(self, mock_validate): + """Test async clear method handles individual key deletion errors in cluster.""" + from unittest.mock import AsyncMock, patch + + from redisvl.index import AsyncSearchIndex + from redisvl.schema import IndexSchema + + # Create a mock schema and async cluster client + schema = Mock(spec=IndexSchema) + schema.index = Mock() + schema.index.prefix = "test" + schema.index.key_separator = ":" + schema.index.storage_type = StorageType.HASH + + mock_cluster_client = Mock(spec=AsyncRedisCluster) + mock_cluster_client.delete = AsyncMock( + side_effect=[ + 1, # First key deletion succeeds + redis.exceptions.RedisError("Some cluster error"), # Second fails + 1, # Third succeeds + ] + ) + + # Mock the paginate method to return test data + async def mock_paginate_generator(*args, **kwargs): + yield [{"id": "test:key1"}, {"id": "test:key2"}, {"id": "test:key3"}] + + with patch.object(AsyncSearchIndex, "paginate", mock_paginate_generator): + # Create index with mocked client + index = AsyncSearchIndex(schema) + index._redis_client = mock_cluster_client + + # Test that clear handles individual key deletion errors + with patch("redisvl.index.index.logger") as mock_logger: + result = await index.clear() + + # Should have attempted to delete all 3 keys + assert mock_cluster_client.delete.call_count == 3 + # Should have logged the error for the failed key + mock_logger.warning.assert_called_once_with( + "Failed to delete key test:key2: Some cluster error" + ) + # Should return count of successfully deleted keys (2 out of 3) + assert result == 2 + + @patch("redisvl.redis.connection.RedisConnectionFactory.validate_sync_redis") + def test_delete_cluster_compatibility(self, mock_validate): + """Test delete method uses clear() for cluster compatibility when drop=True.""" + from unittest.mock import Mock, patch + + from redisvl.index import SearchIndex + from redisvl.schema import IndexSchema + + # Create a mock schema and cluster client + schema = Mock(spec=IndexSchema) + schema.index = Mock() + schema.index.name = "test_index" + + mock_cluster_client = Mock(spec=RedisCluster) + mock_cluster_client.get_default_node.return_value = Mock() + + # Create index with mocked client + index = SearchIndex(schema) + index._SearchIndex__redis_client = mock_cluster_client + + # Test that delete() calls clear() first when drop=True in cluster + with patch.object(index, "clear") as mock_clear: + index.delete(drop=True) + + # Should have called clear() first + mock_clear.assert_called_once() + # Should have called execute_command with just the index name (no DD flag) + mock_cluster_client.execute_command.assert_called_once_with( + "FT.DROPINDEX", + "test_index", + target_nodes=[mock_cluster_client.get_default_node.return_value], + ) + + @patch("redisvl.redis.connection.RedisConnectionFactory.validate_sync_redis") + def test_delete_non_cluster_standard_behavior(self, mock_validate): + """Test delete method uses standard behavior for non-cluster Redis.""" + from unittest.mock import Mock + + from redisvl.index import SearchIndex + from redisvl.schema import IndexSchema + + # Create a mock schema and regular Redis client + schema = Mock(spec=IndexSchema) + schema.index = Mock() + schema.index.name = "test_index" + + mock_redis_client = Mock(spec=Redis) + + # Create index with mocked client + index = SearchIndex(schema) + index._SearchIndex__redis_client = mock_redis_client + + # Test that delete() uses standard behavior for non-cluster + index.delete(drop=True) + + # Should have called execute_command with DD flag + mock_redis_client.execute_command.assert_called_once_with( + "FT.DROPINDEX", "test_index", "DD" + ) + + @patch("redisvl.redis.connection.RedisConnectionFactory.validate_async_redis") + @pytest.mark.asyncio + async def test_async_delete_cluster_compatibility(self, mock_validate): + """Test async delete method uses clear() for cluster compatibility when drop=True.""" + from unittest.mock import AsyncMock, Mock, patch + + from redisvl.index import AsyncSearchIndex + from redisvl.schema import IndexSchema + + # Create a mock schema and async cluster client + schema = Mock(spec=IndexSchema) + schema.index = Mock() + schema.index.name = "test_index" + + mock_cluster_client = Mock(spec=AsyncRedisCluster) + mock_cluster_client.get_default_node.return_value = Mock() + mock_cluster_client.execute_command = AsyncMock() + + # Create index with mocked client + index = AsyncSearchIndex(schema) + index._redis_client = mock_cluster_client + + # Test that delete() calls clear() first when drop=True in cluster + with patch.object(index, "clear", new_callable=AsyncMock) as mock_clear: + await index.delete(drop=True) + + # Should have called clear() first + mock_clear.assert_called_once() + # Should have called execute_command with just the index name (no DD flag) + mock_cluster_client.execute_command.assert_called_once_with( + "FT.DROPINDEX", + "test_index", + target_nodes=[mock_cluster_client.get_default_node.return_value], + ) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 668674d8..d9770f14 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -6,6 +6,7 @@ import pytest from redisvl.redis.utils import ( + _keys_share_hash_tag, array_to_buffer, buffer_to_array, convert_bytes, @@ -740,9 +741,10 @@ def test_attribute_error_with_direct_module_access(self): # Create a simple object with no __getattr__ method class SimpleObject: - pass + def __init__(self, value): + self.value = value - obj = SimpleObject() + obj = SimpleObject(42) # Directly set the _module attribute to our simple object # This bypasses the normal import mechanism @@ -757,3 +759,78 @@ class SimpleObject: "module 'test_direct_module' has no attribute 'nonexistent_attribute'" in str(excinfo.value) ) + + +# Hash tag validation tests for Redis Cluster compatibility +def test_keys_share_hash_tag_same_tags(): + """Test that keys with the same hash tag are considered compatible.""" + keys = ["prefix:{tag1}:key1", "prefix:{tag1}:key2", "prefix:{tag1}:key3"] + assert _keys_share_hash_tag(keys) is True + + +def test_keys_share_hash_tag_different_tags(): + """Test that keys with different hash tags are considered incompatible.""" + keys = ["prefix:{tag1}:key1", "prefix:{tag2}:key2"] + assert _keys_share_hash_tag(keys) is False + + +def test_keys_share_hash_tag_no_tags(): + """Test that keys without hash tags are considered compatible.""" + keys = ["prefix:key1", "prefix:key2", "prefix:key3"] + assert _keys_share_hash_tag(keys) is True + + +def test_keys_share_hash_tag_mixed_tags_and_no_tags(): + """Test that mixing keys with and without hash tags is incompatible.""" + keys = ["prefix:{tag1}:key1", "prefix:key2"] + assert _keys_share_hash_tag(keys) is False + + +def test_keys_share_hash_tag_empty_list(): + """Test that an empty list of keys is considered compatible.""" + assert _keys_share_hash_tag([]) is True + + +def test_keys_share_hash_tag_single_key(): + """Test that a single key is always compatible.""" + assert _keys_share_hash_tag(["prefix:{tag1}:key1"]) is True + assert _keys_share_hash_tag(["prefix:key1"]) is True + + +def test_keys_share_hash_tag_complex_tags(): + """Test with complex hash tag patterns.""" + keys_same = [ + "user:{user123}:profile", + "user:{user123}:settings", + "user:{user123}:history", + ] + assert _keys_share_hash_tag(keys_same) is True + + keys_different = ["user:{user123}:profile", "user:{user456}:profile"] + assert _keys_share_hash_tag(keys_different) is False + + +def test_keys_share_hash_tag_malformed_tags(): + """Test with malformed hash tags (missing closing brace).""" + keys = [ + "prefix:{tag1:key1", # Missing closing brace + "prefix:{tag1:key2", # Missing closing brace + ] + # These should be treated as no hash tags (empty string) + assert _keys_share_hash_tag(keys) is True + + +def test_keys_share_hash_tag_nested_braces(): + """Test with nested braces in hash tags.""" + keys_same = ["prefix:{{nested}tag}:key1", "prefix:{{nested}tag}:key2"] + assert _keys_share_hash_tag(keys_same) is True + + keys_different = ["prefix:{{nested}tag}:key1", "prefix:{{other}tag}:key2"] + assert _keys_share_hash_tag(keys_different) is False + + +def test_keys_share_hash_tag_multiple_braces(): + """Test with multiple sets of braces in a key.""" + keys = ["prefix:{tag1}:middle:{tag2}:key1", "prefix:{tag1}:middle:{tag2}:key2"] + # Should use the first hash tag found + assert _keys_share_hash_tag(keys) is True