diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index c186ab1434ef..0a9e3856f217 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2202,13 +2202,28 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: # will be reported for types incompatible with __contains__(). # See testCustomContainsCheckStrictEquality for an example. cont_type = self.chk.analyze_container_item_type(right_type) + iter_type = right_type + changed_expr = right + proper_right_type = get_proper_type(right_type) + if (local_errors.is_errors() and + isinstance(proper_right_type, UnionType)): + typs = [] # type: List[Type] + for item in proper_right_type.relevant_items(): + temp_errors = self.msg.copy() + temp_errors.disable_count = 0 + self.check_method_call_by_name( + '__contains__', item, [left], [ARG_POS], e, temp_errors) + if temp_errors.is_errors(): + typs.append(item) + iter_type = UnionType.make_union(typs) + changed_expr = TempNode(iter_type) if isinstance(right_type, PartialType): # We don't really know if this is an error or not, so just shut up. pass elif (local_errors.is_errors() and # is_valid_var_arg is True for any Iterable - self.is_valid_var_arg(right_type)): - _, itertype = self.chk.analyze_iterable_item_type(right) + self.is_valid_var_arg(iter_type)): + _, itertype = self.chk.analyze_iterable_item_type(changed_expr) method_type = CallableType( [left_type], [nodes.ARG_POS], diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 4eb52be6f8bd..d0127d9c916b 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -406,6 +406,15 @@ class D(Iterable[A]): def __iter__(self) -> Iterator[A]: pass [builtins fixtures/bool.pyi] +[case testInOperatorUnion] +from typing import Iterable, Union, Container +u: Union[Iterable[str], Container[str], Iterable[str]] +'x' in u + +w: Union[Container[str], int] +'x' in w # E: Unsupported right operand type for in ("Union[Container[str], int]") +[typing fixtures/typing-full.pyi] + [case testNotInOperator] from typing import Iterator, Iterable, Any a, b, c, d, e = None, None, None, None, None # type: (A, B, bool, D, Any)