Skip to content

Fixing errors reported by mypy. #3666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions redis/commands/search/query.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

from redis.commands.search.dialect import DEFAULT_DIALECT

Expand Down Expand Up @@ -31,7 +31,7 @@ def __init__(self, query_string: str) -> None:
self._with_scores: bool = False
self._scorer: Optional[str] = None
self._filters: List = list()
self._ids: Optional[List[str]] = None
self._ids: Optional[Tuple[str]] = None
self._slop: int = -1
self._timeout: Optional[float] = None
self._in_order: bool = False
Expand Down Expand Up @@ -81,7 +81,7 @@ def return_field(
self._return_fields += ("AS", as_field)
return self

def _mk_field_list(self, fields: List[str]) -> List:
def _mk_field_list(self, fields: Optional[Union[List[str], str]]) -> List:
if not fields:
return []
return [fields] if isinstance(fields, str) else list(fields)
Expand Down Expand Up @@ -126,7 +126,7 @@ def summarize(

def highlight(
self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None
) -> None:
) -> "Query":
"""
Apply specified markup to matched term(s) within the returned field(s).

Expand Down Expand Up @@ -187,16 +187,16 @@ def scorer(self, scorer: str) -> "Query":
self._scorer = scorer
return self

def get_args(self) -> List[str]:
def get_args(self) -> List[Union[str, int, float]]:
"""Format the redis arguments for this query and return them."""
args = [self._query_string]
args: List[Union[str, int, float]] = [self._query_string]
args += self._get_args_tags()
args += self._summarize_fields + self._highlight_fields
args += ["LIMIT", self._offset, self._num]
return args

def _get_args_tags(self) -> List[str]:
args = []
def _get_args_tags(self) -> List[Union[str, int, float]]:
args: List[Union[str, int, float]] = []
if self._no_content:
args.append("NOCONTENT")
if self._fields:
Expand Down Expand Up @@ -288,14 +288,14 @@ def with_scores(self) -> "Query":
self._with_scores = True
return self

def limit_fields(self, *fields: List[str]) -> "Query":
def limit_fields(self, *fields: str) -> "Query":
"""
Limit the search to specific TEXT fields only.

- **fields**: A list of strings, case sensitive field names
- **fields**: Each element should be a string, case sensitive field name
from the defined schema.
"""
self._fields = fields
self._fields = list(fields)
return self

def add_filter(self, flt: "Filter") -> "Query":
Expand Down Expand Up @@ -340,7 +340,7 @@ def dialect(self, dialect: int) -> "Query":


class Filter:
def __init__(self, keyword: str, field: str, *args: List[str]) -> None:
def __init__(self, keyword: str, field: str, *args: Union[str, float]) -> None:
self.args = [keyword, field] + list(args)


Expand Down
Loading