@@ -615,24 +615,88 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
615
615
}
616
616
}
617
617
618
+ // TODO `expr_in` and `expr_not_in` should perhaps be unified with `expr_eq` and `expr_ne`,
619
+ // since `eq` and `ne` are equivalent to `in` and `not in` with only one element in the RHS.
618
620
fn evaluate_expr_in ( & mut self , lhs_ty : Type < ' db > , rhs_ty : Type < ' db > ) -> Option < Type < ' db > > {
619
621
if lhs_ty. is_single_valued ( self . db ) || lhs_ty. is_union_of_single_valued ( self . db ) {
620
- if let Type :: StringLiteral ( string_literal) = rhs_ty {
621
- Some ( UnionType :: from_elements (
622
- self . db ,
623
- string_literal
624
- . iter_each_char ( self . db )
625
- . map ( Type :: StringLiteral ) ,
626
- ) )
627
- } else if let Some ( tuple_spec) = rhs_ty. tuple_instance_spec ( self . db ) {
628
- // N.B. Strictly speaking this is unsound, since a tuple subclass might override `__contains__`
629
- // but we'd still apply the narrowing here. This seems unlikely, however, and narrowing is
630
- // generally unsound in numerous ways anyway (attribute narrowing, subscript, narrowing,
631
- // narrowing of globals, etc.). So this doesn't seem worth worrying about too much.
632
- Some ( UnionType :: from_elements ( self . db , tuple_spec. all_elements ( ) ) )
633
- } else {
634
- None
622
+ rhs_ty
623
+ . try_iterate ( self . db )
624
+ . ok ( )
625
+ . map ( |iterable| iterable. homogeneous_element_type ( self . db ) )
626
+ } else if lhs_ty. is_union_with_single_valued ( self . db ) {
627
+ let rhs_values = rhs_ty
628
+ . try_iterate ( self . db )
629
+ . ok ( ) ?
630
+ . homogeneous_element_type ( self . db ) ;
631
+
632
+ let mut builder = UnionBuilder :: new ( self . db ) ;
633
+
634
+ // Add the narrowed values from the RHS first, to keep literals before broader types.
635
+ builder = builder. add ( rhs_values) ;
636
+
637
+ if let Some ( lhs_union) = lhs_ty. into_union ( ) {
638
+ for element in lhs_union. elements ( self . db ) {
639
+ // Keep only the non-single-valued portion of the original type.
640
+ if !element. is_single_valued ( self . db )
641
+ && !element. is_literal_string ( )
642
+ && !element. is_bool ( self . db )
643
+ {
644
+ builder = builder. add ( * element) ;
645
+ }
646
+ }
635
647
}
648
+ Some ( builder. build ( ) )
649
+ } else {
650
+ None
651
+ }
652
+ }
653
+
654
+ fn evaluate_expr_not_in ( & mut self , lhs_ty : Type < ' db > , rhs_ty : Type < ' db > ) -> Option < Type < ' db > > {
655
+ let rhs_values = rhs_ty
656
+ . try_iterate ( self . db )
657
+ . ok ( ) ?
658
+ . homogeneous_element_type ( self . db ) ;
659
+
660
+ if lhs_ty. is_single_valued ( self . db ) || lhs_ty. is_union_of_single_valued ( self . db ) {
661
+ // Exclude the RHS values from the entire (single-valued) LHS domain.
662
+ let complement = IntersectionBuilder :: new ( self . db )
663
+ . add_positive ( lhs_ty)
664
+ . add_negative ( rhs_values)
665
+ . build ( ) ;
666
+ Some ( complement)
667
+ } else if lhs_ty. is_union_with_single_valued ( self . db ) {
668
+ // Split LHS into single-valued portion and the rest. Exclude RHS values from the
669
+ // single-valued portion, keep the rest intact.
670
+ let mut single_builder = UnionBuilder :: new ( self . db ) ;
671
+ let mut rest_builder = UnionBuilder :: new ( self . db ) ;
672
+
673
+ if let Some ( lhs_union) = lhs_ty. into_union ( ) {
674
+ for element in lhs_union. elements ( self . db ) {
675
+ if element. is_single_valued ( self . db )
676
+ || element. is_literal_string ( )
677
+ || element. is_bool ( self . db )
678
+ {
679
+ single_builder = single_builder. add ( * element) ;
680
+ } else {
681
+ rest_builder = rest_builder. add ( * element) ;
682
+ }
683
+ }
684
+ }
685
+
686
+ let single_union = single_builder. build ( ) ;
687
+ let rest_union = rest_builder. build ( ) ;
688
+
689
+ let narrowed_single = IntersectionBuilder :: new ( self . db )
690
+ . add_positive ( single_union)
691
+ . add_negative ( rhs_values)
692
+ . build ( ) ;
693
+
694
+ // Keep order: first literal complement, then broader arms.
695
+ let result = UnionBuilder :: new ( self . db )
696
+ . add ( narrowed_single)
697
+ . add ( rest_union)
698
+ . build ( ) ;
699
+ Some ( result)
636
700
} else {
637
701
None
638
702
}
@@ -660,9 +724,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
660
724
ast:: CmpOp :: Eq => self . evaluate_expr_eq ( lhs_ty, rhs_ty) ,
661
725
ast:: CmpOp :: NotEq => self . evaluate_expr_ne ( lhs_ty, rhs_ty) ,
662
726
ast:: CmpOp :: In => self . evaluate_expr_in ( lhs_ty, rhs_ty) ,
663
- ast:: CmpOp :: NotIn => self
664
- . evaluate_expr_in ( lhs_ty, rhs_ty)
665
- . map ( |ty| ty. negate ( self . db ) ) ,
727
+ ast:: CmpOp :: NotIn => self . evaluate_expr_not_in ( lhs_ty, rhs_ty) ,
666
728
_ => None ,
667
729
}
668
730
}
0 commit comments