Skip to content

Accurate overloads for ZipFile.__init__ #12119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Aug 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions stdlib/@tests/test_cases/check_zipfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from __future__ import annotations

import io
import pathlib
import zipfile
from typing import Literal

###
# Tests for `zipfile.ZipFile`
###

p = pathlib.Path("test.zip")


class CustomPathObj:
def __init__(self, path: str) -> None:
self.path = path

def __fspath__(self) -> str:
return self.path


class NonPathObj:
def __init__(self, path: str) -> None:
self.path = path


class ReadableObj:
def seek(self, offset: int, whence: int = 0) -> int:
return 0

def read(self, n: int | None = -1) -> bytes:
return b"test"


class TellableObj:
def tell(self) -> int:
return 0


class WriteableObj:
def close(self) -> None:
pass

def write(self, b: bytes) -> int:
return len(b)

def flush(self) -> None:
pass


class ReadTellableObj(ReadableObj):
def tell(self) -> int:
return 0


class SeekTellObj:
def seek(self, offset: int, whence: int = 0) -> int:
return 0

def tell(self) -> int:
return 0


def write_zip(mode: Literal["r", "w", "x", "a"]) -> None:
# Test any mode with `pathlib.Path`
with zipfile.ZipFile(p, mode) as z:
z.writestr("test.txt", "test")

# Test any mode with `str` path
with zipfile.ZipFile("test.zip", mode) as z:
z.writestr("test.txt", "test")

# Test any mode with `os.PathLike` object
with zipfile.ZipFile(CustomPathObj("test.zip"), mode) as z:
z.writestr("test.txt", "test")

# Non-PathLike object should raise an error
with zipfile.ZipFile(NonPathObj("test.zip"), mode) as z: # type: ignore
z.writestr("test.txt", "test")

# IO[bytes] like-obj should work for any mode.
io_obj = io.BytesIO(b"This is a test")
with zipfile.ZipFile(io_obj, mode) as z:
z.writestr("test.txt", "test")

# Readable object should not work for any mode.
with zipfile.ZipFile(ReadableObj(), mode) as z: # type: ignore
z.writestr("test.txt", "test")

# Readable object should work for "r" mode.
with zipfile.ZipFile(ReadableObj(), "r") as z:
z.writestr("test.txt", "test")

# Readable/tellable object should work for "a" mode.
with zipfile.ZipFile(ReadTellableObj(), "a") as z:
z.writestr("test.txt", "test")

# If it doesn't have 'tell' method, it should raise an error.
with zipfile.ZipFile(ReadableObj(), "a") as z: # type: ignore
z.writestr("test.txt", "test")

# Readable object should not work for "w" mode.
with zipfile.ZipFile(ReadableObj(), "w") as z: # type: ignore
z.writestr("test.txt", "test")

# Tellable object should not work for any mode.
with zipfile.ZipFile(TellableObj(), mode) as z: # type: ignore
z.writestr("test.txt", "test")

# Tellable object shouldn't work for "w" mode.
# As `__del__` will call close.
with zipfile.ZipFile(TellableObj(), "w") as z: # type: ignore
z.writestr("test.txt", "test")

# Writeable object should not work for any mode.
with zipfile.ZipFile(WriteableObj(), mode) as z: # type: ignore
z.writestr("test.txt", "test")

# Writeable object should work for "w" mode.
with zipfile.ZipFile(WriteableObj(), "w") as z:
z.writestr("test.txt", "test")

# Seekable and Tellable object should not work for any mode.
with zipfile.ZipFile(SeekTellObj(), mode) as z: # type: ignore
z.writestr("test.txt", "test")

# Seekable and Tellable object shouldn't work for "w" mode.
# Cause `__del__` will call close.
with zipfile.ZipFile(SeekTellObj(), "w") as z: # type: ignore
z.writestr("test.txt", "test")
80 changes: 77 additions & 3 deletions stdlib/zipfile/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,20 @@ class ZipExtFile(io.BufferedIOBase):
class _Writer(Protocol):
def write(self, s: str, /) -> object: ...

class _ZipReadable(Protocol):
def seek(self, offset: int, whence: int = 0, /) -> int: ...
def read(self, n: int = -1, /) -> bytes: ...

class _ZipTellable(Protocol):
def tell(self) -> int: ...

class _ZipReadableTellable(_ZipReadable, _ZipTellable, Protocol): ...

class _ZipWritable(Protocol):
def flush(self) -> None: ...
def close(self) -> None: ...
def write(self, b: bytes, /) -> int: ...

class ZipFile:
filename: str | None
debug: int
Expand All @@ -106,24 +120,50 @@ class ZipFile:
compresslevel: int | None # undocumented
mode: _ZipFileMode # undocumented
pwd: bytes | None # undocumented
# metadata_encoding is new in 3.11
if sys.version_info >= (3, 11):
@overload
def __init__(
self,
file: StrPath | IO[bytes],
mode: _ZipFileMode = "r",
compression: int = 0,
allowZip64: bool = True,
compresslevel: int | None = None,
*,
strict_timestamps: bool = True,
metadata_encoding: str | None = None,
) -> None: ...
Comment on lines 126 to +136
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't all following overloads shadowed by this one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No - for example, IO[bytes] is broader than _ZipReadable, which will lead to the second overload getting chosen if you just an object that is readable.

Copy link
Contributor Author

@max-muoto max-muoto Jul 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for _ZipFileMode, if file doesn't match we'll end up skipping over the overload. The tests show that this is working properly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IO[bytes] is broader than _ZipReadable

But that's exactly the problem. It's my understanding that overloads are processed in order (although the typing spec doesn't mention that), so IO[bytes] always matches before _ZipReadable can match.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've opened python/typing#1803 for clarification.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is that a problem though? _ZipReadable will get matched if IO[bytes] isn't fully fulfilled, or am I misunderstanding something?

# metadata_encoding is only allowed for read mode
@overload
def __init__(
self,
file: StrPath | _ZipReadable,
mode: Literal["r"] = "r",
compression: int = 0,
allowZip64: bool = True,
compresslevel: int | None = None,
*,
strict_timestamps: bool = True,
metadata_encoding: str | None,
metadata_encoding: str | None = None,
) -> None: ...
@overload
def __init__(
self,
file: StrPath | IO[bytes],
mode: _ZipFileMode = "r",
file: StrPath | _ZipWritable,
mode: Literal["w", "x"] = ...,
compression: int = 0,
allowZip64: bool = True,
compresslevel: int | None = None,
*,
strict_timestamps: bool = True,
metadata_encoding: None = None,
) -> None: ...
@overload
def __init__(
self,
file: StrPath | _ZipReadableTellable,
mode: Literal["a"] = ...,
compression: int = 0,
allowZip64: bool = True,
compresslevel: int | None = None,
Expand All @@ -132,6 +172,7 @@ class ZipFile:
metadata_encoding: None = None,
) -> None: ...
else:
@overload
def __init__(
self,
file: StrPath | IO[bytes],
Expand All @@ -142,6 +183,39 @@ class ZipFile:
*,
strict_timestamps: bool = True,
) -> None: ...
@overload
def __init__(
self,
file: StrPath | _ZipReadable,
mode: Literal["r"] = "r",
compression: int = 0,
allowZip64: bool = True,
compresslevel: int | None = None,
*,
strict_timestamps: bool = True,
) -> None: ...
@overload
def __init__(
self,
file: StrPath | _ZipWritable,
mode: Literal["w", "x"] = ...,
compression: int = 0,
allowZip64: bool = True,
compresslevel: int | None = None,
*,
strict_timestamps: bool = True,
) -> None: ...
@overload
def __init__(
self,
file: StrPath | _ZipReadableTellable,
mode: Literal["a"] = ...,
compression: int = 0,
allowZip64: bool = True,
compresslevel: int | None = None,
*,
strict_timestamps: bool = True,
) -> None: ...

def __enter__(self) -> Self: ...
def __exit__(
Expand Down