Skip to content

Commit 1d247ea

Browse files
authored
Fix bug with in operator used with a union of Container and Iterable (#14384)
Fixes #4954. Modifies analysis of `in` comparison expressions. Previously, mypy would check the right operand of an `in` expression to see if it was a union of `Container`s, and then if it was a union of `Iterable`s, but would fail on unions of both `Container`s and `Iterable`s.
1 parent b2cf9d1 commit 1d247ea

File tree

3 files changed

+127
-51
lines changed

3 files changed

+127
-51
lines changed

mypy/checker.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4500,6 +4500,26 @@ def analyze_iterable_item_type(self, expr: Expression) -> tuple[Type, Type]:
45004500
# Non-tuple iterable.
45014501
return iterator, echk.check_method_call_by_name("__next__", iterator, [], [], expr)[0]
45024502

4503+
def analyze_iterable_item_type_without_expression(
4504+
self, type: Type, context: Context
4505+
) -> tuple[Type, Type]:
4506+
"""Analyse iterable type and return iterator and iterator item types."""
4507+
echk = self.expr_checker
4508+
iterable = get_proper_type(type)
4509+
iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], context)[0]
4510+
4511+
if isinstance(iterable, TupleType):
4512+
joined: Type = UninhabitedType()
4513+
for item in iterable.items:
4514+
joined = join_types(joined, item)
4515+
return iterator, joined
4516+
else:
4517+
# Non-tuple iterable.
4518+
return (
4519+
iterator,
4520+
echk.check_method_call_by_name("__next__", iterator, [], [], context)[0],
4521+
)
4522+
45034523
def analyze_range_native_int_type(self, expr: Expression) -> Type | None:
45044524
"""Try to infer native int item type from arguments to range(...).
45054525

mypy/checkexpr.py

Lines changed: 90 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2919,75 +2919,116 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
29192919
That is, 'a < b > c == d' is check as 'a < b and b > c and c == d'
29202920
"""
29212921
result: Type | None = None
2922-
sub_result: Type | None = None
2922+
sub_result: Type
29232923

29242924
# Check each consecutive operand pair and their operator
29252925
for left, right, operator in zip(e.operands, e.operands[1:], e.operators):
29262926
left_type = self.accept(left)
29272927

2928-
method_type: mypy.types.Type | None = None
2929-
29302928
if operator == "in" or operator == "not in":
2929+
# This case covers both iterables and containers, which have different meanings.
2930+
# For a container, the in operator calls the __contains__ method.
2931+
# For an iterable, the in operator iterates over the iterable, and compares each item one-by-one.
2932+
# We allow `in` for a union of containers and iterables as long as at least one of them matches the
2933+
# type of the left operand, as the operation will simply return False if the union's container/iterator
2934+
# type doesn't match the left operand.
2935+
29312936
# If the right operand has partial type, look it up without triggering
29322937
# a "Need type annotation ..." message, as it would be noise.
29332938
right_type = self.find_partial_type_ref_fast_path(right)
29342939
if right_type is None:
29352940
right_type = self.accept(right) # Validate the right operand
29362941

2937-
# Keep track of whether we get type check errors (these won't be reported, they
2938-
# are just to verify whether something is valid typing wise).
2939-
with self.msg.filter_errors(save_filtered_errors=True) as local_errors:
2940-
_, method_type = self.check_method_call_by_name(
2941-
method="__contains__",
2942-
base_type=right_type,
2943-
args=[left],
2944-
arg_kinds=[ARG_POS],
2945-
context=e,
2946-
)
2942+
right_type = get_proper_type(right_type)
2943+
item_types: Sequence[Type] = [right_type]
2944+
if isinstance(right_type, UnionType):
2945+
item_types = list(right_type.items)
29472946

29482947
sub_result = self.bool_type()
2949-
# Container item type for strict type overlap checks. Note: we need to only
2950-
# check for nominal type, because a usual "Unsupported operands for in"
2951-
# will be reported for types incompatible with __contains__().
2952-
# See testCustomContainsCheckStrictEquality for an example.
2953-
cont_type = self.chk.analyze_container_item_type(right_type)
2954-
if isinstance(right_type, PartialType):
2955-
# We don't really know if this is an error or not, so just shut up.
2956-
pass
2957-
elif (
2958-
local_errors.has_new_errors()
2959-
and
2960-
# is_valid_var_arg is True for any Iterable
2961-
self.is_valid_var_arg(right_type)
2962-
):
2963-
_, itertype = self.chk.analyze_iterable_item_type(right)
2964-
method_type = CallableType(
2965-
[left_type],
2966-
[nodes.ARG_POS],
2967-
[None],
2968-
self.bool_type(),
2969-
self.named_type("builtins.function"),
2970-
)
2971-
if not is_subtype(left_type, itertype):
2972-
self.msg.unsupported_operand_types("in", left_type, right_type, e)
2973-
# Only show dangerous overlap if there are no other errors.
2974-
elif (
2975-
not local_errors.has_new_errors()
2976-
and cont_type
2977-
and self.dangerous_comparison(
2978-
left_type, cont_type, original_container=right_type, prefer_literal=False
2979-
)
2980-
):
2981-
self.msg.dangerous_comparison(left_type, cont_type, "container", e)
2982-
else:
2983-
self.msg.add_errors(local_errors.filtered_errors())
2948+
2949+
container_types: list[Type] = []
2950+
iterable_types: list[Type] = []
2951+
failed_out = False
2952+
encountered_partial_type = False
2953+
2954+
for item_type in item_types:
2955+
# Keep track of whether we get type check errors (these won't be reported, they
2956+
# are just to verify whether something is valid typing wise).
2957+
with self.msg.filter_errors(save_filtered_errors=True) as container_errors:
2958+
_, method_type = self.check_method_call_by_name(
2959+
method="__contains__",
2960+
base_type=item_type,
2961+
args=[left],
2962+
arg_kinds=[ARG_POS],
2963+
context=e,
2964+
original_type=right_type,
2965+
)
2966+
# Container item type for strict type overlap checks. Note: we need to only
2967+
# check for nominal type, because a usual "Unsupported operands for in"
2968+
# will be reported for types incompatible with __contains__().
2969+
# See testCustomContainsCheckStrictEquality for an example.
2970+
cont_type = self.chk.analyze_container_item_type(item_type)
2971+
2972+
if isinstance(item_type, PartialType):
2973+
# We don't really know if this is an error or not, so just shut up.
2974+
encountered_partial_type = True
2975+
pass
2976+
elif (
2977+
container_errors.has_new_errors()
2978+
and
2979+
# is_valid_var_arg is True for any Iterable
2980+
self.is_valid_var_arg(item_type)
2981+
):
2982+
# it's not a container, but it is an iterable
2983+
with self.msg.filter_errors(save_filtered_errors=True) as iterable_errors:
2984+
_, itertype = self.chk.analyze_iterable_item_type_without_expression(
2985+
item_type, e
2986+
)
2987+
if iterable_errors.has_new_errors():
2988+
self.msg.add_errors(iterable_errors.filtered_errors())
2989+
failed_out = True
2990+
else:
2991+
method_type = CallableType(
2992+
[left_type],
2993+
[nodes.ARG_POS],
2994+
[None],
2995+
self.bool_type(),
2996+
self.named_type("builtins.function"),
2997+
)
2998+
e.method_types.append(method_type)
2999+
iterable_types.append(itertype)
3000+
elif not container_errors.has_new_errors() and cont_type:
3001+
container_types.append(cont_type)
3002+
e.method_types.append(method_type)
3003+
else:
3004+
self.msg.add_errors(container_errors.filtered_errors())
3005+
failed_out = True
3006+
3007+
if not encountered_partial_type and not failed_out:
3008+
iterable_type = UnionType.make_union(iterable_types)
3009+
if not is_subtype(left_type, iterable_type):
3010+
if len(container_types) == 0:
3011+
self.msg.unsupported_operand_types("in", left_type, right_type, e)
3012+
else:
3013+
container_type = UnionType.make_union(container_types)
3014+
if self.dangerous_comparison(
3015+
left_type,
3016+
container_type,
3017+
original_container=right_type,
3018+
prefer_literal=False,
3019+
):
3020+
self.msg.dangerous_comparison(
3021+
left_type, container_type, "container", e
3022+
)
3023+
29843024
elif operator in operators.op_methods:
29853025
method = operators.op_methods[operator]
29863026

29873027
with ErrorWatcher(self.msg.errors) as w:
29883028
sub_result, method_type = self.check_op(
29893029
method, left_type, right, e, allow_reverse=True
29903030
)
3031+
e.method_types.append(method_type)
29913032

29923033
# Only show dangerous overlap if there are no other errors. See
29933034
# testCustomEqCheckStrictEquality for an example.
@@ -3007,12 +3048,10 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
30073048
left_type = try_getting_literal(left_type)
30083049
right_type = try_getting_literal(right_type)
30093050
self.msg.dangerous_comparison(left_type, right_type, "identity", e)
3010-
method_type = None
3051+
e.method_types.append(None)
30113052
else:
30123053
raise RuntimeError(f"Unknown comparison operator {operator}")
30133054

3014-
e.method_types.append(method_type)
3015-
30163055
# Determine type of boolean-and of result and sub_result
30173056
if result is None:
30183057
result = sub_result

test-data/unit/check-unions.test

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,3 +1202,20 @@ def foo(
12021202
yield i
12031203
foo([1])
12041204
[builtins fixtures/list.pyi]
1205+
1206+
[case testUnionIterableContainer]
1207+
from typing import Iterable, Container, Union
1208+
1209+
i: Iterable[str]
1210+
c: Container[str]
1211+
u: Union[Iterable[str], Container[str]]
1212+
ni: Union[Iterable[str], int]
1213+
nc: Union[Container[str], int]
1214+
1215+
'x' in i
1216+
'x' in c
1217+
'x' in u
1218+
'x' in ni # E: Unsupported right operand type for in ("Union[Iterable[str], int]")
1219+
'x' in nc # E: Unsupported right operand type for in ("Union[Container[str], int]")
1220+
[builtins fixtures/tuple.pyi]
1221+
[typing fixtures/typing-full.pyi]

0 commit comments

Comments
 (0)