From 62db7bf82927722c2ce43945bf3d4006764c1804 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Mon, 2 Jun 2025 16:26:18 +0300 Subject: [PATCH] Fixing mypy errors in redis/commands/search/query.py --- redis/commands/search/query.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/redis/commands/search/query.py b/redis/commands/search/query.py index a8312a2ad2..615e6d10fa 100644 --- a/redis/commands/search/query.py +++ b/redis/commands/search/query.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union from redis.commands.search.dialect import DEFAULT_DIALECT @@ -31,7 +31,7 @@ def __init__(self, query_string: str) -> None: self._with_scores: bool = False self._scorer: Optional[str] = None self._filters: List = list() - self._ids: Optional[List[str]] = None + self._ids: Optional[Tuple[str]] = None self._slop: int = -1 self._timeout: Optional[float] = None self._in_order: bool = False @@ -81,7 +81,7 @@ def return_field( self._return_fields += ("AS", as_field) return self - def _mk_field_list(self, fields: List[str]) -> List: + def _mk_field_list(self, fields: Optional[Union[List[str], str]]) -> List: if not fields: return [] return [fields] if isinstance(fields, str) else list(fields) @@ -126,7 +126,7 @@ def summarize( def highlight( self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None - ) -> None: + ) -> "Query": """ Apply specified markup to matched term(s) within the returned field(s). @@ -187,16 +187,16 @@ def scorer(self, scorer: str) -> "Query": self._scorer = scorer return self - def get_args(self) -> List[str]: + def get_args(self) -> List[Union[str, int, float]]: """Format the redis arguments for this query and return them.""" - args = [self._query_string] + args: List[Union[str, int, float]] = [self._query_string] args += self._get_args_tags() args += self._summarize_fields + self._highlight_fields args += ["LIMIT", self._offset, self._num] return args - def _get_args_tags(self) -> List[str]: - args = [] + def _get_args_tags(self) -> List[Union[str, int, float]]: + args: List[Union[str, int, float]] = [] if self._no_content: args.append("NOCONTENT") if self._fields: @@ -288,14 +288,14 @@ def with_scores(self) -> "Query": self._with_scores = True return self - def limit_fields(self, *fields: List[str]) -> "Query": + def limit_fields(self, *fields: str) -> "Query": """ Limit the search to specific TEXT fields only. - - **fields**: A list of strings, case sensitive field names + - **fields**: Each element should be a string, case sensitive field name from the defined schema. """ - self._fields = fields + self._fields = list(fields) return self def add_filter(self, flt: "Filter") -> "Query": @@ -340,7 +340,7 @@ def dialect(self, dialect: int) -> "Query": class Filter: - def __init__(self, keyword: str, field: str, *args: List[str]) -> None: + def __init__(self, keyword: str, field: str, *args: Union[str, float]) -> None: self.args = [keyword, field] + list(args)