@@ -2919,75 +2919,116 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
2919
2919
That is, 'a < b > c == d' is check as 'a < b and b > c and c == d'
2920
2920
"""
2921
2921
result : Type | None = None
2922
- sub_result : Type | None = None
2922
+ sub_result : Type
2923
2923
2924
2924
# Check each consecutive operand pair and their operator
2925
2925
for left , right , operator in zip (e .operands , e .operands [1 :], e .operators ):
2926
2926
left_type = self .accept (left )
2927
2927
2928
- method_type : mypy .types .Type | None = None
2929
-
2930
2928
if operator == "in" or operator == "not in" :
2929
+ # This case covers both iterables and containers, which have different meanings.
2930
+ # For a container, the in operator calls the __contains__ method.
2931
+ # For an iterable, the in operator iterates over the iterable, and compares each item one-by-one.
2932
+ # We allow `in` for a union of containers and iterables as long as at least one of them matches the
2933
+ # type of the left operand, as the operation will simply return False if the union's container/iterator
2934
+ # type doesn't match the left operand.
2935
+
2931
2936
# If the right operand has partial type, look it up without triggering
2932
2937
# a "Need type annotation ..." message, as it would be noise.
2933
2938
right_type = self .find_partial_type_ref_fast_path (right )
2934
2939
if right_type is None :
2935
2940
right_type = self .accept (right ) # Validate the right operand
2936
2941
2937
- # Keep track of whether we get type check errors (these won't be reported, they
2938
- # are just to verify whether something is valid typing wise).
2939
- with self .msg .filter_errors (save_filtered_errors = True ) as local_errors :
2940
- _ , method_type = self .check_method_call_by_name (
2941
- method = "__contains__" ,
2942
- base_type = right_type ,
2943
- args = [left ],
2944
- arg_kinds = [ARG_POS ],
2945
- context = e ,
2946
- )
2942
+ right_type = get_proper_type (right_type )
2943
+ item_types : Sequence [Type ] = [right_type ]
2944
+ if isinstance (right_type , UnionType ):
2945
+ item_types = list (right_type .items )
2947
2946
2948
2947
sub_result = self .bool_type ()
2949
- # Container item type for strict type overlap checks. Note: we need to only
2950
- # check for nominal type, because a usual "Unsupported operands for in"
2951
- # will be reported for types incompatible with __contains__().
2952
- # See testCustomContainsCheckStrictEquality for an example.
2953
- cont_type = self .chk .analyze_container_item_type (right_type )
2954
- if isinstance (right_type , PartialType ):
2955
- # We don't really know if this is an error or not, so just shut up.
2956
- pass
2957
- elif (
2958
- local_errors .has_new_errors ()
2959
- and
2960
- # is_valid_var_arg is True for any Iterable
2961
- self .is_valid_var_arg (right_type )
2962
- ):
2963
- _ , itertype = self .chk .analyze_iterable_item_type (right )
2964
- method_type = CallableType (
2965
- [left_type ],
2966
- [nodes .ARG_POS ],
2967
- [None ],
2968
- self .bool_type (),
2969
- self .named_type ("builtins.function" ),
2970
- )
2971
- if not is_subtype (left_type , itertype ):
2972
- self .msg .unsupported_operand_types ("in" , left_type , right_type , e )
2973
- # Only show dangerous overlap if there are no other errors.
2974
- elif (
2975
- not local_errors .has_new_errors ()
2976
- and cont_type
2977
- and self .dangerous_comparison (
2978
- left_type , cont_type , original_container = right_type , prefer_literal = False
2979
- )
2980
- ):
2981
- self .msg .dangerous_comparison (left_type , cont_type , "container" , e )
2982
- else :
2983
- self .msg .add_errors (local_errors .filtered_errors ())
2948
+
2949
+ container_types : list [Type ] = []
2950
+ iterable_types : list [Type ] = []
2951
+ failed_out = False
2952
+ encountered_partial_type = False
2953
+
2954
+ for item_type in item_types :
2955
+ # Keep track of whether we get type check errors (these won't be reported, they
2956
+ # are just to verify whether something is valid typing wise).
2957
+ with self .msg .filter_errors (save_filtered_errors = True ) as container_errors :
2958
+ _ , method_type = self .check_method_call_by_name (
2959
+ method = "__contains__" ,
2960
+ base_type = item_type ,
2961
+ args = [left ],
2962
+ arg_kinds = [ARG_POS ],
2963
+ context = e ,
2964
+ original_type = right_type ,
2965
+ )
2966
+ # Container item type for strict type overlap checks. Note: we need to only
2967
+ # check for nominal type, because a usual "Unsupported operands for in"
2968
+ # will be reported for types incompatible with __contains__().
2969
+ # See testCustomContainsCheckStrictEquality for an example.
2970
+ cont_type = self .chk .analyze_container_item_type (item_type )
2971
+
2972
+ if isinstance (item_type , PartialType ):
2973
+ # We don't really know if this is an error or not, so just shut up.
2974
+ encountered_partial_type = True
2975
+ pass
2976
+ elif (
2977
+ container_errors .has_new_errors ()
2978
+ and
2979
+ # is_valid_var_arg is True for any Iterable
2980
+ self .is_valid_var_arg (item_type )
2981
+ ):
2982
+ # it's not a container, but it is an iterable
2983
+ with self .msg .filter_errors (save_filtered_errors = True ) as iterable_errors :
2984
+ _ , itertype = self .chk .analyze_iterable_item_type_without_expression (
2985
+ item_type , e
2986
+ )
2987
+ if iterable_errors .has_new_errors ():
2988
+ self .msg .add_errors (iterable_errors .filtered_errors ())
2989
+ failed_out = True
2990
+ else :
2991
+ method_type = CallableType (
2992
+ [left_type ],
2993
+ [nodes .ARG_POS ],
2994
+ [None ],
2995
+ self .bool_type (),
2996
+ self .named_type ("builtins.function" ),
2997
+ )
2998
+ e .method_types .append (method_type )
2999
+ iterable_types .append (itertype )
3000
+ elif not container_errors .has_new_errors () and cont_type :
3001
+ container_types .append (cont_type )
3002
+ e .method_types .append (method_type )
3003
+ else :
3004
+ self .msg .add_errors (container_errors .filtered_errors ())
3005
+ failed_out = True
3006
+
3007
+ if not encountered_partial_type and not failed_out :
3008
+ iterable_type = UnionType .make_union (iterable_types )
3009
+ if not is_subtype (left_type , iterable_type ):
3010
+ if len (container_types ) == 0 :
3011
+ self .msg .unsupported_operand_types ("in" , left_type , right_type , e )
3012
+ else :
3013
+ container_type = UnionType .make_union (container_types )
3014
+ if self .dangerous_comparison (
3015
+ left_type ,
3016
+ container_type ,
3017
+ original_container = right_type ,
3018
+ prefer_literal = False ,
3019
+ ):
3020
+ self .msg .dangerous_comparison (
3021
+ left_type , container_type , "container" , e
3022
+ )
3023
+
2984
3024
elif operator in operators .op_methods :
2985
3025
method = operators .op_methods [operator ]
2986
3026
2987
3027
with ErrorWatcher (self .msg .errors ) as w :
2988
3028
sub_result , method_type = self .check_op (
2989
3029
method , left_type , right , e , allow_reverse = True
2990
3030
)
3031
+ e .method_types .append (method_type )
2991
3032
2992
3033
# Only show dangerous overlap if there are no other errors. See
2993
3034
# testCustomEqCheckStrictEquality for an example.
@@ -3007,12 +3048,10 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
3007
3048
left_type = try_getting_literal (left_type )
3008
3049
right_type = try_getting_literal (right_type )
3009
3050
self .msg .dangerous_comparison (left_type , right_type , "identity" , e )
3010
- method_type = None
3051
+ e . method_types . append ( None )
3011
3052
else :
3012
3053
raise RuntimeError (f"Unknown comparison operator { operator } " )
3013
3054
3014
- e .method_types .append (method_type )
3015
-
3016
3055
# Determine type of boolean-and of result and sub_result
3017
3056
if result is None :
3018
3057
result = sub_result
0 commit comments