diff --git a/mypy/checker.py b/mypy/checker.py index ca3e46b4e158..bea8813c1984 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -39,7 +39,8 @@ from mypy.sametypes import is_same_type from mypy.messages import ( MessageBuilder, make_inferred_type_note, append_invariance_notes, pretty_seq, - format_type, format_type_bare, format_type_distinctly, SUGGESTED_TEST_FIXTURES + format_type, format_type_bare, format_type_distinctly, SUGGESTED_TEST_FIXTURES, + temp_message_builder ) import mypy.checkexpr from mypy.checkmember import ( @@ -3499,20 +3500,40 @@ def analyze_iterable_item_type(self, expr: Expression) -> Tuple[Type, Type]: """Analyse iterable expression and return iterator and iterator item types.""" echk = self.expr_checker iterable = get_proper_type(echk.accept(expr)) - iterator = echk.check_method_call_by_name('__iter__', iterable, [], [], expr)[0] + iter_msg_builder = self.msg.clean_copy() + iterator = echk.check_method_call_by_name( + '__iter__', iterable, [], [], expr, iter_msg_builder)[0] if isinstance(iterable, TupleType): joined = UninhabitedType() # type: Type for item in iterable.items: joined = join_types(joined, item) return iterator, joined + + if iter_msg_builder.is_errors(): + # We couldn't find __iter__ so let's try __getitem__ + getitem_msg_builder = temp_message_builder() + arg = self.temp_node(self.named_type("builtins.int"), expr) + getitem_type = echk.check_method_call_by_name( + '__getitem__', + iterable, + [arg], + [nodes.ARG_POS], + expr, + getitem_msg_builder + )[0] + + if not getitem_msg_builder.is_errors(): + # We found __getitem__ + return self.named_generic_type("typing.Iterator", [getitem_type]), getitem_type + + self.msg.add_errors(iter_msg_builder) + # Non-tuple iterable. + if self.options.python_version[0] >= 3: + nextmethod = '__next__' else: - # Non-tuple iterable. - if self.options.python_version[0] >= 3: - nextmethod = '__next__' - else: - nextmethod = 'next' - return iterator, echk.check_method_call_by_name(nextmethod, iterator, [], [], expr)[0] + nextmethod = 'next' + return iterator, echk.check_method_call_by_name(nextmethod, iterator, [], [], expr)[0] def analyze_container_item_type(self, typ: Type) -> Optional[Type]: """Check if a type is a nominal container of a union of such. diff --git a/test-data/unit/check-protocols.test b/test-data/unit/check-protocols.test index f13d2bc597da..f446f2a9a048 100644 --- a/test-data/unit/check-protocols.test +++ b/test-data/unit/check-protocols.test @@ -1940,6 +1940,17 @@ reveal_type(list(c for c in C)) # N: Revealed type is "builtins.list[__main__.C reveal_type(list(C)) # N: Revealed type is "builtins.list[__main__.C*]" [builtins fixtures/list.pyi] +[case testIterableGetItemOnClass] +class A: + def __getitem__(self, x: int) -> int: pass + +class B: + def __getitem__(self, x: str) -> str: pass + +reveal_type(list(a for a in A())) # N: Revealed type is "builtins.list[builtins.int*]" +list(b for b in B()) # E: "B" has no attribute "__iter__" (not iterable) +[builtins fixtures/list.pyi] + [case testClassesGetattrWithProtocols] from typing import Protocol