Skip to content

Commit 65623b7

Browse files
Extending the tests to validate more supported scorers for hybrid search. Adding experimental_method annotation to hybrid_search commands. (#3939)
* Extending the tests to validate more supported scorers for hybrid search. Adding experimental_method annotation to hybrid_search commands. * Update redis/commands/search/hybrid_query.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Applying review comments about defined decorators when used with async functions * Applying review comments --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 2b2f2cf commit 65623b7

File tree

6 files changed

+417
-62
lines changed

6 files changed

+417
-62
lines changed

redis/commands/search/commands.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
HybridQuery,
1212
)
1313
from redis.commands.search.hybrid_result import HybridCursorResult, HybridResult
14-
from redis.utils import deprecated_function
14+
from redis.utils import deprecated_function, experimental_method
1515

1616
from ..helpers import get_protocol_version
1717
from ._util import to_string
@@ -560,6 +560,7 @@ def search(
560560
SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0
561561
)
562562

563+
@experimental_method()
563564
def hybrid_search(
564565
self,
565566
query: HybridQuery,
@@ -1053,6 +1054,7 @@ async def search(
10531054
SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0
10541055
)
10551056

1057+
@experimental_method()
10561058
async def hybrid_search(
10571059
self,
10581060
query: HybridQuery,

redis/commands/search/hybrid_query.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ def __init__(
2525
2626
Args:
2727
query_string: The query string.
28-
scorer: The scorer to use. Allowed values are "TFIDF" or "BM25".
28+
scorer: Scoring algorithm for text search query.
29+
Allowed values are "TFIDF", "TFIDF.DOCNORM", "DISMAX", "DOCSCORE",
30+
"BM25", "BM25STD", "BM25STD.TANH", "HAMMING", etc.
31+
For more information about supported scoring algorithms, see
32+
https://redis.io/docs/latest/develop/ai/search-and-query/advanced-concepts/scoring/
2933
yield_score_as: The name of the field to yield the score as.
3034
"""
3135
self._query_string = query_string
@@ -39,9 +43,10 @@ def query_string(self) -> str:
3943
def scorer(self, scorer: str) -> "HybridSearchQuery":
4044
"""
4145
Scoring algorithm for text search query.
42-
Allowed values are "TFIDF", "DISMAX", "DOCSCORE", "BM25", etc.
46+
Allowed values are "TFIDF", "TFIDF.DOCNORM", "DISMAX", "DOCSCORE", "BM25",
47+
"BM25STD", "BM25STD.TANH", "HAMMING", etc.
4348
44-
For more information about supported scroring algorithms,
49+
For more information about supported scoring algorithms,
4550
see https://redis.io/docs/latest/develop/ai/search-and-query/advanced-concepts/scoring/
4651
"""
4752
self._scorer = scorer

redis/utils.py

Lines changed: 117 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
import inspect
23
import logging
34
import textwrap
45
import warnings
@@ -125,12 +126,22 @@ def deprecated_function(reason="", version="", name=None):
125126
"""
126127

127128
def decorator(func):
128-
@wraps(func)
129-
def wrapper(*args, **kwargs):
130-
warn_deprecated(name or func.__name__, reason, version, stacklevel=3)
131-
return func(*args, **kwargs)
129+
if inspect.iscoroutinefunction(func):
130+
# Create async wrapper for async functions
131+
@wraps(func)
132+
async def async_wrapper(*args, **kwargs):
133+
warn_deprecated(name or func.__name__, reason, version, stacklevel=3)
134+
return await func(*args, **kwargs)
135+
136+
return async_wrapper
137+
else:
138+
# Create regular wrapper for sync functions
139+
@wraps(func)
140+
def wrapper(*args, **kwargs):
141+
warn_deprecated(name or func.__name__, reason, version, stacklevel=3)
142+
return func(*args, **kwargs)
132143

133-
return wrapper
144+
return wrapper
134145

135146
return decorator
136147

@@ -158,47 +169,73 @@ def warn_deprecated_arg_usage(
158169
C = TypeVar("C", bound=Callable)
159170

160171

172+
def _get_filterable_args(
173+
func: Callable, args: tuple, kwargs: dict, allowed_args: Optional[List[str]] = None
174+
) -> dict:
175+
"""
176+
Extract arguments from function call that should be checked for deprecation/experimental warnings.
177+
Excludes 'self' and any explicitly allowed args.
178+
"""
179+
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
180+
filterable_args = dict(zip(arg_names, args))
181+
filterable_args.update(kwargs)
182+
filterable_args.pop("self", None)
183+
if allowed_args:
184+
for allowed_arg in allowed_args:
185+
filterable_args.pop(allowed_arg, None)
186+
return filterable_args
187+
188+
161189
def deprecated_args(
162-
args_to_warn: list = ["*"],
163-
allowed_args: list = [],
190+
args_to_warn: Optional[List[str]] = None,
191+
allowed_args: Optional[List[str]] = None,
164192
reason: str = "",
165193
version: str = "",
166194
) -> Callable[[C], C]:
167195
"""
168196
Decorator to mark specified args of a function as deprecated.
169197
If '*' is in args_to_warn, all arguments will be marked as deprecated.
170198
"""
199+
if args_to_warn is None:
200+
args_to_warn = ["*"]
201+
if allowed_args is None:
202+
allowed_args = []
203+
204+
def _check_deprecated_args(func, filterable_args):
205+
"""Check and warn about deprecated arguments."""
206+
for arg in args_to_warn:
207+
if arg == "*" and len(filterable_args) > 0:
208+
warn_deprecated_arg_usage(
209+
list(filterable_args.keys()),
210+
func.__name__,
211+
reason,
212+
version,
213+
stacklevel=5,
214+
)
215+
elif arg in filterable_args:
216+
warn_deprecated_arg_usage(
217+
arg, func.__name__, reason, version, stacklevel=5
218+
)
171219

172220
def decorator(func: C) -> C:
173-
@wraps(func)
174-
def wrapper(*args, **kwargs):
175-
# Get function argument names
176-
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
177-
178-
provided_args = dict(zip(arg_names, args))
179-
provided_args.update(kwargs)
180-
181-
provided_args.pop("self", None)
182-
for allowed_arg in allowed_args:
183-
provided_args.pop(allowed_arg, None)
184-
185-
for arg in args_to_warn:
186-
if arg == "*" and len(provided_args) > 0:
187-
warn_deprecated_arg_usage(
188-
list(provided_args.keys()),
189-
func.__name__,
190-
reason,
191-
version,
192-
stacklevel=3,
193-
)
194-
elif arg in provided_args:
195-
warn_deprecated_arg_usage(
196-
arg, func.__name__, reason, version, stacklevel=3
197-
)
198-
199-
return func(*args, **kwargs)
200-
201-
return wrapper
221+
if inspect.iscoroutinefunction(func):
222+
223+
@wraps(func)
224+
async def async_wrapper(*args, **kwargs):
225+
filterable_args = _get_filterable_args(func, args, kwargs, allowed_args)
226+
_check_deprecated_args(func, filterable_args)
227+
return await func(*args, **kwargs)
228+
229+
return async_wrapper
230+
else:
231+
232+
@wraps(func)
233+
def wrapper(*args, **kwargs):
234+
filterable_args = _get_filterable_args(func, args, kwargs, allowed_args)
235+
_check_deprecated_args(func, filterable_args)
236+
return func(*args, **kwargs)
237+
238+
return wrapper
202239

203240
return decorator
204241

@@ -368,12 +405,22 @@ def experimental_method() -> Callable[[C], C]:
368405
"""
369406

370407
def decorator(func: C) -> C:
371-
@wraps(func)
372-
def wrapper(*args, **kwargs):
373-
warn_experimental(func.__name__, stacklevel=2)
374-
return func(*args, **kwargs)
408+
if inspect.iscoroutinefunction(func):
409+
# Create async wrapper for async functions
410+
@wraps(func)
411+
async def async_wrapper(*args, **kwargs):
412+
warn_experimental(func.__name__, stacklevel=2)
413+
return await func(*args, **kwargs)
414+
415+
return async_wrapper
416+
else:
417+
# Create regular wrapper for sync functions
418+
@wraps(func)
419+
def wrapper(*args, **kwargs):
420+
warn_experimental(func.__name__, stacklevel=2)
421+
return func(*args, **kwargs)
375422

376-
return wrapper
423+
return wrapper
377424

378425
return decorator
379426

@@ -393,32 +440,45 @@ def warn_experimental_arg_usage(
393440

394441

395442
def experimental_args(
396-
args_to_warn: list = ["*"],
443+
args_to_warn: Optional[List[str]] = None,
397444
) -> Callable[[C], C]:
398445
"""
399446
Decorator to mark specified args of a function as experimental.
447+
If '*' is in args_to_warn, all arguments will be marked as experimental.
400448
"""
449+
if args_to_warn is None:
450+
args_to_warn = ["*"]
451+
452+
def _check_experimental_args(func, filterable_args):
453+
"""Check and warn about experimental arguments."""
454+
for arg in args_to_warn:
455+
if arg == "*" and len(filterable_args) > 0:
456+
warn_experimental_arg_usage(
457+
list(filterable_args.keys()), func.__name__, stacklevel=4
458+
)
459+
elif arg in filterable_args:
460+
warn_experimental_arg_usage(arg, func.__name__, stacklevel=4)
401461

402462
def decorator(func: C) -> C:
403-
@wraps(func)
404-
def wrapper(*args, **kwargs):
405-
# Get function argument names
406-
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
463+
if inspect.iscoroutinefunction(func):
407464

408-
provided_args = dict(zip(arg_names, args))
409-
provided_args.update(kwargs)
465+
@wraps(func)
466+
async def async_wrapper(*args, **kwargs):
467+
filterable_args = _get_filterable_args(func, args, kwargs)
468+
if len(filterable_args) > 0:
469+
_check_experimental_args(func, filterable_args)
470+
return await func(*args, **kwargs)
410471

411-
provided_args.pop("self", None)
472+
return async_wrapper
473+
else:
412474

413-
if len(provided_args) == 0:
475+
@wraps(func)
476+
def wrapper(*args, **kwargs):
477+
filterable_args = _get_filterable_args(func, args, kwargs)
478+
if len(filterable_args) > 0:
479+
_check_experimental_args(func, filterable_args)
414480
return func(*args, **kwargs)
415481

416-
for arg in args_to_warn:
417-
if arg in provided_args:
418-
warn_experimental_arg_usage(arg, func.__name__, stacklevel=3)
419-
420-
return func(*args, **kwargs)
421-
422-
return wrapper
482+
return wrapper
423483

424484
return decorator

0 commit comments

Comments
 (0)