Skip to content

Commit 3c5f368

Browse files
authored
Allow function arguments as base classes (#14135)
Fixes #5865 Looks quite easy and safe, unless I am missing something. Most changes in the diff are just moving stuff around. Previously we only applied argument types before type checking, but it looks like we can totally do this in semantic analyzer. I also enable variable annotated as `type` (or equivalently `Type[Any]`), this use case was mentioned in the comments. This PR also accidentally fixes two additional bugs, one related to type variables with values vs walrus operator, another one for type variables with values vs self types. I include test cases for those as well.
1 parent c660354 commit 3c5f368

10 files changed

+108
-32
lines changed

mypy/checker.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@
184184
LiteralType,
185185
NoneType,
186186
Overloaded,
187-
ParamSpecType,
188187
PartialType,
189188
ProperType,
190189
StarType,
@@ -203,14 +202,14 @@
203202
UnboundType,
204203
UninhabitedType,
205204
UnionType,
206-
UnpackType,
207205
flatten_nested_unions,
208206
get_proper_type,
209207
get_proper_types,
210208
is_literal_type,
211209
is_named_instance,
212210
is_optional,
213211
remove_optional,
212+
store_argument_type,
214213
strip_type,
215214
)
216215
from mypy.typetraverser import TypeTraverserVisitor
@@ -1174,30 +1173,8 @@ def check_func_def(
11741173
if ctx.line < 0:
11751174
ctx = typ
11761175
self.fail(message_registry.FUNCTION_PARAMETER_CANNOT_BE_COVARIANT, ctx)
1177-
if typ.arg_kinds[i] == nodes.ARG_STAR:
1178-
if isinstance(arg_type, ParamSpecType):
1179-
pass
1180-
elif isinstance(arg_type, UnpackType):
1181-
if isinstance(get_proper_type(arg_type.type), TupleType):
1182-
# Instead of using Tuple[Unpack[Tuple[...]]], just use
1183-
# Tuple[...]
1184-
arg_type = arg_type.type
1185-
else:
1186-
arg_type = TupleType(
1187-
[arg_type],
1188-
fallback=self.named_generic_type(
1189-
"builtins.tuple", [self.named_type("builtins.object")]
1190-
),
1191-
)
1192-
else:
1193-
# builtins.tuple[T] is typing.Tuple[T, ...]
1194-
arg_type = self.named_generic_type("builtins.tuple", [arg_type])
1195-
elif typ.arg_kinds[i] == nodes.ARG_STAR2:
1196-
if not isinstance(arg_type, ParamSpecType) and not typ.unpack_kwargs:
1197-
arg_type = self.named_generic_type(
1198-
"builtins.dict", [self.str_type(), arg_type]
1199-
)
1200-
item.arguments[i].variable.type = arg_type
1176+
# Need to store arguments again for the expanded item.
1177+
store_argument_type(item, i, typ, self.named_generic_type)
12011178

12021179
# Type check initialization expressions.
12031180
body_is_trivial = is_trivial_body(defn.body)

mypy/semanal.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@
273273
get_proper_types,
274274
invalid_recursive_alias,
275275
is_named_instance,
276+
store_argument_type,
276277
)
277278
from mypy.typevars import fill_typevars
278279
from mypy.util import (
@@ -1315,7 +1316,10 @@ def analyze_function_body(self, defn: FuncItem) -> None:
13151316
# Bind the type variables again to visit the body.
13161317
if defn.type:
13171318
a = self.type_analyzer()
1318-
a.bind_function_type_variables(cast(CallableType, defn.type), defn)
1319+
typ = cast(CallableType, defn.type)
1320+
a.bind_function_type_variables(typ, defn)
1321+
for i in range(len(typ.arg_types)):
1322+
store_argument_type(defn, i, typ, self.named_type)
13191323
self.function_stack.append(defn)
13201324
with self.enter(defn):
13211325
for arg in defn.arguments:
@@ -2018,7 +2022,9 @@ def analyze_base_classes(
20182022
continue
20192023

20202024
try:
2021-
base = self.expr_to_analyzed_type(base_expr, allow_placeholder=True)
2025+
base = self.expr_to_analyzed_type(
2026+
base_expr, allow_placeholder=True, allow_type_any=True
2027+
)
20222028
except TypeTranslationError:
20232029
name = self.get_name_repr_of_expr(base_expr)
20242030
if isinstance(base_expr, CallExpr):
@@ -6139,7 +6145,11 @@ def accept(self, node: Node) -> None:
61396145
report_internal_error(err, self.errors.file, node.line, self.errors, self.options)
61406146

61416147
def expr_to_analyzed_type(
6142-
self, expr: Expression, report_invalid_types: bool = True, allow_placeholder: bool = False
6148+
self,
6149+
expr: Expression,
6150+
report_invalid_types: bool = True,
6151+
allow_placeholder: bool = False,
6152+
allow_type_any: bool = False,
61436153
) -> Type | None:
61446154
if isinstance(expr, CallExpr):
61456155
# This is a legacy syntax intended mostly for Python 2, we keep it for
@@ -6164,7 +6174,10 @@ def expr_to_analyzed_type(
61646174
return TupleType(info.tuple_type.items, fallback=fallback)
61656175
typ = self.expr_to_unanalyzed_type(expr)
61666176
return self.anal_type(
6167-
typ, report_invalid_types=report_invalid_types, allow_placeholder=allow_placeholder
6177+
typ,
6178+
report_invalid_types=report_invalid_types,
6179+
allow_placeholder=allow_placeholder,
6180+
allow_type_any=allow_type_any,
61686181
)
61696182

61706183
def analyze_type_expr(self, expr: Expression) -> None:
@@ -6188,6 +6201,7 @@ def type_analyzer(
61886201
allow_param_spec_literals: bool = False,
61896202
report_invalid_types: bool = True,
61906203
prohibit_self_type: str | None = None,
6204+
allow_type_any: bool = False,
61916205
) -> TypeAnalyser:
61926206
if tvar_scope is None:
61936207
tvar_scope = self.tvar_scope
@@ -6204,6 +6218,7 @@ def type_analyzer(
62046218
allow_required=allow_required,
62056219
allow_param_spec_literals=allow_param_spec_literals,
62066220
prohibit_self_type=prohibit_self_type,
6221+
allow_type_any=allow_type_any,
62076222
)
62086223
tpan.in_dynamic_func = bool(self.function_stack and self.function_stack[-1].is_dynamic())
62096224
tpan.global_scope = not self.type and not self.function_stack
@@ -6224,6 +6239,7 @@ def anal_type(
62246239
allow_param_spec_literals: bool = False,
62256240
report_invalid_types: bool = True,
62266241
prohibit_self_type: str | None = None,
6242+
allow_type_any: bool = False,
62276243
third_pass: bool = False,
62286244
) -> Type | None:
62296245
"""Semantically analyze a type.
@@ -6260,6 +6276,7 @@ def anal_type(
62606276
allow_param_spec_literals=allow_param_spec_literals,
62616277
report_invalid_types=report_invalid_types,
62626278
prohibit_self_type=prohibit_self_type,
6279+
allow_type_any=allow_type_any,
62636280
)
62646281
tag = self.track_incomplete_refs()
62656282
typ = typ.accept(a)

mypy/stubtest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def _verify_final(
354354
) -> Iterator[Error]:
355355
try:
356356

357-
class SubClass(runtime): # type: ignore[misc,valid-type]
357+
class SubClass(runtime): # type: ignore[misc]
358358
pass
359359

360360
except TypeError:

mypy/treetransform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ def visit_super_expr(self, node: SuperExpr) -> SuperExpr:
550550
return new
551551

552552
def visit_assignment_expr(self, node: AssignmentExpr) -> AssignmentExpr:
553-
return AssignmentExpr(node.target, node.value)
553+
return AssignmentExpr(self.expr(node.target), self.expr(node.value))
554554

555555
def visit_unary_expr(self, node: UnaryExpr) -> UnaryExpr:
556556
new = UnaryExpr(node.op, self.expr(node.expr))

mypy/typeanal.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def __init__(
201201
allow_param_spec_literals: bool = False,
202202
report_invalid_types: bool = True,
203203
prohibit_self_type: str | None = None,
204+
allow_type_any: bool = False,
204205
) -> None:
205206
self.api = api
206207
self.lookup_qualified = api.lookup_qualified
@@ -237,6 +238,8 @@ def __init__(
237238
# Names of type aliases encountered while analysing a type will be collected here.
238239
self.aliases_used: set[str] = set()
239240
self.prohibit_self_type = prohibit_self_type
241+
# Allow variables typed as Type[Any] and type (useful for base classes).
242+
self.allow_type_any = allow_type_any
240243

241244
def visit_unbound_type(self, t: UnboundType, defining_literal: bool = False) -> Type:
242245
typ = self.visit_unbound_type_nonoptional(t, defining_literal)
@@ -730,6 +733,11 @@ def analyze_unbound_type_without_type_info(
730733
return AnyType(
731734
TypeOfAny.from_unimported_type, missing_import_name=typ.missing_import_name
732735
)
736+
elif self.allow_type_any:
737+
if isinstance(typ, Instance) and typ.type.fullname == "builtins.type":
738+
return AnyType(TypeOfAny.special_form)
739+
if isinstance(typ, TypeType) and isinstance(typ.item, AnyType):
740+
return AnyType(TypeOfAny.from_another_any, source_any=typ.item)
733741
# Option 2:
734742
# Unbound type variable. Currently these may be still valid,
735743
# for example when defining a generic type alias.

mypy/types.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import (
88
TYPE_CHECKING,
99
Any,
10+
Callable,
1011
ClassVar,
1112
Dict,
1213
Iterable,
@@ -29,6 +30,7 @@
2930
ArgKind,
3031
FakeInfo,
3132
FuncDef,
33+
FuncItem,
3234
SymbolNode,
3335
)
3436
from mypy.state import state
@@ -3402,3 +3404,29 @@ def callable_with_ellipsis(any_type: AnyType, ret_type: Type, fallback: Instance
34023404
fallback=fallback,
34033405
is_ellipsis_args=True,
34043406
)
3407+
3408+
3409+
def store_argument_type(
3410+
defn: FuncItem, i: int, typ: CallableType, named_type: Callable[[str, list[Type]], Instance]
3411+
) -> None:
3412+
arg_type = typ.arg_types[i]
3413+
if typ.arg_kinds[i] == ARG_STAR:
3414+
if isinstance(arg_type, ParamSpecType):
3415+
pass
3416+
elif isinstance(arg_type, UnpackType):
3417+
if isinstance(get_proper_type(arg_type.type), TupleType):
3418+
# Instead of using Tuple[Unpack[Tuple[...]]], just use
3419+
# Tuple[...]
3420+
arg_type = arg_type.type
3421+
else:
3422+
arg_type = TupleType(
3423+
[arg_type],
3424+
fallback=named_type("builtins.tuple", [named_type("builtins.object", [])]),
3425+
)
3426+
else:
3427+
# builtins.tuple[T] is typing.Tuple[T, ...]
3428+
arg_type = named_type("builtins.tuple", [arg_type])
3429+
elif typ.arg_kinds[i] == ARG_STAR2:
3430+
if not isinstance(arg_type, ParamSpecType) and not typ.unpack_kwargs:
3431+
arg_type = named_type("builtins.dict", [named_type("builtins.str", []), arg_type])
3432+
defn.arguments[i].variable.type = arg_type

test-data/unit/check-classes.test

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7664,3 +7664,18 @@ class C(B):
76647664
def foo(self) -> int: # E: Signature of "foo" incompatible with supertype "B"
76657665
...
76667666
[builtins fixtures/property.pyi]
7667+
7668+
[case testAllowArgumentAsBaseClass]
7669+
from typing import Any, Type
7670+
7671+
def e(b) -> None:
7672+
class D(b): ...
7673+
7674+
def f(b: Any) -> None:
7675+
class D(b): ...
7676+
7677+
def g(b: Type[Any]) -> None:
7678+
class D(b): ...
7679+
7680+
def h(b: type) -> None:
7681+
class D(b): ...

test-data/unit/check-python38.test

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,3 +718,19 @@ def f1() -> None:
718718
y = x
719719
z = x
720720
[builtins fixtures/dict.pyi]
721+
722+
[case testNarrowOnSelfInGeneric]
723+
# flags: --strict-optional
724+
from typing import Generic, TypeVar, Optional
725+
726+
T = TypeVar("T", int, str)
727+
728+
class C(Generic[T]):
729+
x: Optional[T]
730+
def meth(self) -> Optional[T]:
731+
if (y := self.x) is not None:
732+
reveal_type(y)
733+
return None
734+
[out]
735+
main:10: note: Revealed type is "builtins.int"
736+
main:10: note: Revealed type is "builtins.str"

test-data/unit/check-selftype.test

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1772,3 +1772,16 @@ class D(C): ...
17721772

17731773
reveal_type(D.f) # N: Revealed type is "def [T] (T`-1) -> T`-1"
17741774
reveal_type(D().f) # N: Revealed type is "def () -> __main__.D"
1775+
1776+
[case testTypingSelfOnSuperTypeVarValues]
1777+
from typing import Self, Generic, TypeVar
1778+
1779+
T = TypeVar("T", int, str)
1780+
1781+
class B:
1782+
def copy(self) -> Self: ...
1783+
class C(B, Generic[T]):
1784+
def copy(self) -> Self:
1785+
inst = super().copy()
1786+
reveal_type(inst) # N: Revealed type is "Self`0"
1787+
return inst

test-data/unit/semanal-types.test

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,7 @@ def f(x: int) -> None: pass
790790
def f(*args) -> None: pass
791791

792792
x = f
793+
[builtins fixtures/tuple.pyi]
793794
[out]
794795
MypyFile:1(
795796
ImportFrom:1(typing, [overload])
@@ -1032,6 +1033,7 @@ MypyFile:1(
10321033

10331034
[case testVarArgsAndKeywordArgs]
10341035
def g(*x: int, y: str = ''): pass
1036+
[builtins fixtures/tuple.pyi]
10351037
[out]
10361038
MypyFile:1(
10371039
FuncDef:1(

0 commit comments

Comments
 (0)