@@ -3536,67 +3536,61 @@ def find_isinstance_check(self, node: Expression
3536
3536
vartype = type_map [expr ]
3537
3537
return self .conditional_callable_type_map (expr , vartype )
3538
3538
elif isinstance (node , ComparisonExpr ):
3539
- operand_types = [coerce_to_literal (type_map [expr ])
3540
- for expr in node .operands if expr in type_map ]
3541
-
3542
- is_not = node .operators == ['is not' ]
3543
- if (is_not or node .operators == ['is' ]) and len (operand_types ) == len (node .operands ):
3544
- if_vars = {} # type: TypeMap
3545
- else_vars = {} # type: TypeMap
3546
-
3547
- for i , expr in enumerate (node .operands ):
3548
- var_type = operand_types [i ]
3549
- other_type = operand_types [1 - i ]
3550
-
3551
- if literal (expr ) == LITERAL_TYPE and is_singleton_type (other_type ):
3552
- # This should only be true at most once: there should be
3553
- # exactly two elements in node.operands and if the 'other type' is
3554
- # a singleton type, it by definition does not need to be narrowed:
3555
- # it already has the most precise type possible so does not need to
3556
- # be narrowed/included in the output map.
3557
- #
3558
- # TODO: Generalize this to handle the case where 'other_type' is
3559
- # a union of singleton types.
3539
+ operand_types = []
3540
+ for expr in node .operands :
3541
+ if expr not in type_map :
3542
+ return {}, {}
3543
+ operand_types .append (coerce_to_literal (type_map [expr ]))
3544
+
3545
+ type_maps = []
3546
+ for i , (operator , left_expr , right_expr ) in enumerate (node .pairwise ()):
3547
+ left_type = operand_types [i ]
3548
+ right_type = operand_types [i + 1 ]
3549
+
3550
+ if_map = {} # type: TypeMap
3551
+ else_map = {} # type: TypeMap
3552
+ if operator in {'in' , 'not in' }:
3553
+ right_item_type = builtin_item_type (right_type )
3554
+ if right_item_type is None or is_optional (right_item_type ):
3555
+ continue
3556
+ if (isinstance (right_item_type , Instance )
3557
+ and right_item_type .type .fullname () == 'builtins.object' ):
3558
+ continue
3559
+
3560
+ if (is_optional (left_type ) and literal (left_expr ) == LITERAL_TYPE
3561
+ and not is_literal_none (left_expr ) and
3562
+ is_overlapping_erased_types (left_type , right_item_type )):
3563
+ if_map , else_map = {left_expr : remove_optional (left_type )}, {}
3564
+ else :
3565
+ continue
3566
+ elif operator in {'==' , '!=' }:
3567
+ if_map , else_map = self .narrow_given_equality (
3568
+ left_expr , left_type , right_expr , right_type , assume_identity = False )
3569
+ elif operator in {'is' , 'is not' }:
3570
+ if_map , else_map = self .narrow_given_equality (
3571
+ left_expr , left_type , right_expr , right_type , assume_identity = True )
3572
+ else :
3573
+ continue
3560
3574
3561
- if isinstance (other_type , LiteralType ) and other_type .is_enum_literal ():
3562
- fallback_name = other_type .fallback .type .fullname ()
3563
- var_type = try_expanding_enum_to_union (var_type , fallback_name )
3575
+ if operator in {'not in' , '!=' , 'is not' }:
3576
+ if_map , else_map = else_map , if_map
3564
3577
3565
- target_type = [TypeRange (other_type , is_upper_bound = False )]
3566
- if_vars , else_vars = conditional_type_map (expr , var_type , target_type )
3567
- break
3578
+ type_maps .append ((if_map , else_map ))
3568
3579
3569
- if is_not :
3570
- if_vars , else_vars = else_vars , if_vars
3571
- return if_vars , else_vars
3572
- # Check for `x == y` where x is of type Optional[T] and y is of type T
3573
- # or a type that overlaps with T (or vice versa).
3574
- elif node .operators == ['==' ]:
3575
- first_type = type_map [node .operands [0 ]]
3576
- second_type = type_map [node .operands [1 ]]
3577
- if is_optional (first_type ) != is_optional (second_type ):
3578
- if is_optional (first_type ):
3579
- optional_type , comp_type = first_type , second_type
3580
- optional_expr = node .operands [0 ]
3581
- else :
3582
- optional_type , comp_type = second_type , first_type
3583
- optional_expr = node .operands [1 ]
3584
- if is_overlapping_erased_types (optional_type , comp_type ):
3585
- return {optional_expr : remove_optional (optional_type )}, {}
3586
- elif node .operators in [['in' ], ['not in' ]]:
3587
- expr = node .operands [0 ]
3588
- left_type = type_map [expr ]
3589
- right_type = builtin_item_type (type_map [node .operands [1 ]])
3590
- right_ok = right_type and (not is_optional (right_type ) and
3591
- (not isinstance (right_type , Instance ) or
3592
- right_type .type .fullname () != 'builtins.object' ))
3593
- if (right_type and right_ok and is_optional (left_type ) and
3594
- literal (expr ) == LITERAL_TYPE and not is_literal_none (expr ) and
3595
- is_overlapping_erased_types (left_type , right_type )):
3596
- if node .operators == ['in' ]:
3597
- return {expr : remove_optional (left_type )}, {}
3598
- if node .operators == ['not in' ]:
3599
- return {}, {expr : remove_optional (left_type )}
3580
+ if len (type_maps ) == 0 :
3581
+ return {}, {}
3582
+ elif len (type_maps ) == 1 :
3583
+ return type_maps [0 ]
3584
+ else :
3585
+ # Comparisons like 'a == b == c is d' is the same thing as
3586
+ # '(a == b) and (b == c) and (c is d)'. So after generating each
3587
+ # individual comparison's typemaps, we "and" them together here.
3588
+ # (Also see comments below where we handle the 'and' OpExpr.)
3589
+ final_if_map , final_else_map = type_maps [0 ]
3590
+ for if_map , else_map in type_maps [1 :]:
3591
+ final_if_map = and_conditional_maps (final_if_map , if_map )
3592
+ final_else_map = or_conditional_maps (final_else_map , else_map )
3593
+ return final_if_map , final_else_map
3600
3594
elif isinstance (node , RefExpr ):
3601
3595
# Restrict the type of the variable to True-ish/False-ish in the if and else branches
3602
3596
# respectively
@@ -3630,6 +3624,78 @@ def find_isinstance_check(self, node: Expression
3630
3624
# Not a supported isinstance check
3631
3625
return {}, {}
3632
3626
3627
+ def narrow_given_equality (self ,
3628
+ left_expr : Expression ,
3629
+ left_type : Type ,
3630
+ right_expr : Expression ,
3631
+ right_type : Type ,
3632
+ assume_identity : bool ,
3633
+ ) -> Tuple [TypeMap , TypeMap ]:
3634
+ """Assuming that the given 'left' and 'right' exprs are equal to each other, try
3635
+ producing TypeMaps refining the types of either the left or right exprs (or neither,
3636
+ if we can't learn anything from the comparison).
3637
+
3638
+ For more details about what TypeMaps are, see the docstring in find_isinstance_check.
3639
+
3640
+ If 'assume_identity' is true, assume that this comparison was done using an
3641
+ identity comparison (left_expr is right_expr), not just an equality comparison
3642
+ (left_expr == right_expr). Identity checks are not overridable, so we can infer
3643
+ more information in that case.
3644
+ """
3645
+
3646
+ # For the sake of simplicity, we currently attempt inferring a more precise type
3647
+ # for just one of the two variables.
3648
+ comparisons = [
3649
+ (left_expr , left_type , right_type ),
3650
+ (right_expr , right_type , left_type ),
3651
+ ]
3652
+
3653
+ for expr , expr_type , other_type in comparisons :
3654
+ # The 'expr' isn't an expression that we can refine the type of. Skip
3655
+ # attempting to refine this expr.
3656
+ if literal (expr ) != LITERAL_TYPE :
3657
+ continue
3658
+
3659
+ # Case 1: If the 'other_type' is a singleton (only one value has
3660
+ # the specified type), attempt to narrow 'expr_type' to just that
3661
+ # singleton type.
3662
+ if is_singleton_type (other_type ):
3663
+ if isinstance (other_type , LiteralType ) and other_type .is_enum_literal ():
3664
+ if not assume_identity :
3665
+ # Our checks need to be more conservative if the operand is
3666
+ # '==' or '!=': all bets are off if either of the two operands
3667
+ # has a custom `__eq__` or `__ne__` method.
3668
+ #
3669
+ # So, we permit this check to succeed only if 'other_type' does
3670
+ # not define custom equality logic
3671
+ if not uses_default_equality_checks (expr_type ):
3672
+ continue
3673
+ if not uses_default_equality_checks (other_type .fallback ):
3674
+ continue
3675
+ fallback_name = other_type .fallback .type .fullname ()
3676
+ expr_type = try_expanding_enum_to_union (expr_type , fallback_name )
3677
+
3678
+ target_type = [TypeRange (other_type , is_upper_bound = False )]
3679
+ return conditional_type_map (expr , expr_type , target_type )
3680
+
3681
+ # Case 2: Given expr_type=Union[A, None] and other_type=A, narrow to just 'A'.
3682
+ #
3683
+ # Note: This check is actually strictly speaking unsafe: stripping away the 'None'
3684
+ # would be unsound in the case where A defines an '__eq__' method that always
3685
+ # returns 'True', for example.
3686
+ #
3687
+ # We implement this check partly for backwards-compatibility reasons and partly
3688
+ # because those kinds of degenerate '__eq__' implementations are probably rare
3689
+ # enough that this is fine in practice.
3690
+ #
3691
+ # We could also probably generalize this block to strip away *any* singleton type,
3692
+ # if we were fine with a bit more unsoundness.
3693
+ if is_optional (expr_type ) and not is_optional (other_type ):
3694
+ if is_overlapping_erased_types (expr_type , other_type ):
3695
+ return {expr : remove_optional (expr_type )}, {}
3696
+
3697
+ return {}, {}
3698
+
3633
3699
#
3634
3700
# Helpers
3635
3701
#
@@ -4505,6 +4571,32 @@ def is_private(node_name: str) -> bool:
4505
4571
return node_name .startswith ('__' ) and not node_name .endswith ('__' )
4506
4572
4507
4573
4574
+ def uses_default_equality_checks (typ : Type ) -> bool :
4575
+ """Returns 'true' if we know for certain that the given type is using
4576
+ the default __eq__ and __ne__ checks defined in 'builtins.object'.
4577
+ We can use this information to make more aggressive inferences when
4578
+ analyzing things like equality checks.
4579
+
4580
+ When in doubt, this function will conservatively bias towards
4581
+ returning False.
4582
+ """
4583
+ if isinstance (typ , UnionType ):
4584
+ return all (map (uses_default_equality_checks , typ .items ))
4585
+ # TODO: Generalize this so it'll handle other types with fallbacks
4586
+ if isinstance (typ , LiteralType ):
4587
+ typ = typ .fallback
4588
+ if isinstance (typ , Instance ):
4589
+ typeinfo = typ .type
4590
+ eq_sym = typeinfo .get ('__eq__' )
4591
+ ne_sym = typeinfo .get ('__ne__' )
4592
+ if eq_sym is None or ne_sym is None :
4593
+ return False
4594
+ return (eq_sym .fullname == 'builtins.object.__eq__'
4595
+ and ne_sym .fullname == 'builtins.object.__ne__' )
4596
+ else :
4597
+ return False
4598
+
4599
+
4508
4600
def is_singleton_type (typ : Type ) -> bool :
4509
4601
"""Returns 'true' if this type is a "singleton type" -- if there exists
4510
4602
exactly only one runtime value associated with this type.
0 commit comments