Skip to content

Commit 51fbd62

Browse files
committed
Support non-generic decorators in import cycles
Infer more decorated function signatures during semantic analysis. Also add extra test cases for import cycles in general. Partially addresses #1303.
1 parent b9bc14a commit 51fbd62

File tree

2 files changed

+177
-7
lines changed

2 files changed

+177
-7
lines changed

mypy/semanal.py

+37-7
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
from mypy.lex import lex
7777
from mypy.parsetype import parse_type
7878
from mypy.sametypes import is_same_type
79+
from mypy.erasetype import erase_typevars
7980
from mypy import defaults
8081

8182

@@ -2426,8 +2427,12 @@ def visit_class_def(self, tdef: ClassDef) -> None:
24262427
def visit_decorator(self, dec: Decorator) -> None:
24272428
"""Try to infer the type of the decorated function.
24282429
2429-
This helps us resolve forward references to decorated
2430-
functions during type checking.
2430+
This lets us resolve references to decorated functions during
2431+
type checking when there are cyclic imports, as otherwise the
2432+
type might not be available when we need it.
2433+
2434+
This basically uses a simple special-purpose type inference
2435+
engine just for decorators.
24312436
"""
24322437
super().visit_decorator(dec)
24332438
if dec.var.is_property:
@@ -2453,13 +2458,19 @@ def visit_decorator(self, dec: Decorator) -> None:
24532458
decorator_preserves_type = False
24542459
break
24552460
if decorator_preserves_type:
2456-
# No non-special decorators left. We can trivially infer the type
2461+
# No non-identity decorators left. We can trivially infer the type
24572462
# of the function here.
24582463
dec.var.type = function_type(dec.func, self.builtin_type('function'))
2459-
if dec.decorators and returns_any_if_called(dec.decorators[0]):
2460-
# The outermost decorator will return Any so we know the type of the
2461-
# decorated function.
2462-
dec.var.type = AnyType()
2464+
if dec.decorators:
2465+
if returns_any_if_called(dec.decorators[0]):
2466+
# The outermost decorator will return Any so we know the type of the
2467+
# decorated function.
2468+
dec.var.type = AnyType()
2469+
sig = find_fixed_callable_return(dec.decorators[0])
2470+
if sig:
2471+
# The outermost decorator always returns the same kind of function,
2472+
# so we know that this is the type of the decoratored function.
2473+
dec.var.type = sig
24632474

24642475
def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
24652476
self.analyze(s.type)
@@ -2673,3 +2684,22 @@ def returns_any_if_called(expr: Node) -> bool:
26732684
elif isinstance(expr, CallExpr):
26742685
return returns_any_if_called(expr.callee)
26752686
return False
2687+
2688+
2689+
def find_fixed_callable_return(expr: Node) -> Optional[CallableType]:
2690+
if isinstance(expr, RefExpr):
2691+
if isinstance(expr.node, FuncDef):
2692+
typ = expr.node.type
2693+
if typ:
2694+
if isinstance(typ, CallableType) and has_no_typevars(typ.ret_type):
2695+
return typ.ret_type
2696+
elif isinstance(expr, CallExpr):
2697+
t = find_fixed_callable_return(expr.callee)
2698+
if t:
2699+
if isinstance(t.ret_type, CallableType):
2700+
return t.ret_type
2701+
return None
2702+
2703+
2704+
def has_no_typevars(typ: Type) -> bool:
2705+
return is_same_type(typ, erase_typevars(typ))

mypy/test/data/check-functions.test

+140
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,146 @@ def g(): pass
863863
def dec(f): pass
864864

865865

866+
-- Decorator functions in import cycles
867+
-- ------------------------------------
868+
869+
870+
[case testDecoratorWithIdentityTypeInImportCycle]
871+
import a
872+
873+
[file a.py]
874+
import b
875+
from d import dec
876+
@dec
877+
def f(x: int) -> None: pass
878+
b.g(1) # E
879+
880+
[file b.py]
881+
import a
882+
from d import dec
883+
@dec
884+
def g(x: str) -> None: pass
885+
a.f('')
886+
887+
[file d.py]
888+
from typing import TypeVar
889+
T = TypeVar('T')
890+
def dec(f: T) -> T: return f
891+
892+
[out]
893+
tmp/a.py:1: note: In module imported here,
894+
main:1: note: ... from here:
895+
tmp/b.py:5: error: Argument 1 to "f" has incompatible type "str"; expected "int"
896+
main:1: note: In module imported here:
897+
tmp/a.py:5: error: Argument 1 to "g" has incompatible type "int"; expected "str"
898+
899+
[case testDecoratorWithNoAnnotationInImportCycle]
900+
import a
901+
902+
[file a.py]
903+
import b
904+
from d import dec
905+
@dec
906+
def f(x: int) -> None: pass
907+
b.g(1, z=4)
908+
909+
[file b.py]
910+
import a
911+
from d import dec
912+
@dec
913+
def g(x: str) -> None: pass
914+
a.f('', y=2)
915+
916+
[file d.py]
917+
def dec(f): return f
918+
919+
[case testDecoratorWithFixedReturnTypeInImportCycle]
920+
import a
921+
922+
[file a.py]
923+
import b
924+
from d import dec
925+
@dec
926+
def f(x: int) -> str: pass
927+
b.g(1)()
928+
929+
[file b.py]
930+
import a
931+
from d import dec
932+
@dec
933+
def g(x: int) -> str: pass
934+
a.f(1)()
935+
936+
[file d.py]
937+
from typing import Callable
938+
def dec(f: Callable[[int], str]) -> Callable[[int], str]: return f
939+
940+
[out]
941+
tmp/a.py:1: note: In module imported here,
942+
main:1: note: ... from here:
943+
tmp/b.py:5: error: "str" not callable
944+
main:1: note: In module imported here:
945+
tmp/a.py:5: error: "str" not callable
946+
947+
[case testDecoratorWithCallAndFixedReturnTypeInImportCycle]
948+
import a
949+
950+
[file a.py]
951+
import b
952+
from d import dec
953+
@dec()
954+
def f(x: int) -> str: pass
955+
b.g(1)()
956+
957+
[file b.py]
958+
import a
959+
from d import dec
960+
@dec()
961+
def g(x: int) -> str: pass
962+
a.f(1)()
963+
964+
[file d.py]
965+
from typing import Callable
966+
def dec() -> Callable[[Callable[[int], str]], Callable[[int], str]]: pass
967+
968+
[out]
969+
tmp/a.py:1: note: In module imported here,
970+
main:1: note: ... from here:
971+
tmp/b.py:5: error: "str" not callable
972+
main:1: note: In module imported here:
973+
tmp/a.py:5: error: "str" not callable
974+
975+
[case testDecoratorWithCallAndFixedReturnTypeInImportCycleAndDecoratorArgs]
976+
import a
977+
978+
[file a.py]
979+
import b
980+
from d import dec
981+
@dec(1)
982+
def f(x: int) -> str: pass
983+
b.g(1)()
984+
985+
[file b.py]
986+
import a
987+
from d import dec
988+
@dec(1)
989+
def g(x: int) -> str: pass
990+
a.f(1)()
991+
992+
[file d.py]
993+
from typing import Callable
994+
def dec(x: str) -> Callable[[Callable[[int], str]], Callable[[int], str]]: pass
995+
996+
[out]
997+
tmp/a.py:1: note: In module imported here,
998+
main:1: note: ... from here:
999+
tmp/b.py:3: error: Argument 1 to "dec" has incompatible type "int"; expected "str"
1000+
tmp/b.py:5: error: "str" not callable
1001+
main:1: note: In module imported here:
1002+
tmp/a.py:3: error: Argument 1 to "dec" has incompatible type "int"; expected "str"
1003+
tmp/a.py:5: error: "str" not callable
1004+
1005+
8661006
-- Conditional function definition
8671007
-- -------------------------------
8681008

0 commit comments

Comments
 (0)