Skip to content

Use Literal to improve SpooledTemporaryFile #3526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 5, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 50 additions & 21 deletions stdlib/3/tempfile.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def TemporaryFile(
prefix: Optional[AnyStr] = ...,
dir: Optional[_DirT[AnyStr]] = ...,
) -> IO[Any]: ...

@overload
def NamedTemporaryFile(
mode: Literal["r", "w", "a", "x", "r+", "w+", "a+", "x+", "rt", "wt", "at", "xt", "r+t", "w+t", "a+t", "x+t"],
Expand Down Expand Up @@ -93,17 +92,48 @@ def NamedTemporaryFile(
# It does not actually derive from IO[AnyStr], but it does implement the
# protocol.
class SpooledTemporaryFile(IO[AnyStr]):
def __init__(self, max_size: int = ..., mode: str = ...,
buffering: int = ..., encoding: Optional[str] = ...,
newline: Optional[str] = ..., suffix: Optional[str] = ...,
prefix: Optional[str] = ..., dir: Optional[str] = ...
) -> None: ...
# bytes needs to go first, as default mode is to open as bytes
@overload
def __init__(
self: SpooledTemporaryFile[bytes],
max_size: int = ...,
mode: Literal["rb", "wb", "ab", "xb", "r+b", "w+b", "a+b", "x+b"] = ...,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also just "b".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mode "b" and "" both cause errors in Python 2.7 and 3.7. I haven't tested on other Python versions, but I don't see why they would be different.

Trying to instantiate a TemporaryFile (done when a SpooledTemporaryFile rolls over) with a mode string not containing exactly one of rwax and more than one + will raise ValueError: Must have exactly one of create/read/write/append mode and at most one plus on Python 3.7.4.

Python 2.7.16 has a similar error: ValueError: mode string must begin with one of 'r', 'w', 'a' or 'U', not 'b'.

Confusingly, the SpooledTemporaryFile constructor accepts any mode, but an exception is raised only when it creates a TemporaryFile internally. This can be forced by using the rollover() function.

Test case:

import tempfile
# this line works fine
with tempfile.SpooledTemporaryFile(mode="b") as tmpfile:
    # this line fails, since the invalid mode is checked here
    tmpfile.rollover()

buffering: int = ...,
encoding: Optional[str] = ...,
newline: Optional[str] = ...,
suffix: Optional[str] = ...,
prefix: Optional[str] = ...,
dir: Optional[str] = ...,
) -> None: ...
@overload
def __init__(
self: SpooledTemporaryFile[str],
max_size: int = ...,
mode: Literal["r", "w", "a", "x", "r+", "w+", "a+", "x+", "rt", "wt", "at", "xt", "r+t", "w+t", "a+t", "x+t"] = ...,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also just "".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same error as above on Python 3.7.4, Python 2.7.16 has it's own ValueError: empty mode string error. It is only seen after the SpooledTemporaryFile creates a TemporaryFile internally, which can be forced using ~SpooledTemporaryFile.rollover().

import tempfile

with tempfile.SpooledTemporaryFile(mode="") as tmpfile:
    tmpfile.rollover() # empty mode fails on this line, not at constructor

buffering: int = ...,
encoding: Optional[str] = ...,
newline: Optional[str] = ...,
suffix: Optional[str] = ...,
prefix: Optional[str] = ...,
dir: Optional[str] = ...,
) -> None: ...
@overload
def __init__(
self,
max_size: int = ...,
mode: str = ...,
buffering: int = ...,
encoding: Optional[str] = ...,
newline: Optional[str] = ...,
suffix: Optional[str] = ...,
prefix: Optional[str] = ...,
dir: Optional[str] = ...,
) -> None: ...
def rollover(self) -> None: ...
def __enter__(self: _S) -> _S: ...
def __exit__(self, exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType]) -> Optional[bool]: ...

def __exit__(
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
) -> Optional[bool]: ...
# These methods are copied from the abstract methods of IO, because
# SpooledTemporaryFile implements IO.
# See also https://github.com/python/typeshed/pull/2452#issuecomment-420657918.
Expand All @@ -127,25 +157,24 @@ class SpooledTemporaryFile(IO[AnyStr]):

class TemporaryDirectory(Generic[AnyStr]):
name: str
def __init__(self, suffix: Optional[AnyStr] = ..., prefix: Optional[AnyStr] = ...,
dir: Optional[_DirT[AnyStr]] = ...) -> None: ...
def __init__(
self, suffix: Optional[AnyStr] = ..., prefix: Optional[AnyStr] = ..., dir: Optional[_DirT[AnyStr]] = ...
) -> None: ...
def cleanup(self) -> None: ...
def __enter__(self) -> AnyStr: ...
def __exit__(self, exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType]) -> None: ...
def __exit__(
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
) -> None: ...

def mkstemp(suffix: Optional[AnyStr] = ..., prefix: Optional[AnyStr] = ..., dir: Optional[_DirT[AnyStr]] = ...,
text: bool = ...) -> Tuple[int, AnyStr]: ...
def mkstemp(
suffix: Optional[AnyStr] = ..., prefix: Optional[AnyStr] = ..., dir: Optional[_DirT[AnyStr]] = ..., text: bool = ...
) -> Tuple[int, AnyStr]: ...
@overload
def mkdtemp() -> str: ...
@overload
def mkdtemp(suffix: Optional[AnyStr] = ..., prefix: Optional[AnyStr] = ...,
dir: Optional[_DirT[AnyStr]] = ...) -> AnyStr: ...
def mkdtemp(suffix: Optional[AnyStr] = ..., prefix: Optional[AnyStr] = ..., dir: Optional[_DirT[AnyStr]] = ...) -> AnyStr: ...
def mktemp(suffix: Optional[AnyStr] = ..., prefix: Optional[AnyStr] = ..., dir: Optional[_DirT[AnyStr]] = ...) -> AnyStr: ...

def gettempdirb() -> bytes: ...
def gettempprefixb() -> bytes: ...

def gettempdir() -> str: ...
def gettempprefix() -> str: ...