Skip to content

Commit ce50550

Browse files
committed
Use ParamSpec to specify api function types
1 parent bc83517 commit ce50550

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

.mypy.ini

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
[mypy]
22
strict_optional = true
3-
plugins = duet.typing
43
show_error_codes = true
54
warn_unused_ignores = true
65

duet/api.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,21 @@
3232
TypeVar,
3333
)
3434

35+
try:
36+
from typing import ParamSpec
37+
except ImportError:
38+
from typing_extensions import ParamSpec # type: ignore[assignment]
39+
3540
import duet.impl as impl
3641
from duet.aitertools import aenumerate, aiter, AnyIterable, AsyncCollector
3742
from duet.futuretools import AwaitableFuture
3843

44+
P = ParamSpec("P")
3945
T = TypeVar("T")
4046
U = TypeVar("U")
4147

4248

43-
def run(func: Callable[..., Awaitable[T]], *args, **kwds) -> T:
49+
def run(func: Callable[P, Awaitable[T]], *args: P.args, **kwds: P.kwargs) -> T:
4450
"""Run an async function to completion.
4551
4652
Args:
@@ -72,7 +78,7 @@ def run(func: Callable[..., Awaitable[T]], *args, **kwds) -> T:
7278
scheduler.cleanup_signals()
7379

7480

75-
def sync(f: Callable[..., Awaitable[T]]) -> Callable[..., T]:
81+
def sync(f: Callable[P, Awaitable[T]]) -> Callable[P, T]:
7682
"""Decorator that adds a sync version of async function or method."""
7783
if isinstance(f, classmethod):
7884
raise TypeError(f"duet.sync cannot be applied to classmethod {f.__func__}")
@@ -113,7 +119,7 @@ def wrapped(self, *args, **kw):
113119
def wrapped(*args, **kw):
114120
return run(f, *args, **kw)
115121

116-
return wrapped
122+
return wrapped # type: ignore[return-value]
117123

118124

119125
def awaitable(value):
@@ -375,12 +381,14 @@ def __init__(
375381
def cancel(self) -> None:
376382
self._main_task.interrupt(self._main_task, CancelledError())
377383

378-
def spawn(self, func: Callable[..., Awaitable[Any]], *args, **kwds) -> None:
384+
def spawn(self, func: Callable[P, Awaitable[Any]], *args: P.args, **kwds: P.kwargs) -> None:
379385
"""Starts a background task that will run the given function."""
380386
task = self._scheduler.spawn(self._run(func, *args, **kwds), main_task=self._main_task)
381387
self._tasks.add(task)
382388

383-
async def _run(self, func: Callable[..., Awaitable[Any]], *args, **kwds) -> None:
389+
async def _run(
390+
self, func: Callable[P, Awaitable[Any]], *args: P.args, **kwds: P.kwargs
391+
) -> None:
384392
task = impl.current_task()
385393
try:
386394
await func(*args, **kwds)
@@ -513,7 +521,7 @@ def scope(self) -> Scope:
513521
def limiter(self) -> Limiter:
514522
pass
515523

516-
def spawn(self, func: Callable[..., Awaitable[Any]], *args, **kwds) -> None:
524+
def spawn(self, func: Callable[P, Awaitable[Any]], *args: P.args, **kwds: P.kwargs) -> None:
517525
"""Starts a background task that will run the given function."""
518526
self.scope.spawn(func, *args, **kwds)
519527

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
typing-extensions >= 3.10.0; python_version <= '3.7'
1+
typing-extensions >= 4.0.0; python_version < '3.10'

0 commit comments

Comments
 (0)