Skip to content

Commit b70f4a0

Browse files
authored
Fix type annotations for wrapped functions (#5)
Fix type annotations for wrapped functions: - Use a `ParamSpec` to capture the wrapped function arguments and a separate `TypeVar` for the return type, so that the annotation can accurately wrap just the return value in a `Coroutine` annotation. - Add explicit type hints for functions that use @overload (which are not compatible with `ParamSpec`). - Add typing-extensions as a dependency (used for `ParamSpec` on older Python versions) - Add a few simple tests to verify that type hints now are producing correct output. Closes #4
1 parent dad22ff commit b70f4a0

File tree

5 files changed

+191
-12
lines changed

5 files changed

+191
-12
lines changed

aioshutil/__init__.py

Lines changed: 170 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,28 @@
22
"""
33
Asynchronous shutil module.
44
"""
5+
from __future__ import annotations
6+
57
import asyncio
68
import shutil
79
from functools import partial, wraps
8-
from typing import Any, Awaitable, Callable, TypeVar, cast
10+
from typing import (
11+
TYPE_CHECKING,
12+
Any,
13+
Callable,
14+
Coroutine,
15+
Optional,
16+
Sequence,
17+
TypeVar,
18+
Union,
19+
overload,
20+
)
21+
22+
try:
23+
from typing import ParamSpec, TypeAlias # type: ignore
24+
except ImportError:
25+
# Python versions < 3.10
26+
from typing_extensions import ParamSpec, TypeAlias
927

1028
__all__ = [
1129
"copyfileobj",
@@ -35,42 +53,183 @@
3553
"SameFileError",
3654
]
3755

38-
T = TypeVar("T", bound=Callable[..., Any])
56+
P = ParamSpec("P")
57+
R = TypeVar("R")
58+
59+
if TYPE_CHECKING: # pragma: no cover
60+
# type hints for wrapped functions with overloads (which are incompatible
61+
# with ParamSpec).
62+
63+
import sys
64+
from os import PathLike
65+
66+
StrPath: TypeAlias = Union[str, PathLike[str]]
67+
BytesPath: TypeAlias = Union[bytes, PathLike[bytes]]
68+
StrOrBytesPath: TypeAlias = Union[str, bytes, PathLike[str], PathLike[bytes]]
69+
_PathReturn: TypeAlias = Any
70+
_StrPathT = TypeVar("_StrPathT", bound=StrPath)
71+
72+
@overload
73+
async def copy(
74+
src: StrPath, dst: StrPath, *, follow_symlinks: bool = ...
75+
) -> _PathReturn:
76+
...
77+
78+
@overload
79+
async def copy(
80+
src: BytesPath, dst: BytesPath, *, follow_symlinks: bool = ...
81+
) -> _PathReturn:
82+
...
83+
84+
async def copy(src, dst, *, follow_symlinks=...):
85+
...
86+
87+
@overload
88+
async def copy2(
89+
src: StrPath, dst: StrPath, *, follow_symlinks: bool = ...
90+
) -> _PathReturn:
91+
...
92+
93+
@overload
94+
async def copy2(
95+
src: BytesPath, dst: BytesPath, *, follow_symlinks: bool = ...
96+
) -> _PathReturn:
97+
...
98+
99+
async def copy2(src, dst, *, follow_symlinks=...):
100+
...
101+
102+
@overload
103+
async def register_archive_format(
104+
name: str,
105+
function: Callable[..., object],
106+
extra_args: Sequence[tuple[str, Any] | list[Any]],
107+
description: str = ...,
108+
) -> None:
109+
...
110+
111+
@overload
112+
async def register_archive_format(
113+
name: str,
114+
function: Callable[[str, str], object],
115+
extra_args: None = ...,
116+
description: str = ...,
117+
) -> None:
118+
...
119+
120+
async def register_archive_format(name, function, extra_args=..., description=...):
121+
...
122+
123+
@overload
124+
async def register_unpack_format(
125+
name: str,
126+
extensions: list[str],
127+
function: Callable[..., object],
128+
extra_args: Sequence[tuple[str, Any]],
129+
description: str = ...,
130+
) -> None:
131+
...
132+
133+
@overload
134+
async def register_unpack_format(
135+
name: str,
136+
extensions: list[str],
137+
function: Callable[[str, str], object],
138+
extra_args: None = ...,
139+
description: str = ...,
140+
) -> None:
141+
...
142+
143+
async def register_unpack_format(
144+
name, extensions, function, extra_args=..., description=...
145+
):
146+
...
147+
148+
@overload
149+
async def chown(
150+
path: StrOrBytesPath, user: Union[str, int], group: None = ...
151+
) -> None:
152+
...
153+
154+
@overload
155+
async def chown(
156+
path: StrOrBytesPath, user: None = ..., *, group: Union[str, int]
157+
) -> None:
158+
...
159+
160+
@overload
161+
async def chown(path: StrOrBytesPath, user: None, group: Union[str, int]) -> None:
162+
...
163+
164+
@overload
165+
async def chown(
166+
path: StrOrBytesPath, user: Union[str, int], group: Union[str, int]
167+
) -> None:
168+
...
169+
170+
async def chown(path, user=..., group=...):
171+
...
172+
173+
if sys.version_info >= (3, 8):
174+
175+
@overload
176+
async def which(
177+
cmd: _StrPathT, mode: int = ..., path: Optional[StrPath] = ...
178+
) -> Union[str, _StrPathT, None]:
179+
...
180+
181+
@overload
182+
async def which(
183+
cmd: bytes, mode: int = ..., path: Optional[StrPath] = ...
184+
) -> Optional[bytes]:
185+
...
186+
187+
async def which(
188+
cmd, mode=..., path=...
189+
) -> Union[bytes, str, StrPath, PathLike[str], None]:
190+
...
191+
192+
else:
193+
194+
async def which(
195+
cmd: _StrPathT, mode: int = ..., path: StrPath | None = ...
196+
) -> str | _StrPathT | None:
197+
...
39198

40199

41-
def sync_to_async(func: T):
200+
def sync_to_async(func: Callable[P, R]) -> Callable[P, Coroutine[Any, Any, R]]:
42201
@wraps(func)
43-
async def run_in_executor(*args, **kwargs):
202+
async def run_in_executor(*args: P.args, **kwargs: P.kwargs) -> R:
44203
loop = asyncio.get_event_loop()
45204
pfunc = partial(func, *args, **kwargs)
46205
return await loop.run_in_executor(None, pfunc)
47206

48-
return cast(Awaitable[T], run_in_executor)
207+
return run_in_executor
49208

50209

51210
rmtree = sync_to_async(shutil.rmtree)
52211
copyfile = sync_to_async(shutil.copyfile)
53212
copyfileobj = sync_to_async(shutil.copyfileobj)
54213
copymode = sync_to_async(shutil.copymode)
55214
copystat = sync_to_async(shutil.copystat)
56-
copy = sync_to_async(shutil.copy)
57-
copy2 = sync_to_async(shutil.copy2)
215+
copy = sync_to_async(shutil.copy) # type: ignore # noqa: F811
216+
copy2 = sync_to_async(shutil.copy2) # type: ignore # noqa: F811
58217
copytree = sync_to_async(shutil.copytree)
59218
move = sync_to_async(shutil.move)
60219
Error = shutil.Error
61220
SpecialFileError = shutil.SpecialFileError
62221
ExecError = shutil.ExecError
63222
make_archive = sync_to_async(shutil.make_archive)
64223
get_archive_formats = sync_to_async(shutil.get_archive_formats)
65-
register_archive_format = sync_to_async(shutil.register_archive_format)
224+
register_archive_format = sync_to_async(shutil.register_archive_format) # type: ignore # noqa: F811
66225
unregister_archive_format = sync_to_async(shutil.unregister_archive_format)
67226
get_unpack_formats = sync_to_async(shutil.get_unpack_formats)
68-
register_unpack_format = sync_to_async(shutil.register_unpack_format)
227+
register_unpack_format = sync_to_async(shutil.register_unpack_format) # type: ignore # noqa: F811
69228
unregister_unpack_format = sync_to_async(shutil.unregister_unpack_format)
70229
unpack_archive = sync_to_async(shutil.unpack_archive)
71230
ignore_patterns = sync_to_async(shutil.ignore_patterns)
72-
chown = sync_to_async(shutil.chown)
73-
which = sync_to_async(shutil.which)
231+
chown = sync_to_async(shutil.chown) # type: ignore # noqa: F811
232+
which = sync_to_async(shutil.which) # type: ignore # noqa: F811
74233
get_terminal_size = sync_to_async(shutil.get_terminal_size)
75234
SameFileError = shutil.SameFileError
76235

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ pre-commit
22
pytest
33
pytest-asyncio
44
pytest-cov
5+
pytest-mypy-plugins

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,5 @@
3434
"License :: OSI Approved :: BSD License",
3535
],
3636
setup_requires=["setuptools_scm"],
37+
install_requires=["typing-extensions;python_version<'3.10'"],
3738
)

tests/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@
44

55
def pytest_collection_modifyitems(items):
66
for item in items:
7-
item.add_marker(pytest.mark.asyncio)
7+
# Only Python tests, not the typing tests
8+
if isinstance(item, pytest.Function):
9+
item.add_marker(pytest.mark.asyncio)

tests/test_typehints.yml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
- case: decorator_produces_coroutine
2+
regex: yes
3+
main: |
4+
from aioshutil import rmtree
5+
reveal_type(rmtree)
6+
out: |
7+
main:2: note: Revealed type is "def \(.*?\) -> typing\.Coroutine\[Any, Any, \w+\]"
8+
skip: sys.version_info < (3, 10)
9+
- case: copy_overload_typehint
10+
regex: yes
11+
main: |
12+
from aioshutil import copy
13+
reveal_type(copy)
14+
out: |
15+
main:2: note: Revealed type is "Overload\(def \(.*\) -> typing\.Coroutine\[Any, Any, \w+\], def \(.*\) -> typing\.Coroutine\[Any, Any, \w+\]\)"
16+
skip: sys.version_info < (3, 10)

0 commit comments

Comments
 (0)