Skip to content

Commit 7f986bd

Browse files
authored
Add more overloads to the re stubs to help out pyright (#9592)
1 parent 18c4661 commit 7f986bd

File tree

3 files changed

+51
-17
lines changed

3 files changed

+51
-17
lines changed

stdlib/re.pyi

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ class Match(Generic[AnyStr]):
6767
@overload
6868
def expand(self: Match[str], template: str) -> str: ...
6969
@overload
70-
def expand(self: Match[bytes], template: ReadableBuffer) -> bytes: ...
70+
def expand(self: Match[bytes], template: ReadableBuffer) -> bytes: ... # type: ignore[misc]
71+
@overload
72+
def expand(self, template: AnyStr) -> AnyStr: ...
7173
# group() returns "AnyStr" or "AnyStr | None", depending on the pattern.
7274
@overload
7375
def group(self, __group: Literal[0] = ...) -> AnyStr: ...
@@ -115,46 +117,62 @@ class Pattern(Generic[AnyStr]):
115117
@overload
116118
def search(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ...
117119
@overload
118-
def search(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ...
120+
def search(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... # type: ignore[misc]
121+
@overload
122+
def search(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ...
119123
@overload
120124
def match(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ...
121125
@overload
122-
def match(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ...
126+
def match(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... # type: ignore[misc]
127+
@overload
128+
def match(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ...
123129
@overload
124130
def fullmatch(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ...
125131
@overload
126-
def fullmatch(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ...
132+
def fullmatch(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... # type: ignore[misc]
133+
@overload
134+
def fullmatch(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ...
127135
@overload
128136
def split(self: Pattern[str], string: str, maxsplit: int = ...) -> list[str | Any]: ...
129137
@overload
130138
def split(self: Pattern[bytes], string: ReadableBuffer, maxsplit: int = ...) -> list[bytes | Any]: ...
139+
@overload
140+
def split(self, string: AnyStr, maxsplit: int = ...) -> list[AnyStr | Any]: ...
131141
# return type depends on the number of groups in the pattern
132142
@overload
133143
def findall(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> list[Any]: ...
134144
@overload
135145
def findall(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> list[Any]: ...
136146
@overload
147+
def findall(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> list[AnyStr]: ...
148+
@overload
137149
def finditer(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Iterator[Match[str]]: ...
138150
@overload
139-
def finditer(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Iterator[Match[bytes]]: ...
151+
def finditer(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Iterator[Match[bytes]]: ... # type: ignore[misc]
152+
@overload
153+
def finditer(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Iterator[Match[AnyStr]]: ...
140154
@overload
141155
def sub(self: Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ...) -> str: ...
142156
@overload
143-
def sub(
157+
def sub( # type: ignore[misc]
144158
self: Pattern[bytes],
145159
repl: ReadableBuffer | Callable[[Match[bytes]], ReadableBuffer],
146160
string: ReadableBuffer,
147161
count: int = ...,
148162
) -> bytes: ...
149163
@overload
164+
def sub(self, repl: AnyStr | Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ...) -> AnyStr: ...
165+
@overload
150166
def subn(self: Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ...) -> tuple[str, int]: ...
151167
@overload
152-
def subn(
168+
def subn( # type: ignore[misc]
153169
self: Pattern[bytes],
154170
repl: ReadableBuffer | Callable[[Match[bytes]], ReadableBuffer],
155171
string: ReadableBuffer,
156172
count: int = ...,
157173
) -> tuple[bytes, int]: ...
174+
@overload
175+
def subn(self, repl: AnyStr | Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ...) -> tuple[AnyStr, int]: ...
158176
def __copy__(self) -> Pattern[AnyStr]: ...
159177
def __deepcopy__(self, __memo: Any) -> Pattern[AnyStr]: ...
160178
if sys.version_info >= (3, 9):

test_cases/stdlib/check_re.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from __future__ import annotations
2+
3+
import mmap
4+
import re
5+
import typing as t
6+
from typing_extensions import assert_type
7+
8+
9+
def check_search(str_pat: re.Pattern[str], bytes_pat: re.Pattern[bytes]) -> None:
10+
assert_type(str_pat.search("x"), t.Optional[t.Match[str]])
11+
assert_type(bytes_pat.search(b"x"), t.Optional[t.Match[bytes]])
12+
assert_type(bytes_pat.search(bytearray(b"x")), t.Optional[t.Match[bytes]])
13+
assert_type(bytes_pat.search(mmap.mmap(0, 10)), t.Optional[t.Match[bytes]])
14+
15+
16+
def check_search_with_AnyStr(pattern: re.Pattern[t.AnyStr], string: t.AnyStr) -> re.Match[t.AnyStr]:
17+
"""See issue #9591"""
18+
match = pattern.search(string)
19+
if match is None:
20+
raise ValueError(f"'{string!r}' does not match {pattern!r}")
21+
return match
22+
23+
24+
def check_no_ReadableBuffer_false_negatives() -> None:
25+
re.compile("foo").search(bytearray(b"foo")) # type: ignore
26+
re.compile("foo").search(mmap.mmap(0, 10)) # type: ignore

test_cases/stdlib/typing/check_pattern.py

Lines changed: 0 additions & 10 deletions
This file was deleted.

0 commit comments

Comments
 (0)