diff --git a/stdlib/re.pyi b/stdlib/re.pyi index 4962ab8edad9..a6ec585d0b1c 100644 --- a/stdlib/re.pyi +++ b/stdlib/re.pyi @@ -65,7 +65,7 @@ class Match(Generic[AnyStr]): @property def re(self) -> Pattern[AnyStr]: ... @overload - def expand(self: Match[str], template: str) -> str: ... + def expand(self, template: AnyStr) -> AnyStr: ... # type: ignore[misc] @overload def expand(self: Match[bytes], template: ReadableBuffer) -> bytes: ... # group() returns "AnyStr" or "AnyStr | None", depending on the pattern. @@ -113,32 +113,32 @@ class Pattern(Generic[AnyStr]): @property def pattern(self) -> AnyStr: ... @overload - def search(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ... + def search(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ... # type: ignore[misc] @overload def search(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... @overload - def match(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ... + def match(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ... # type: ignore[misc] @overload def match(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... @overload - def fullmatch(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ... + def fullmatch(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ... # type: ignore[misc] @overload def fullmatch(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... @overload - def split(self: Pattern[str], string: str, maxsplit: int = ...) -> list[str | Any]: ... + def split(self, string: AnyStr, maxsplit: int = ...) -> list[AnyStr | Any]: ... @overload def split(self: Pattern[bytes], string: ReadableBuffer, maxsplit: int = ...) -> list[bytes | Any]: ... # return type depends on the number of groups in the pattern @overload - def findall(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> list[Any]: ... + def findall(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> list[AnyStr]: ... @overload def findall(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> list[Any]: ... @overload - def finditer(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Iterator[Match[str]]: ... + def finditer(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Iterator[Match[AnyStr]]: ... # type: ignore[misc] @overload def finditer(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Iterator[Match[bytes]]: ... @overload - def sub(self: Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ...) -> str: ... + def sub(self, repl: AnyStr | Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ...) -> AnyStr: ... # type: ignore[misc] @overload def sub( self: Pattern[bytes], @@ -147,7 +147,7 @@ class Pattern(Generic[AnyStr]): count: int = ..., ) -> bytes: ... @overload - def subn(self: Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ...) -> tuple[str, int]: ... + def subn(self, repl: AnyStr | Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ...) -> tuple[AnyStr, int]: ... # type: ignore[misc] @overload def subn( self: Pattern[bytes], diff --git a/test_cases/stdlib/check_re.py b/test_cases/stdlib/check_re.py new file mode 100644 index 000000000000..398c505d5962 --- /dev/null +++ b/test_cases/stdlib/check_re.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import re +import typing as t +from typing_extensions import assert_type + + +def check_search(str_pat: re.Pattern[str], bytes_pat: re.Pattern[bytes]) -> None: + assert_type(str_pat.search("x"), t.Optional[t.Match[str]]) + assert_type(bytes_pat.search(b"x"), t.Optional[t.Match[bytes]]) + assert_type(bytes_pat.search(bytearray(b"x")), t.Optional[t.Match[bytes]]) + + +def check_search_with_AnyStr(pattern: re.Pattern[t.AnyStr], string: t.AnyStr) -> re.Match[t.AnyStr]: + """See issue #9591""" + match = pattern.search(string) + if match is None: + raise ValueError(f"'{string!r}' does not match {pattern!r}") + return match + + +def check_no_ReadableBuffer_false_negatives() -> None: + re.compile("foo").search(bytearray(b"foo")) # type: ignore diff --git a/test_cases/stdlib/typing/check_pattern.py b/test_cases/stdlib/typing/check_pattern.py deleted file mode 100644 index ec5c1c4f6141..000000000000 --- a/test_cases/stdlib/typing/check_pattern.py +++ /dev/null @@ -1,10 +0,0 @@ -from __future__ import annotations - -from typing import Match, Optional, Pattern -from typing_extensions import assert_type - - -def test_search(str_pat: Pattern[str], bytes_pat: Pattern[bytes]) -> None: - assert_type(str_pat.search("x"), Optional[Match[str]]) - assert_type(bytes_pat.search(b"x"), Optional[Match[bytes]]) - assert_type(bytes_pat.search(bytearray(b"x")), Optional[Match[bytes]])