|
1 | 1 | import os
|
2 |
| -from typing import Any, Dict, List, Optional, Type |
| 2 | +from typing import Any, Dict, List, Optional, Tuple, Type |
| 3 | +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse |
3 | 4 | from warnings import warn
|
4 | 5 |
|
5 | 6 | from redis import Redis, RedisCluster
|
|
20 | 21 | from redisvl.version import __version__
|
21 | 22 |
|
22 | 23 |
|
| 24 | +def _strip_cluster_from_url_and_kwargs(url: str, **kwargs) -> Tuple[str, Dict[str, Any]]: |
| 25 | + """ |
| 26 | + Strip the 'cluster' parameter from URL query string and kwargs to prevent |
| 27 | + TypeError when calling AsyncRedisCluster.from_url(). |
| 28 | + |
| 29 | + Args: |
| 30 | + url (str): Redis connection URL |
| 31 | + **kwargs: Additional keyword arguments |
| 32 | + |
| 33 | + Returns: |
| 34 | + Tuple[str, Dict[str, Any]]: Modified URL and kwargs with cluster parameter removed |
| 35 | + """ |
| 36 | + # Create a copy of kwargs to avoid modifying the original |
| 37 | + clean_kwargs = kwargs.copy() |
| 38 | + clean_kwargs.pop("cluster", None) |
| 39 | + |
| 40 | + # Parse the URL and remove cluster parameter from query string |
| 41 | + parsed_url = urlparse(url) |
| 42 | + query_params = parse_qs(parsed_url.query) |
| 43 | + |
| 44 | + # Remove cluster parameter if present (case-insensitive) |
| 45 | + query_params.pop("cluster", None) |
| 46 | + query_params.pop("Cluster", None) |
| 47 | + query_params.pop("CLUSTER", None) |
| 48 | + |
| 49 | + # Rebuild the URL without cluster parameter |
| 50 | + new_query = urlencode(query_params, doseq=True) |
| 51 | + clean_url = urlunparse(( |
| 52 | + parsed_url.scheme, |
| 53 | + parsed_url.netloc, |
| 54 | + parsed_url.path, |
| 55 | + parsed_url.params, |
| 56 | + new_query, |
| 57 | + parsed_url.fragment |
| 58 | + )) |
| 59 | + |
| 60 | + return clean_url, clean_kwargs |
| 61 | + |
| 62 | + |
23 | 63 | def compare_versions(version1: str, version2: str):
|
24 | 64 | """
|
25 | 65 | Compare two Redis version strings numerically.
|
@@ -293,7 +333,9 @@ async def _get_aredis_connection(
|
293 | 333 | url = url or get_address_from_env()
|
294 | 334 |
|
295 | 335 | if is_cluster_url(url, **kwargs):
|
296 |
| - client = AsyncRedisCluster.from_url(url, **kwargs) |
| 336 | + # Strip cluster parameter to prevent TypeError in AsyncRedisCluster.from_url() |
| 337 | + clean_url, clean_kwargs = _strip_cluster_from_url_and_kwargs(url, **kwargs) |
| 338 | + client = AsyncRedisCluster.from_url(clean_url, **clean_kwargs) |
297 | 339 | else:
|
298 | 340 | client = AsyncRedis.from_url(url, **kwargs)
|
299 | 341 |
|
@@ -345,7 +387,9 @@ def get_async_redis_cluster_connection(
|
345 | 387 | ) -> AsyncRedisCluster:
|
346 | 388 | """Creates and returns an asynchronous Redis client for a Redis cluster."""
|
347 | 389 | url = redis_url or get_address_from_env()
|
348 |
| - return AsyncRedisCluster.from_url(url, **kwargs) |
| 390 | + # Strip cluster parameter to prevent TypeError in AsyncRedisCluster.from_url() |
| 391 | + clean_url, clean_kwargs = _strip_cluster_from_url_and_kwargs(url, **kwargs) |
| 392 | + return AsyncRedisCluster.from_url(clean_url, **clean_kwargs) |
349 | 393 |
|
350 | 394 | @staticmethod
|
351 | 395 | def sync_to_async_redis(
|
|
0 commit comments