Skip to content

Commit 8f77254

Browse files
More type-checking fixes
1 parent 54d8cb9 commit 8f77254

File tree

8 files changed

+251
-106
lines changed

8 files changed

+251
-106
lines changed

tanjun/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ async def hello(ctx: tanjun.abc.Context, user: hikari.User | None) -> None:
115115
"InteractionAcceptsEnum",
116116
"LazyConstant",
117117
"MenuCommand",
118+
"MenuHooks",
118119
"MessageAcceptsEnum",
119120
"MessageCommand",
120121
"MessageCommandGroup",
@@ -153,7 +154,6 @@ async def hello(ctx: tanjun.abc.Context, user: hikari.User | None) -> None:
153154
"dependencies",
154155
"errors",
155156
"hooks",
156-
"MenuHooks",
157157
"inject",
158158
"inject_lc",
159159
"injected",
@@ -297,9 +297,9 @@ async def hello(ctx: tanjun.abc.Context, user: hikari.User | None) -> None:
297297
from .errors import TooManyArgumentsError
298298
from .hooks import AnyHooks
299299
from .hooks import Hooks
300+
from .hooks import MenuHooks
300301
from .hooks import MessageHooks
301302
from .hooks import SlashHooks
302-
from .hooks import MenuHooks
303303
from .injecting import as_self_injecting
304304
from .parsing import ShlexParser
305305
from .parsing import with_argument

tanjun/_internal/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757

5858
_P = typing_extensions.ParamSpec("_P")
5959

60-
_ContextT = typing.TypeVar("_ContextT", bound=abc.Context)
60+
_ContextT = typing.TypeVar("_ContextT", bound=tanjun.Context)
6161
_TreeT = dict[
6262
typing.Union[str, "_IndexKeys"],
6363
typing.Union["_TreeT", list[tuple[list[str], tanjun.MessageCommand[typing.Any]]]],

tanjun/abc.py

Lines changed: 96 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,17 @@
8484

8585
_AutocompleteValueT = typing.TypeVar("_AutocompleteValueT", int, str, float)
8686
_BaseSlashCommandT = typing.TypeVar("_BaseSlashCommandT", bound="BaseSlashCommand")
87-
_CommandCallbackSigT = typing.TypeVar("_CommandCallbackSigT", bound="CommandCallbackSig")
88-
_ErrorHookSigT = typing.TypeVar("_ErrorHookSigT", bound="ErrorHookSig")
89-
_HookSigT = typing.TypeVar("_HookSigT", bound="HookSig")
87+
88+
_AnyErrorHookSigT = typing.TypeVar("_AnyErrorHookSigT", bound="ErrorHookSig[typing.Any]")
89+
_MenuErrorHookSigT = typing.TypeVar("_MenuErrorHookSigT", bound="ErrorHookSig[MenuContext]")
90+
_MessageErrorHookSigT = typing.TypeVar("_MessageErrorHookSigT", bound="ErrorHookSig[MessageContext]")
91+
_SlashErrorHookSigT = typing.TypeVar("_SlashErrorHookSigT", bound="ErrorHookSig[SlashContext]")
92+
93+
_AnyHookSigT = typing.TypeVar("_AnyHookSigT", bound="HookSig[typing.Any]")
94+
_MenuHookSigT = typing.TypeVar("_MenuHookSigT", bound="HookSig[MenuContext]")
95+
_MessageHookSigT = typing.TypeVar("_MessageHookSigT", bound="HookSig[MessageContext]")
96+
_SlashHookSigT = typing.TypeVar("_SlashHookSigT", bound="HookSig[SlashContext]")
97+
9098
_ListenerCallbackSigT = typing.TypeVar("_ListenerCallbackSigT", bound="ListenerCallbackSig")
9199
_MenuCommandT = typing.TypeVar("_MenuCommandT", bound="MenuCommand[typing.Any, typing.Any]")
92100
_MessageCommandT = typing.TypeVar("_MessageCommandT", bound="MessageCommand[typing.Any]")
@@ -133,12 +141,14 @@
133141
SlashCheckSig = _CheckSig["SlashContext", ...]
134142

135143

136-
_CommandCallbackSig = collections.Callable[typing_extensions.Concatenate[_ContextT_contra, _P], collections.Coroutine[typing.Any, typing.Any, None]]
144+
_CommandCallbackSig = collections.Callable[
145+
typing_extensions.Concatenate[_ContextT_contra, _P], collections.Coroutine[typing.Any, typing.Any, None]
146+
]
137147

138-
_MenuValueT = typing.TypeVar("_MenuValueT", hikari.User, hikari.InteractionMember)
148+
_MenuValueT = typing.TypeVar("_MenuValueT", hikari.Message, hikari.InteractionMember)
139149
_ManuCallbackSig = collections.Callable[
140150
typing_extensions.Concatenate[_ContextT_contra, _MenuValueT, _P],
141-
collections.Coroutine[typing.Any, typing.Any, None]
151+
collections.Coroutine[typing.Any, typing.Any, None],
142152
]
143153
MenuCallbackSig = _ManuCallbackSig["MenuContext", _MenuValueT, ...]
144154
"""Type hint of a context menu command callback.
@@ -2155,8 +2165,23 @@ def add_on_error(self, callback: ErrorHookSig[_ContextT_contra], /) -> Self:
21552165
The hook object to enable method chaining.
21562166
"""
21572167

2168+
@typing.overload
2169+
@abc.abstractmethod
2170+
def with_on_error(self: MenuHooks, callback: _MenuErrorHookSigT, /) -> _MenuErrorHookSigT:
2171+
...
2172+
2173+
@typing.overload
2174+
@abc.abstractmethod
2175+
def with_on_error(self: MessageHooks, callback: _MessageErrorHookSigT, /) -> _MessageErrorHookSigT:
2176+
...
2177+
2178+
@typing.overload
2179+
@abc.abstractmethod
2180+
def with_on_error(self: SlashHooks, callback: _SlashErrorHookSigT, /) -> _SlashErrorHookSigT:
2181+
...
2182+
21582183
@abc.abstractmethod
2159-
def with_on_error(self, callback: _ErrorHookSigT, /) -> _ErrorHookSigT:
2184+
def with_on_error(self, callback: _AnyErrorHookSigT, /) -> _AnyErrorHookSigT:
21602185
"""Add an error callback to this hook object through a decorator call.
21612186
21622187
!!! note
@@ -2219,8 +2244,23 @@ def add_on_parser_error(self, callback: HookSig[_ContextT_contra], /) -> Self:
22192244
The hook object to enable method chaining.
22202245
"""
22212246

2247+
@typing.overload
2248+
@abc.abstractmethod
2249+
def with_on_parser_error(self: MenuHooks, callback: _MenuHookSigT, /) -> _MenuHookSigT:
2250+
...
2251+
2252+
@typing.overload
2253+
@abc.abstractmethod
2254+
def with_on_parser_error(self: MessageHooks, callback: _MessageHookSigT, /) -> _MessageHookSigT:
2255+
...
2256+
2257+
@typing.overload
2258+
@abc.abstractmethod
2259+
def with_on_parser_error(self: SlashHooks, callback: _SlashHookSigT, /) -> _SlashHookSigT:
2260+
...
2261+
22222262
@abc.abstractmethod
2223-
def with_on_parser_error(self, callback: _HookSigT, /) -> _HookSigT:
2263+
def with_on_parser_error(self, callback: _AnyHookSigT, /) -> _AnyHookSigT:
22242264
"""Add a parser error callback to this hook object through a decorator call.
22252265
22262266
Examples
@@ -2267,8 +2307,23 @@ def add_post_execution(self, callback: HookSig[_ContextT_contra], /) -> Self:
22672307
The hook object to enable method chaining.
22682308
"""
22692309

2310+
@typing.overload
2311+
@abc.abstractmethod
2312+
def with_post_execution(self: MenuHooks, callback: _MenuHookSigT, /) -> _MenuHookSigT:
2313+
...
2314+
2315+
@typing.overload
2316+
@abc.abstractmethod
2317+
def with_post_execution(self: MessageHooks, callback: _MessageHookSigT, /) -> _MessageHookSigT:
2318+
...
2319+
2320+
@typing.overload
2321+
@abc.abstractmethod
2322+
def with_post_execution(self: SlashHooks, callback: _SlashHookSigT, /) -> _SlashHookSigT:
2323+
...
2324+
22702325
@abc.abstractmethod
2271-
def with_post_execution(self, callback: _HookSigT, /) -> _HookSigT:
2326+
def with_post_execution(self, callback: _AnyHookSigT, /) -> _AnyHookSigT:
22722327
"""Add a post-execution callback to this hook object through a decorator call.
22732328
22742329
Examples
@@ -2315,8 +2370,23 @@ def add_pre_execution(self, callback: HookSig[_ContextT_contra], /) -> Self:
23152370
The hook object to enable method chaining.
23162371
"""
23172372

2373+
@typing.overload
2374+
@abc.abstractmethod
2375+
def with_pre_execution(self: MenuHooks, callback: _MenuHookSigT, /) -> _MenuHookSigT:
2376+
...
2377+
2378+
@typing.overload
2379+
@abc.abstractmethod
2380+
def with_pre_execution(self: MessageHooks, callback: _MessageHookSigT, /) -> _MessageHookSigT:
2381+
...
2382+
2383+
@typing.overload
2384+
@abc.abstractmethod
2385+
def with_pre_execution(self: SlashHooks, callback: _SlashHookSigT, /) -> _SlashHookSigT:
2386+
...
2387+
23182388
@abc.abstractmethod
2319-
def with_pre_execution(self, callback: _HookSigT, /) -> _HookSigT:
2389+
def with_pre_execution(self, callback: _AnyHookSigT, /) -> _AnyHookSigT:
23202390
"""Add a pre-execution callback to this hook object through a decorator call.
23212391
23222392
Examples
@@ -2363,8 +2433,23 @@ def add_on_success(self, callback: HookSig[_ContextT_contra], /) -> Self:
23632433
The hook object to enable method chaining.
23642434
"""
23652435

2436+
@typing.overload
2437+
@abc.abstractmethod
2438+
def with_on_success(self: MenuHooks, callback: _MenuHookSigT, /) -> _MenuHookSigT:
2439+
...
2440+
2441+
@typing.overload
2442+
@abc.abstractmethod
2443+
def with_on_success(self: MessageHooks, callback: _MessageHookSigT, /) -> _MessageHookSigT:
2444+
...
2445+
2446+
@typing.overload
2447+
@abc.abstractmethod
2448+
def with_on_success(self: SlashHooks, callback: _SlashHookSigT, /) -> _SlashHookSigT:
2449+
...
2450+
23662451
@abc.abstractmethod
2367-
def with_on_success(self, callback: _HookSigT, /) -> _HookSigT:
2452+
def with_on_success(self, callback: _AnyHookSigT, /) -> _AnyHookSigT:
23682453
"""Add a success callback to this hook object through a decorator call.
23692454
23702455
Examples

tanjun/commands/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class PartialCommand(tanjun.ExecutableCommand[_ContextT], components.AbstractCom
5555
__slots__ = ("_checks", "_component", "_hooks", "_metadata")
5656

5757
def __init__(self) -> None:
58-
self._checks: list[tanjun.AnyCheckSig] = []
58+
self._checks: list[tanjun.CheckSig[_ContextT]] = []
5959
self._component: typing.Optional[tanjun.Component] = None
6060
self._hooks: typing.Optional[tanjun.Hooks[_ContextT]] = None
6161
self._metadata: dict[typing.Any, typing.Any] = {}
@@ -98,15 +98,15 @@ def set_metadata(self, key: typing.Any, value: typing.Any, /) -> Self:
9898
self._metadata[key] = value
9999
return self
100100

101-
def add_check(self, *checks: tanjun.AnyCheckSig) -> Self:
101+
def add_check(self, *checks: tanjun.CheckSig[_ContextT]) -> Self:
102102
# <<inherited docstring from tanjun.abc.ExecutableCommand>>.
103103
for check in checks:
104104
if check not in self._checks:
105105
self._checks.append(check)
106106

107107
return self
108108

109-
def remove_check(self, check: tanjun.AnyCheckSig, /) -> Self:
109+
def remove_check(self, check: tanjun.CheckSig[_ContextT], /) -> Self:
110110
# <<inherited docstring from tanjun.abc.ExecutableCommand>>.
111111
self._checks.remove(check)
112112
return self

0 commit comments

Comments
 (0)