Skip to content

Fix type hinting #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 170 additions & 11 deletions aioshutil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,28 @@
"""
Asynchronous shutil module.
"""
from __future__ import annotations

import asyncio
import shutil
from functools import partial, wraps
from typing import Any, Awaitable, Callable, TypeVar, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Optional,
Sequence,
TypeVar,
Union,
overload,
)

try:
from typing import ParamSpec, TypeAlias # type: ignore
except ImportError:
# Python versions < 3.10
from typing_extensions import ParamSpec, TypeAlias

__all__ = [
"copyfileobj",
Expand Down Expand Up @@ -35,42 +53,183 @@
"SameFileError",
]

T = TypeVar("T", bound=Callable[..., Any])
P = ParamSpec("P")
R = TypeVar("R")

if TYPE_CHECKING: # pragma: no cover
# type hints for wrapped functions with overloads (which are incompatible
# with ParamSpec).

import sys
from os import PathLike

StrPath: TypeAlias = Union[str, PathLike[str]]
BytesPath: TypeAlias = Union[bytes, PathLike[bytes]]
StrOrBytesPath: TypeAlias = Union[str, bytes, PathLike[str], PathLike[bytes]]
_PathReturn: TypeAlias = Any
_StrPathT = TypeVar("_StrPathT", bound=StrPath)

@overload
async def copy(
src: StrPath, dst: StrPath, *, follow_symlinks: bool = ...
) -> _PathReturn:
...

@overload
async def copy(
src: BytesPath, dst: BytesPath, *, follow_symlinks: bool = ...
) -> _PathReturn:
...

async def copy(src, dst, *, follow_symlinks=...):
...

@overload
async def copy2(
src: StrPath, dst: StrPath, *, follow_symlinks: bool = ...
) -> _PathReturn:
...

@overload
async def copy2(
src: BytesPath, dst: BytesPath, *, follow_symlinks: bool = ...
) -> _PathReturn:
...

async def copy2(src, dst, *, follow_symlinks=...):
...

@overload
async def register_archive_format(
name: str,
function: Callable[..., object],
extra_args: Sequence[tuple[str, Any] | list[Any]],
description: str = ...,
) -> None:
...

@overload
async def register_archive_format(
name: str,
function: Callable[[str, str], object],
extra_args: None = ...,
description: str = ...,
) -> None:
...

async def register_archive_format(name, function, extra_args=..., description=...):
...

@overload
async def register_unpack_format(
name: str,
extensions: list[str],
function: Callable[..., object],
extra_args: Sequence[tuple[str, Any]],
description: str = ...,
) -> None:
...

@overload
async def register_unpack_format(
name: str,
extensions: list[str],
function: Callable[[str, str], object],
extra_args: None = ...,
description: str = ...,
) -> None:
...

async def register_unpack_format(
name, extensions, function, extra_args=..., description=...
):
...

@overload
async def chown(
path: StrOrBytesPath, user: Union[str, int], group: None = ...
) -> None:
...

@overload
async def chown(
path: StrOrBytesPath, user: None = ..., *, group: Union[str, int]
) -> None:
...

@overload
async def chown(path: StrOrBytesPath, user: None, group: Union[str, int]) -> None:
...

@overload
async def chown(
path: StrOrBytesPath, user: Union[str, int], group: Union[str, int]
) -> None:
...

async def chown(path, user=..., group=...):
...

if sys.version_info >= (3, 8):

@overload
async def which(
cmd: _StrPathT, mode: int = ..., path: Optional[StrPath] = ...
) -> Union[str, _StrPathT, None]:
...

@overload
async def which(
cmd: bytes, mode: int = ..., path: Optional[StrPath] = ...
) -> Optional[bytes]:
...

async def which(
cmd, mode=..., path=...
) -> Union[bytes, str, StrPath, PathLike[str], None]:
...

else:

async def which(
cmd: _StrPathT, mode: int = ..., path: StrPath | None = ...
) -> str | _StrPathT | None:
...


def sync_to_async(func: T):
def sync_to_async(func: Callable[P, R]) -> Callable[P, Coroutine[Any, Any, R]]:
@wraps(func)
async def run_in_executor(*args, **kwargs):
async def run_in_executor(*args: P.args, **kwargs: P.kwargs) -> R:
loop = asyncio.get_event_loop()
pfunc = partial(func, *args, **kwargs)
return await loop.run_in_executor(None, pfunc)

return cast(Awaitable[T], run_in_executor)
return run_in_executor


rmtree = sync_to_async(shutil.rmtree)
copyfile = sync_to_async(shutil.copyfile)
copyfileobj = sync_to_async(shutil.copyfileobj)
copymode = sync_to_async(shutil.copymode)
copystat = sync_to_async(shutil.copystat)
copy = sync_to_async(shutil.copy)
copy2 = sync_to_async(shutil.copy2)
copy = sync_to_async(shutil.copy) # type: ignore # noqa: F811
copy2 = sync_to_async(shutil.copy2) # type: ignore # noqa: F811
copytree = sync_to_async(shutil.copytree)
move = sync_to_async(shutil.move)
Error = shutil.Error
SpecialFileError = shutil.SpecialFileError
ExecError = shutil.ExecError
make_archive = sync_to_async(shutil.make_archive)
get_archive_formats = sync_to_async(shutil.get_archive_formats)
register_archive_format = sync_to_async(shutil.register_archive_format)
register_archive_format = sync_to_async(shutil.register_archive_format) # type: ignore # noqa: F811
unregister_archive_format = sync_to_async(shutil.unregister_archive_format)
get_unpack_formats = sync_to_async(shutil.get_unpack_formats)
register_unpack_format = sync_to_async(shutil.register_unpack_format)
register_unpack_format = sync_to_async(shutil.register_unpack_format) # type: ignore # noqa: F811
unregister_unpack_format = sync_to_async(shutil.unregister_unpack_format)
unpack_archive = sync_to_async(shutil.unpack_archive)
ignore_patterns = sync_to_async(shutil.ignore_patterns)
chown = sync_to_async(shutil.chown)
which = sync_to_async(shutil.which)
chown = sync_to_async(shutil.chown) # type: ignore # noqa: F811
which = sync_to_async(shutil.which) # type: ignore # noqa: F811
get_terminal_size = sync_to_async(shutil.get_terminal_size)
SameFileError = shutil.SameFileError

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pre-commit
pytest
pytest-asyncio
pytest-cov
pytest-mypy-plugins
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@
"License :: OSI Approved :: BSD License",
],
setup_requires=["setuptools_scm"],
install_requires=["typing-extensions;python_version<'3.10'"],
)
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@

def pytest_collection_modifyitems(items):
for item in items:
item.add_marker(pytest.mark.asyncio)
# Only Python tests, not the typing tests
if isinstance(item, pytest.Function):
item.add_marker(pytest.mark.asyncio)
16 changes: 16 additions & 0 deletions tests/test_typehints.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
- case: decorator_produces_coroutine
regex: yes
main: |
from aioshutil import rmtree
reveal_type(rmtree)
out: |
main:2: note: Revealed type is "def \(.*?\) -> typing\.Coroutine\[Any, Any, \w+\]"
skip: sys.version_info < (3, 10)
- case: copy_overload_typehint
regex: yes
main: |
from aioshutil import copy
reveal_type(copy)
out: |
main:2: note: Revealed type is "Overload\(def \(.*\) -> typing\.Coroutine\[Any, Any, \w+\], def \(.*\) -> typing\.Coroutine\[Any, Any, \w+\]\)"
skip: sys.version_info < (3, 10)