Skip to content

Commit dc0b63f

Browse files
Accurate overloads for ZipFile.__init__ (#12119)
Co-authored-by: Shantanu <[email protected]>
1 parent e2aea1e commit dc0b63f

File tree

2 files changed

+208
-3
lines changed

2 files changed

+208
-3
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
from __future__ import annotations
2+
3+
import io
4+
import pathlib
5+
import zipfile
6+
from typing import Literal
7+
8+
###
9+
# Tests for `zipfile.ZipFile`
10+
###
11+
12+
p = pathlib.Path("test.zip")
13+
14+
15+
class CustomPathObj:
16+
def __init__(self, path: str) -> None:
17+
self.path = path
18+
19+
def __fspath__(self) -> str:
20+
return self.path
21+
22+
23+
class NonPathObj:
24+
def __init__(self, path: str) -> None:
25+
self.path = path
26+
27+
28+
class ReadableObj:
29+
def seek(self, offset: int, whence: int = 0) -> int:
30+
return 0
31+
32+
def read(self, n: int | None = -1) -> bytes:
33+
return b"test"
34+
35+
36+
class TellableObj:
37+
def tell(self) -> int:
38+
return 0
39+
40+
41+
class WriteableObj:
42+
def close(self) -> None:
43+
pass
44+
45+
def write(self, b: bytes) -> int:
46+
return len(b)
47+
48+
def flush(self) -> None:
49+
pass
50+
51+
52+
class ReadTellableObj(ReadableObj):
53+
def tell(self) -> int:
54+
return 0
55+
56+
57+
class SeekTellObj:
58+
def seek(self, offset: int, whence: int = 0) -> int:
59+
return 0
60+
61+
def tell(self) -> int:
62+
return 0
63+
64+
65+
def write_zip(mode: Literal["r", "w", "x", "a"]) -> None:
66+
# Test any mode with `pathlib.Path`
67+
with zipfile.ZipFile(p, mode) as z:
68+
z.writestr("test.txt", "test")
69+
70+
# Test any mode with `str` path
71+
with zipfile.ZipFile("test.zip", mode) as z:
72+
z.writestr("test.txt", "test")
73+
74+
# Test any mode with `os.PathLike` object
75+
with zipfile.ZipFile(CustomPathObj("test.zip"), mode) as z:
76+
z.writestr("test.txt", "test")
77+
78+
# Non-PathLike object should raise an error
79+
with zipfile.ZipFile(NonPathObj("test.zip"), mode) as z: # type: ignore
80+
z.writestr("test.txt", "test")
81+
82+
# IO[bytes] like-obj should work for any mode.
83+
io_obj = io.BytesIO(b"This is a test")
84+
with zipfile.ZipFile(io_obj, mode) as z:
85+
z.writestr("test.txt", "test")
86+
87+
# Readable object should not work for any mode.
88+
with zipfile.ZipFile(ReadableObj(), mode) as z: # type: ignore
89+
z.writestr("test.txt", "test")
90+
91+
# Readable object should work for "r" mode.
92+
with zipfile.ZipFile(ReadableObj(), "r") as z:
93+
z.writestr("test.txt", "test")
94+
95+
# Readable/tellable object should work for "a" mode.
96+
with zipfile.ZipFile(ReadTellableObj(), "a") as z:
97+
z.writestr("test.txt", "test")
98+
99+
# If it doesn't have 'tell' method, it should raise an error.
100+
with zipfile.ZipFile(ReadableObj(), "a") as z: # type: ignore
101+
z.writestr("test.txt", "test")
102+
103+
# Readable object should not work for "w" mode.
104+
with zipfile.ZipFile(ReadableObj(), "w") as z: # type: ignore
105+
z.writestr("test.txt", "test")
106+
107+
# Tellable object should not work for any mode.
108+
with zipfile.ZipFile(TellableObj(), mode) as z: # type: ignore
109+
z.writestr("test.txt", "test")
110+
111+
# Tellable object shouldn't work for "w" mode.
112+
# As `__del__` will call close.
113+
with zipfile.ZipFile(TellableObj(), "w") as z: # type: ignore
114+
z.writestr("test.txt", "test")
115+
116+
# Writeable object should not work for any mode.
117+
with zipfile.ZipFile(WriteableObj(), mode) as z: # type: ignore
118+
z.writestr("test.txt", "test")
119+
120+
# Writeable object should work for "w" mode.
121+
with zipfile.ZipFile(WriteableObj(), "w") as z:
122+
z.writestr("test.txt", "test")
123+
124+
# Seekable and Tellable object should not work for any mode.
125+
with zipfile.ZipFile(SeekTellObj(), mode) as z: # type: ignore
126+
z.writestr("test.txt", "test")
127+
128+
# Seekable and Tellable object shouldn't work for "w" mode.
129+
# Cause `__del__` will call close.
130+
with zipfile.ZipFile(SeekTellObj(), "w") as z: # type: ignore
131+
z.writestr("test.txt", "test")

stdlib/zipfile/__init__.pyi

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,20 @@ class ZipExtFile(io.BufferedIOBase):
9494
class _Writer(Protocol):
9595
def write(self, s: str, /) -> object: ...
9696

97+
class _ZipReadable(Protocol):
98+
def seek(self, offset: int, whence: int = 0, /) -> int: ...
99+
def read(self, n: int = -1, /) -> bytes: ...
100+
101+
class _ZipTellable(Protocol):
102+
def tell(self) -> int: ...
103+
104+
class _ZipReadableTellable(_ZipReadable, _ZipTellable, Protocol): ...
105+
106+
class _ZipWritable(Protocol):
107+
def flush(self) -> None: ...
108+
def close(self) -> None: ...
109+
def write(self, b: bytes, /) -> int: ...
110+
97111
class ZipFile:
98112
filename: str | None
99113
debug: int
@@ -106,24 +120,50 @@ class ZipFile:
106120
compresslevel: int | None # undocumented
107121
mode: _ZipFileMode # undocumented
108122
pwd: bytes | None # undocumented
123+
# metadata_encoding is new in 3.11
109124
if sys.version_info >= (3, 11):
110125
@overload
111126
def __init__(
112127
self,
113128
file: StrPath | IO[bytes],
129+
mode: _ZipFileMode = "r",
130+
compression: int = 0,
131+
allowZip64: bool = True,
132+
compresslevel: int | None = None,
133+
*,
134+
strict_timestamps: bool = True,
135+
metadata_encoding: str | None = None,
136+
) -> None: ...
137+
# metadata_encoding is only allowed for read mode
138+
@overload
139+
def __init__(
140+
self,
141+
file: StrPath | _ZipReadable,
114142
mode: Literal["r"] = "r",
115143
compression: int = 0,
116144
allowZip64: bool = True,
117145
compresslevel: int | None = None,
118146
*,
119147
strict_timestamps: bool = True,
120-
metadata_encoding: str | None,
148+
metadata_encoding: str | None = None,
121149
) -> None: ...
122150
@overload
123151
def __init__(
124152
self,
125-
file: StrPath | IO[bytes],
126-
mode: _ZipFileMode = "r",
153+
file: StrPath | _ZipWritable,
154+
mode: Literal["w", "x"] = ...,
155+
compression: int = 0,
156+
allowZip64: bool = True,
157+
compresslevel: int | None = None,
158+
*,
159+
strict_timestamps: bool = True,
160+
metadata_encoding: None = None,
161+
) -> None: ...
162+
@overload
163+
def __init__(
164+
self,
165+
file: StrPath | _ZipReadableTellable,
166+
mode: Literal["a"] = ...,
127167
compression: int = 0,
128168
allowZip64: bool = True,
129169
compresslevel: int | None = None,
@@ -132,6 +172,7 @@ class ZipFile:
132172
metadata_encoding: None = None,
133173
) -> None: ...
134174
else:
175+
@overload
135176
def __init__(
136177
self,
137178
file: StrPath | IO[bytes],
@@ -142,6 +183,39 @@ class ZipFile:
142183
*,
143184
strict_timestamps: bool = True,
144185
) -> None: ...
186+
@overload
187+
def __init__(
188+
self,
189+
file: StrPath | _ZipReadable,
190+
mode: Literal["r"] = "r",
191+
compression: int = 0,
192+
allowZip64: bool = True,
193+
compresslevel: int | None = None,
194+
*,
195+
strict_timestamps: bool = True,
196+
) -> None: ...
197+
@overload
198+
def __init__(
199+
self,
200+
file: StrPath | _ZipWritable,
201+
mode: Literal["w", "x"] = ...,
202+
compression: int = 0,
203+
allowZip64: bool = True,
204+
compresslevel: int | None = None,
205+
*,
206+
strict_timestamps: bool = True,
207+
) -> None: ...
208+
@overload
209+
def __init__(
210+
self,
211+
file: StrPath | _ZipReadableTellable,
212+
mode: Literal["a"] = ...,
213+
compression: int = 0,
214+
allowZip64: bool = True,
215+
compresslevel: int | None = None,
216+
*,
217+
strict_timestamps: bool = True,
218+
) -> None: ...
145219

146220
def __enter__(self) -> Self: ...
147221
def __exit__(

0 commit comments

Comments
 (0)