Skip to content

Commit cc3f5c2

Browse files
committed
Mypy now treats classes with __getitem__ as iterable
1 parent 0c0f071 commit cc3f5c2

9 files changed

+337
-38
lines changed

mypy/checker.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3429,9 +3429,9 @@ def type_is_iterable(self, type: Type) -> bool:
34293429
type = get_proper_type(type)
34303430
if isinstance(type, CallableType) and type.is_type_obj():
34313431
type = type.fallback
3432-
return is_subtype(
3433-
type, self.named_generic_type("typing.Iterable", [AnyType(TypeOfAny.special_form)])
3434-
)
3432+
with self.msg.filter_errors() as iter_errors:
3433+
self.analyze_iterable_item_type(TempNode(type))
3434+
return not iter_errors.has_new_errors()
34353435

34363436
def check_multi_assignment_from_iterable(
34373437
self,
@@ -4247,15 +4247,36 @@ def analyze_iterable_item_type(self, expr: Expression) -> tuple[Type, Type]:
42474247
"""Analyse iterable expression and return iterator and iterator item types."""
42484248
echk = self.expr_checker
42494249
iterable = get_proper_type(echk.accept(expr))
4250-
iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], expr)[0]
42514250

4251+
# We first try to find `__iter__` magic method.
4252+
# If it is present, we go on with it.
4253+
# But, python also support iterables with just `__getitem__(index) -> Any` defined.
4254+
# So, we check it in case `__iter__` is missing.
4255+
with self.msg.filter_errors(save_filtered_errors=True) as iter_errors:
4256+
# We save original error to show it later if `__getitem__` is also missing.
4257+
iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], expr)[0]
4258+
if iter_errors.has_new_errors():
4259+
# `__iter__` is missing, try `__getattr__`:
4260+
arg = self.temp_node(AnyType(TypeOfAny.implementation_artifact), expr)
4261+
with self.msg.filter_errors() as getitem_errors:
4262+
getitem_type = echk.check_method_call_by_name(
4263+
"__getitem__", iterable, [arg], [nodes.ARG_POS], expr
4264+
)[0]
4265+
if getitem_errors.has_new_errors(): # Both are missing.
4266+
self.msg.add_errors(iter_errors.filtered_errors())
4267+
return AnyType(TypeOfAny.from_error), AnyType(TypeOfAny.from_error)
4268+
else:
4269+
# We found just `__getitem__`, it does not follow the same
4270+
# semantics as `__iter__`, so: just return what we found.
4271+
return self.named_generic_type("typing.Iterator", [getitem_type]), getitem_type
4272+
4273+
# We found `__iter__`, let's analyze its return type:
42524274
if isinstance(iterable, TupleType):
42534275
joined: Type = UninhabitedType()
42544276
for item in iterable.items:
42554277
joined = join_types(joined, item)
42564278
return iterator, joined
4257-
else:
4258-
# Non-tuple iterable.
4279+
else: # Non-tuple iterable.
42594280
return iterator, echk.check_method_call_by_name("__next__", iterator, [], [], expr)[0]
42604281

42614282
def analyze_container_item_type(self, typ: Type) -> Type | None:
@@ -5986,15 +6007,9 @@ def iterable_item_type(self, instance: Instance) -> Type:
59866007
# This relies on 'map_instance_to_supertype' returning 'Iterable[Any]'
59876008
# in case there is no explicit base class.
59886009
return item_type
5989-
# Try also structural typing.
5990-
iter_type = get_proper_type(find_member("__iter__", instance, instance, is_operator=True))
5991-
if iter_type and isinstance(iter_type, CallableType):
5992-
ret_type = get_proper_type(iter_type.ret_type)
5993-
if isinstance(ret_type, Instance):
5994-
iterator = map_instance_to_supertype(
5995-
ret_type, self.lookup_typeinfo("typing.Iterator")
5996-
)
5997-
item_type = iterator.args[0]
6010+
6011+
# Try also structural typing: including `__iter__` and `__getitem__`.
6012+
_, item_type = self.analyze_iterable_item_type(TempNode(instance))
59986013
return item_type
59996014

60006015
def function_type(self, func: FuncBase) -> FunctionLike:

mypy/checkexpr.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2998,13 +2998,14 @@ def check_method_call_by_name(
29982998
method,
29992999
base_type,
30003000
context,
3001-
False,
3002-
False,
3003-
True,
3004-
self.msg,
3001+
is_lvalue=False,
3002+
is_super=False,
3003+
is_operator=True,
3004+
msg=self.msg,
30053005
original_type=original_type,
30063006
chk=self.chk,
30073007
in_literal_context=self.is_literal_context(),
3008+
suggest_awaitable=False,
30083009
)
30093010
return self.check_method_call(method, base_type, method_type, args, arg_kinds, context)
30103011

@@ -4806,13 +4807,8 @@ def visit_yield_from_expr(self, e: YieldFromExpr, allow_none_return: bool = Fals
48064807
if is_async_def(subexpr_type) and not has_coroutine_decorator(return_type):
48074808
self.chk.msg.yield_from_invalid_operand_type(subexpr_type, e)
48084809

4809-
any_type = AnyType(TypeOfAny.special_form)
4810-
generic_generator_type = self.chk.named_generic_type(
4811-
"typing.Generator", [any_type, any_type, any_type]
4812-
)
4813-
iter_type, _ = self.check_method_call_by_name(
4814-
"__iter__", subexpr_type, [], [], context=generic_generator_type
4815-
)
4810+
iter_type, _ = self.chk.analyze_iterable_item_type(TempNode(subexpr_type))
4811+
iter_type = get_proper_type(iter_type)
48164812
else:
48174813
if not (is_async_def(subexpr_type) and has_coroutine_decorator(return_type)):
48184814
self.chk.msg.yield_from_invalid_operand_type(subexpr_type, e)

mypy/checkmember.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def __init__(
9090
chk: mypy.checker.TypeChecker,
9191
self_type: Type | None,
9292
module_symbol_table: SymbolTable | None = None,
93+
suggest_awaitable: bool = True,
9394
) -> None:
9495
self.is_lvalue = is_lvalue
9596
self.is_super = is_super
@@ -100,6 +101,7 @@ def __init__(
100101
self.msg = msg
101102
self.chk = chk
102103
self.module_symbol_table = module_symbol_table
104+
self.suggest_awaitable = suggest_awaitable
103105

104106
def named_type(self, name: str) -> Instance:
105107
return self.chk.named_type(name)
@@ -149,6 +151,7 @@ def analyze_member_access(
149151
in_literal_context: bool = False,
150152
self_type: Type | None = None,
151153
module_symbol_table: SymbolTable | None = None,
154+
suggest_awaitable: bool = True,
152155
) -> Type:
153156
"""Return the type of attribute 'name' of 'typ'.
154157
@@ -183,6 +186,7 @@ def analyze_member_access(
183186
chk=chk,
184187
self_type=self_type,
185188
module_symbol_table=module_symbol_table,
189+
suggest_awaitable=suggest_awaitable,
186190
)
187191
result = _analyze_member_access(name, typ, mx, override_info)
188192
possible_literal = get_proper_type(result)
@@ -260,7 +264,7 @@ def report_missing_attribute(
260264
override_info: TypeInfo | None = None,
261265
) -> Type:
262266
res_type = mx.msg.has_no_attr(original_type, typ, name, mx.context, mx.module_symbol_table)
263-
if may_be_awaitable_attribute(name, typ, mx, override_info):
267+
if mx.suggest_awaitable and may_be_awaitable_attribute(name, typ, mx, override_info):
264268
mx.msg.possible_missing_await(mx.context)
265269
return res_type
266270

mypy/constraints.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -750,8 +750,7 @@ def infer_constraints_from_protocol_members(
750750
"""
751751
res = []
752752
for member in protocol.type.protocol_members:
753-
inst = mypy.subtypes.find_member(member, instance, subtype)
754-
temp = mypy.subtypes.find_member(member, template, subtype)
753+
inst, temp = mypy.subtypes.find_members(member, instance, template, subtype)
755754
if inst is None or temp is None:
756755
return [] # See #11020
757756
# The above is safe since at this point we know that 'instance' is a subtype

mypy/subtypes.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from contextlib import contextmanager
4-
from typing import Any, Callable, Iterator, List, TypeVar, cast
4+
from typing import Any, Callable, Iterator, List, Tuple, TypeVar, cast
55
from typing_extensions import Final, TypeAlias as _TypeAlias
66

77
import mypy.applytype
@@ -14,6 +14,7 @@
1414
# Circular import; done in the function instead.
1515
# import mypy.solve
1616
from mypy.nodes import (
17+
ARG_POS,
1718
ARG_STAR,
1819
ARG_STAR2,
1920
CONTRAVARIANT,
@@ -956,9 +957,10 @@ def f(self) -> A: ...
956957
ignore_names = member != "__call__" # __call__ can be passed kwargs
957958
# The third argument below indicates to what self type is bound.
958959
# We always bind self to the subtype. (Similarly to nominal types).
959-
supertype = get_proper_type(find_member(member, right, left))
960+
supertype, subtype = find_members(member, right, left, left)
961+
supertype = get_proper_type(supertype)
960962
assert supertype is not None
961-
subtype = get_proper_type(find_member(member, left, left))
963+
subtype = get_proper_type(subtype)
962964
# Useful for debugging:
963965
# print(member, 'of', left, 'has type', subtype)
964966
# print(member, 'of', right, 'has type', supertype)
@@ -1012,6 +1014,56 @@ def f(self) -> A: ...
10121014
return True
10131015

10141016

1017+
def find_members(
1018+
name: str, supertype: Instance, subtype: Instance, context: Type
1019+
) -> Tuple[Type | None, Type | None]:
1020+
"""Find types of member by name for two instances.
1021+
1022+
We do it with respect to some special cases, like `Iterable` and `__geitem__`.
1023+
"""
1024+
if name == "__iter__":
1025+
# So, this is a special case: old-style iterbale protocol
1026+
# must be supported even without explicit `__iter__` method.
1027+
# Because all types with `__geitem__` defined have default `__iter__`
1028+
# implementation. See #2220
1029+
# First, we need to find which is one actually `Iterable`:
1030+
if is_named_instance(supertype, "typing.Iterable"):
1031+
left, right = _iterable_special_member(supertype, subtype, context)
1032+
if left is not None and right is not None:
1033+
return left, right
1034+
elif is_named_instance(subtype, "typing.Iterable"):
1035+
left, right = _iterable_special_member(subtype, supertype, context)
1036+
if left is not None and right is not None:
1037+
return right, left
1038+
1039+
# This is not a special case.
1040+
# Falling back to regular `find_member` call:
1041+
return (find_member(name, supertype, context), find_member(name, subtype, context))
1042+
1043+
1044+
def _iterable_special_member(
1045+
iterable: Instance, candidate: Instance, context: Type
1046+
) -> Tuple[Type | None, Type | None]:
1047+
name = "__iter__"
1048+
iterable_method = get_proper_type(find_member(name, iterable, context))
1049+
candidate_method = get_proper_type(find_member("__getitem__", candidate, context))
1050+
if isinstance(iterable_method, CallableType) and isinstance(
1051+
(ret := get_proper_type(iterable_method.ret_type)), Instance
1052+
):
1053+
# We need to transform
1054+
# `__iter__() -> Iterable[ret]` into
1055+
# `__getitem__(Any) -> ret`
1056+
iterable_method = iterable_method.copy_modified(
1057+
arg_names=[None],
1058+
arg_types=[AnyType(TypeOfAny.implementation_artifact)],
1059+
arg_kinds=[ARG_POS],
1060+
ret_type=ret.args[0],
1061+
name="__getitem__",
1062+
)
1063+
return (iterable_method, candidate_method)
1064+
return None, None
1065+
1066+
10151067
def find_member(
10161068
name: str, itype: Instance, subtype: Type, is_operator: bool = False
10171069
) -> Type | None:

0 commit comments

Comments
 (0)