From 100c659683024fa0e589c9341230246594381a45 Mon Sep 17 00:00:00 2001 From: Tobias Bergkvist Date: Sun, 28 Jul 2024 01:18:44 +0200 Subject: [PATCH 1/8] Improve match statement union narrowing/inference --- mypy/checkpattern.py | 16 ++++++++++++++++ test-data/unit/check-python310.test | 5 ++--- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index a23be464b825..94d9be9cd14b 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -244,6 +244,22 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: # get inner types of original type # unpack_index = None + + if isinstance(current_type, UnionType): + captures = {} + items = [] + for t in current_type.items: + typ, _, capture = self.accept(o, t) + if not isinstance(typ, UninhabitedType): + items.append(typ) + captures.update(capture) + if len(items) == 0: + typ = UninhabitedType() + elif len(items) == 1: + typ = items[0] + else: + typ = UnionType(items=items) + return PatternType(type=typ, rest_type=current_type, captures=captures) if isinstance(current_type, TupleType): inner_types = current_type.items unpack_index = find_unpack_in_list(inner_types) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 5ecc69dc7c32..c4e777a79a52 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1777,12 +1777,11 @@ match foo: [case testMatchUnionTwoTuplesNoCrash] var: tuple[int, int] | tuple[str, str] -# TODO: we can infer better here. match var: case (42, a): - reveal_type(a) # N: Revealed type is "Union[builtins.int, builtins.str]" + reveal_type(a) # N: Revealed type is "builtins.int" case ("yes", b): - reveal_type(b) # N: Revealed type is "Union[builtins.int, builtins.str]" + reveal_type(b) # N: Revealed type is "builtins.str" [builtins fixtures/tuple.pyi] [case testMatchNamedAndKeywordsAreTheSame] From c62587aef2b0362609e31919e1baed2f76372db6 Mon Sep 17 00:00:00 2001 From: Tobias Bergkvist Date: Sun, 28 Jul 2024 01:40:59 +0200 Subject: [PATCH 2/8] Fix lint errors: Use get_proper_type and rename captures to union_captures --- mypy/checkpattern.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 94d9be9cd14b..a2e116fd096a 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -246,20 +246,20 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: unpack_index = None if isinstance(current_type, UnionType): - captures = {} - items = [] + union_captures: dict[Expression, Type] = {} + union_items: list[Type] = [] for t in current_type.items: typ, _, capture = self.accept(o, t) - if not isinstance(typ, UninhabitedType): - items.append(typ) - captures.update(capture) - if len(items) == 0: + if not isinstance(get_proper_type(typ), UninhabitedType): + union_items.append(typ) + union_captures.update(capture) + if len(union_items) == 0: typ = UninhabitedType() - elif len(items) == 1: - typ = items[0] + elif len(union_items) == 1: + typ = union_items[0] else: - typ = UnionType(items=items) - return PatternType(type=typ, rest_type=current_type, captures=captures) + typ = UnionType(items=union_items) + return PatternType(type=typ, rest_type=current_type, captures=union_captures) if isinstance(current_type, TupleType): inner_types = current_type.items unpack_index = find_unpack_in_list(inner_types) From e418d98d2edfc5858b0503106292597829965555 Mon Sep 17 00:00:00 2001 From: Tobias Bergkvist Date: Thu, 8 Aug 2024 18:08:04 +0200 Subject: [PATCH 3/8] Save a function call by storing length of union items Co-authored-by: Hashem --- mypy/checkpattern.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index a2e116fd096a..cd3be0bdd9f4 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -253,9 +253,10 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: if not isinstance(get_proper_type(typ), UninhabitedType): union_items.append(typ) union_captures.update(capture) - if len(union_items) == 0: + num_items = len(union_items) + if num_items == 0: typ = UninhabitedType() - elif len(union_items) == 1: + elif num_items == 1: typ = union_items[0] else: typ = UnionType(items=union_items) From e85e8d464ca2b805672db8ad772d3b6ed4ce6adc Mon Sep 17 00:00:00 2001 From: Tobias Bergkvist Date: Sun, 18 Aug 2024 02:26:21 +0200 Subject: [PATCH 4/8] Refactor: Use UnionType.make_union(..) --- mypy/checkpattern.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index cd3be0bdd9f4..311dfd69c277 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -253,14 +253,9 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: if not isinstance(get_proper_type(typ), UninhabitedType): union_items.append(typ) union_captures.update(capture) - num_items = len(union_items) - if num_items == 0: - typ = UninhabitedType() - elif num_items == 1: - typ = union_items[0] - else: - typ = UnionType(items=union_items) + typ = UnionType.make_union(items=union_items) return PatternType(type=typ, rest_type=current_type, captures=union_captures) + if isinstance(current_type, TupleType): inner_types = current_type.items unpack_index = find_unpack_in_list(inner_types) From 34585b6f7c6deb70f8ddfb2de48fd59d180dc014 Mon Sep 17 00:00:00 2001 From: Tobias Bergkvist Date: Sun, 18 Aug 2024 14:58:58 +0200 Subject: [PATCH 5/8] Use is_uninhabited(typ) instead of isinstance(typ, UninhabitedType) --- mypy/checkpattern.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 311dfd69c277..45d03ad86ca2 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -250,7 +250,7 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: union_items: list[Type] = [] for t in current_type.items: typ, _, capture = self.accept(o, t) - if not isinstance(get_proper_type(typ), UninhabitedType): + if not is_uninhabited(get_proper_type(typ)): union_items.append(typ) union_captures.update(capture) typ = UnionType.make_union(items=union_items) From cb6fa03b443f82fdec7a4d80ae1c02bda92ecdca Mon Sep 17 00:00:00 2001 From: Tobias Bergkvist Date: Mon, 2 Sep 2024 01:17:11 +0200 Subject: [PATCH 6/8] Narrow rest_type based on match type and current_type --- mypy/checkpattern.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 45d03ad86ca2..f9c60d771138 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -253,8 +253,15 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: if not is_uninhabited(get_proper_type(typ)): union_items.append(typ) union_captures.update(capture) + + rest_items: list[Type] = [] + for item in current_type.items: + if all(used_item != item for used_item in union_items): + rest_items.append(item) + typ = UnionType.make_union(items=union_items) - return PatternType(type=typ, rest_type=current_type, captures=union_captures) + rest_type = UnionType.make_union(items=rest_items) + return PatternType(type=typ, rest_type=rest_type, captures=union_captures) if isinstance(current_type, TupleType): inner_types = current_type.items From 4d28adf498e20460b75f7f8fc4f25531c34fea8f Mon Sep 17 00:00:00 2001 From: Tobias Bergkvist Date: Mon, 2 Sep 2024 01:49:51 +0200 Subject: [PATCH 7/8] Fix no-redef lint for rest_type --- mypy/checkpattern.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index f9c60d771138..809ca36c0765 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -259,9 +259,11 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: if all(used_item != item for used_item in union_items): rest_items.append(item) - typ = UnionType.make_union(items=union_items) - rest_type = UnionType.make_union(items=rest_items) - return PatternType(type=typ, rest_type=rest_type, captures=union_captures) + return PatternType( + type=UnionType.make_union(items=union_items), + rest_type=UnionType.make_union(items=rest_items), + captures=union_captures + ) if isinstance(current_type, TupleType): inner_types = current_type.items From 92d3c91aaf2757706f9ef923650eb342d19af2fb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 1 Sep 2024 23:50:21 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypy/checkpattern.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 809ca36c0765..65bbbe2c824b 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -262,7 +262,7 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: return PatternType( type=UnionType.make_union(items=union_items), rest_type=UnionType.make_union(items=rest_items), - captures=union_captures + captures=union_captures, ) if isinstance(current_type, TupleType):