Skip to content

Commit 5d713d6

Browse files
update ranx imports and fix for mypy
1 parent 85f929b commit 5d713d6

File tree

7 files changed

+3318
-2996
lines changed

7 files changed

+3318
-2996
lines changed

poetry.lock

Lines changed: 3269 additions & 2970 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,34 +24,33 @@ python = ">=3.9,<3.14"
2424
numpy = ">=1.26.0,<3"
2525
pyyaml = ">=5.4,<7.0"
2626
redis = ">=5.0,<7.0"
27-
pydantic = "^2"
27+
pydantic = ">=2,<3"
2828
tenacity = ">=8.2.2"
2929
ml-dtypes = ">=0.4.0,<1.0.0"
30-
python-ulid = "^3.0.0"
30+
python-ulid = ">=3.0.0"
31+
jsonpath-ng = ">=1.5.0"
3132
nltk = { version = "^3.8.1", optional = true }
32-
jsonpath-ng = "^1.5.0"
3333
openai = { version = "^1.13.0", optional = true }
34-
sentence-transformers = { version = "^3.4.0", optional = true }
3534
google-cloud-aiplatform = { version = "^1.26", optional = true }
3635
protobuf = { version = "^5.29.1", optional = true }
3736
cohere = { version = ">=4.44", optional = true }
3837
mistralai = { version = ">=1.0.0", optional = true }
3938
voyageai = { version = ">=0.2.2", optional = true }
40-
ranx = { version = "^0.3.0", python=">=3.10", optional = true }
41-
boto3 = {version = "1.36.0", optional = true, extras = ["bedrock"]}
42-
scipy = {version = "^1.15", optional = true, python=">=3.10"}
43-
39+
sentence-transformers = { version = "^3.4.0", optional = true }
40+
scipy = [
41+
{ version = ">=1.9.0,<1.14", python = "<3.10", optional = true },
42+
{ version = ">=1.14.0,<1.16", python = ">=3.10", optional = true }
43+
]
44+
ranx = {version = "^0.3.0", optional = true}
4445

4546
[tool.poetry.extras]
4647
openai = ["openai"]
47-
sentence-transformers = ["sentence-transformers", "scipy"]
48-
vertexai = ["google-cloud-aiplatform", "protobuf"]
48+
vertexai = ["google-cloud-aiplatform"]
4949
cohere = ["cohere"]
5050
mistralai = ["mistralai"]
5151
voyageai = ["voyageai"]
52-
ranx = ["ranx", "scipy"]
53-
bedrock = ["boto3"]
5452
nltk = ["nltk"]
53+
sentence-transformers = ["sentence-transformers"]
5554

5655
[tool.poetry.group.dev.dependencies]
5756
black = "^25.1.0"
@@ -61,7 +60,7 @@ pytest = "^8.1.1"
6160
pytest-asyncio = "^0.23.6"
6261
pytest-xdist = {extras = ["psutil"], version = "^3.6.1"}
6362
pre-commit = "^4.1.0"
64-
mypy = "1.9.0"
63+
mypy = "^1.11.0"
6564
nbval = "^0.11.0"
6665
types-pyyaml = "*"
6766
types-pyopenssl = "*"

redisvl/redis/connection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,9 @@ def sync_to_async_redis(
357357
"RedisCluster is not supported for sync-to-async conversion."
358358
)
359359

360+
# At this point, redis_client is guaranteed to be Redis type
361+
assert isinstance(redis_client, Redis) # Type narrowing for MyPy
362+
360363
# pick the right connection class
361364
connection_class: Type[AsyncAbstractConnection] = (
362365
AsyncSSLConnection

redisvl/utils/optimize/cache.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1-
from typing import Any, Callable, Dict, List
1+
from typing import TYPE_CHECKING, Any, Callable, Dict, List
22

33
from redisvl.utils.utils import lazy_import
44

5+
if TYPE_CHECKING:
6+
from ranx import Qrels, Run, evaluate
7+
else:
8+
Qrels = lazy_import("ranx.Qrels")
9+
Run = lazy_import("ranx.Run")
10+
evaluate = lazy_import("ranx.evaluate")
11+
512
np = lazy_import("numpy")
6-
from ranx import Qrels, Run, evaluate
713

814
from redisvl.extensions.cache.llm.semantic import SemanticCache
915
from redisvl.query import RangeQuery
@@ -12,7 +18,7 @@
1218
from redisvl.utils.optimize.utils import NULL_RESPONSE_KEY, _format_qrels
1319

1420

15-
def _generate_run_cache(test_data: List[LabeledData], threshold: float) -> Run:
21+
def _generate_run_cache(test_data: List[LabeledData], threshold: float) -> "Run":
1622
"""Format observed data for evaluation with ranx"""
1723
run_dict: Dict[str, Dict[str, int]] = {}
1824

@@ -32,7 +38,7 @@ def _generate_run_cache(test_data: List[LabeledData], threshold: float) -> Run:
3238

3339

3440
def _eval_cache(
35-
test_data: List[LabeledData], threshold: float, qrels: Qrels, metric: str
41+
test_data: List[LabeledData], threshold: float, qrels: "Qrels", metric: str
3642
) -> float:
3743
"""Formats run data and evaluates supported metric"""
3844
run = _generate_run_cache(test_data, threshold)

redisvl/utils/optimize/router.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
import random
2-
from typing import Any, Callable, Dict, List
2+
from typing import TYPE_CHECKING, Any, Callable, Dict, List
33

44
from redisvl.utils.utils import lazy_import
55

6+
if TYPE_CHECKING:
7+
from ranx import Qrels, Run, evaluate
8+
else:
9+
Qrels = lazy_import("ranx.Qrels")
10+
Run = lazy_import("ranx.Run")
11+
evaluate = lazy_import("ranx.evaluate")
12+
613
np = lazy_import("numpy")
7-
from ranx import Qrels, Run, evaluate
814

915
from redisvl.extensions.router.semantic import SemanticRouter
1016
from redisvl.utils.optimize.base import BaseThresholdOptimizer, EvalMetric
1117
from redisvl.utils.optimize.schema import LabeledData
1218
from redisvl.utils.optimize.utils import NULL_RESPONSE_KEY, _format_qrels
1319

1420

15-
def _generate_run_router(test_data: List[LabeledData], router: SemanticRouter) -> Run:
21+
def _generate_run_router(test_data: List[LabeledData], router: SemanticRouter) -> "Run":
1622
"""Format router results into format for ranx Run"""
1723
run_dict: Dict[Any, Any] = {}
1824

@@ -28,7 +34,10 @@ def _generate_run_router(test_data: List[LabeledData], router: SemanticRouter) -
2834

2935

3036
def _eval_router(
31-
router: SemanticRouter, test_data: List[LabeledData], qrels: Qrels, eval_metric: str
37+
router: SemanticRouter,
38+
test_data: List[LabeledData],
39+
qrels: "Qrels",
40+
eval_metric: str,
3241
) -> float:
3342
"""Evaluate acceptable metric given run and qrels data"""
3443
run = _generate_run_router(test_data, router)
@@ -58,7 +67,7 @@ def _router_random_search(
5867
def _random_search_opt_router(
5968
router: SemanticRouter,
6069
test_data: List[LabeledData],
61-
qrels: Qrels,
70+
qrels: "Qrels",
6271
eval_metric: EvalMetric,
6372
**kwargs: Any,
6473
):

redisvl/utils/optimize/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1-
from typing import List
1+
from typing import TYPE_CHECKING, List
22

33
from redisvl.utils.utils import lazy_import
44

5+
if TYPE_CHECKING:
6+
from ranx import Qrels
7+
else:
8+
Qrels = lazy_import("ranx.Qrels")
9+
510
np = lazy_import("numpy")
6-
from ranx import Qrels
711

812
from redisvl.utils.optimize.schema import LabeledData
913

1014
NULL_RESPONSE_KEY = "no_match"
1115

1216

13-
def _format_qrels(test_data: List[LabeledData]) -> Qrels:
17+
def _format_qrels(test_data: List[LabeledData]) -> "Qrels":
1418
"""Utility function for creating qrels for evaluation with ranx"""
1519
qrels_dict = {}
1620

tests/unit/test_threshold_optimizer_utility.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
if sys.version_info.major == 3 and sys.version_info.minor < 10:
66
pytest.skip("Test requires Python 3.10 or higher", allow_module_level=True)
77

8-
from ranx import evaluate
8+
from redisvl.utils.utils import lazy_import
9+
10+
evaluate = lazy_import("ranx.evaluate")
911

1012
from redisvl.utils.optimize import LabeledData
1113
from redisvl.utils.optimize.cache import _generate_run_cache

0 commit comments

Comments
 (0)