Skip to content

Commit 0243661

Browse files
committed
Refactor reversible operators
This pull request refactors and reworks how we handle reversible operators like __add__. Specifically, what our code was previously doing was assuming that given the expression `A() + B()`, we would always try calling `A().__add__(B())` first, followed by `B().__radd__(A())` second (if the `__radd__` method exists). Unfortunately, it seems like this model was a little too naive, which caused several mismatches/weird errors when I was working on refining how we handle overlaps and TypeVars in a subsequent PR. Specifically, what actually happens is that... 1. When doing `A() + A()`, we only ever try calling `A.__add__`, never `A.__radd__`. This is the case even if `__add__` is undefined. 2. If `B` is a subclass of `A`, and if `B` defines an `__radd__` method, and we do `A() + B()`, Python will actually try checking `B.__radd__` *first*, then `A.__add__` second. This lets a subclass effectively "refine" the desired return type. Note that if `B` only *inherits* an `__radd__` method, Python calls `A.__add__` first as usual. Basically, `B` must provide a genuine refinement over whatever `A` returns. 3. In all other cases, we call `__add__` then `__radd__` as usual. This pull request modifies both checker.py and checkexpr.py to match this behavior, and adds logic so that we check the calls in the correct order. This ended up slightly changing a few error messages in certain edge cases.
1 parent a2ffde9 commit 0243661

13 files changed

+466
-160
lines changed

mypy/checker.py

+80-50
Original file line numberDiff line numberDiff line change
@@ -1043,72 +1043,102 @@ def check_overlapping_op_methods(self,
10431043
"""Check for overlapping method and reverse method signatures.
10441044
10451045
Assume reverse method has valid argument count and kinds.
1046+
1047+
Precondition:
1048+
If the reverse operator method accepts some argument of type
1049+
X, the forward operator method must belong to class X.
1050+
1051+
For example, if we have the reverse operator `A.__radd__(B)`, then the
1052+
corresponding forward operator must have the type `B.__add__(...)`.
10461053
"""
10471054

1048-
# Reverse operator method that overlaps unsafely with the
1049-
# forward operator method can result in type unsafety. This is
1050-
# similar to overlapping overload variants.
1055+
# Note: Suppose we have two operator methods "A.__rOP__(B) -> R1" and
1056+
# "B.__OP__(C) -> R2". We check if these two methods are unsafely overlapping
1057+
# by using the following algorithm:
1058+
#
1059+
# 1. Rewrite "B.__OP__(C) -> R1" to "temp1(B, C) -> R1"
1060+
#
1061+
# 2. Rewrite "A.__rOP__(B) -> R2" to "temp2(B, A) -> R2"
1062+
#
1063+
# 3. Treat temp1 and temp2 as if they were both variants in the same
1064+
# overloaded function. (This mirrors how the Python runtime calls
1065+
# operator methods: we first try __OP__, then __rOP__.)
10511066
#
1052-
# This example illustrates the issue:
1067+
# If the first signature is unsafely overlapping with the second,
1068+
# report an error.
10531069
#
1054-
# class X: pass
1055-
# class A:
1056-
# def __add__(self, x: X) -> int:
1057-
# if isinstance(x, X):
1058-
# return 1
1059-
# return NotImplemented
1060-
# class B:
1061-
# def __radd__(self, x: A) -> str: return 'x'
1062-
# class C(X, B): pass
1063-
# def f(b: B) -> None:
1064-
# A() + b # Result is 1, even though static type seems to be str!
1065-
# f(C())
1070+
# 4. However, if temp1 shadows temp2 (e.g. the __rOP__ method can never
1071+
# be called), do NOT report an error.
10661072
#
1067-
# The reason for the problem is that B and X are overlapping
1068-
# types, and the return types are different. Also, if the type
1069-
# of x in __radd__ would not be A, the methods could be
1070-
# non-overlapping.
1073+
# This behavior deviates from how we handle overloads -- many of the
1074+
# modules in typeshed seem to define __OP__ methods that shadow the
1075+
# corresponding __rOP__ method.
1076+
#
1077+
# Note: we do not attempt to handle unsafe overlaps related to multiple
1078+
# inheritance.
10711079

10721080
for forward_item in union_items(forward_type):
10731081
if isinstance(forward_item, CallableType):
1074-
# TODO check argument kinds
1075-
if len(forward_item.arg_types) < 1:
1076-
# Not a valid operator method -- can't succeed anyway.
1077-
return
1078-
1079-
# Construct normalized function signatures corresponding to the
1080-
# operator methods. The first argument is the left operand and the
1081-
# second operand is the right argument -- we switch the order of
1082-
# the arguments of the reverse method.
1083-
forward_tweaked = CallableType(
1084-
[forward_base, forward_item.arg_types[0]],
1085-
[nodes.ARG_POS] * 2,
1086-
[None] * 2,
1087-
forward_item.ret_type,
1088-
forward_item.fallback,
1089-
name=forward_item.name)
1090-
reverse_args = reverse_type.arg_types
1091-
reverse_tweaked = CallableType(
1092-
[reverse_args[1], reverse_args[0]],
1093-
[nodes.ARG_POS] * 2,
1094-
[None] * 2,
1095-
reverse_type.ret_type,
1096-
fallback=self.named_type('builtins.function'),
1097-
name=reverse_type.name)
1098-
1099-
if is_unsafe_overlapping_operator_signatures(
1100-
forward_tweaked, reverse_tweaked):
1082+
if self.is_unsafe_overlapping_op(forward_item, forward_base, reverse_type):
11011083
self.msg.operator_method_signatures_overlap(
11021084
reverse_class, reverse_name,
11031085
forward_base, forward_name, context)
11041086
elif isinstance(forward_item, Overloaded):
11051087
for item in forward_item.items():
1106-
self.check_overlapping_op_methods(
1107-
reverse_type, reverse_name, reverse_class,
1108-
item, forward_name, forward_base, context)
1088+
if self.is_unsafe_overlapping_op(item, forward_base, reverse_type):
1089+
self.msg.operator_method_signatures_overlap(
1090+
reverse_class, reverse_name,
1091+
forward_base, forward_name,
1092+
context)
11091093
elif not isinstance(forward_item, AnyType):
11101094
self.msg.forward_operator_not_callable(forward_name, context)
11111095

1096+
def is_unsafe_overlapping_op(self,
1097+
forward_item: CallableType,
1098+
forward_base: Type,
1099+
reverse_type: CallableType) -> bool:
1100+
# TODO check argument kinds
1101+
if len(forward_item.arg_types) < 1:
1102+
# Not a valid operator method -- can't succeed anyway.
1103+
return False
1104+
1105+
# Erase the type if necessary to make sure we don't have a dangling
1106+
# TypeVar in forward_tweaked
1107+
forward_base_erased = forward_base
1108+
if isinstance(forward_base, TypeVarType):
1109+
forward_base_erased = erase_to_bound(forward_base)
1110+
1111+
# Construct normalized function signatures corresponding to the
1112+
# operator methods. The first argument is the left operand and the
1113+
# second operand is the right argument -- we switch the order of
1114+
# the arguments of the reverse method.
1115+
1116+
forward_tweaked = forward_item.copy_modified(
1117+
arg_types=[forward_base_erased, forward_item.arg_types[0]],
1118+
arg_kinds=[nodes.ARG_POS] * 2,
1119+
arg_names=[None] * 2,
1120+
)
1121+
reverse_tweaked = reverse_type.copy_modified(
1122+
arg_types=[reverse_type.arg_types[1], reverse_type.arg_types[0]],
1123+
arg_kinds=[nodes.ARG_POS] * 2,
1124+
arg_names=[None] * 2,
1125+
)
1126+
1127+
reverse_base_erased = reverse_type.arg_types[0]
1128+
if isinstance(reverse_base_erased, TypeVarType):
1129+
reverse_base_erased = erase_to_bound(reverse_base_erased)
1130+
1131+
if is_same_type(reverse_base_erased, forward_base_erased):
1132+
return False
1133+
elif is_subtype(reverse_base_erased, forward_base_erased):
1134+
first = reverse_tweaked
1135+
second = forward_tweaked
1136+
else:
1137+
first = forward_tweaked
1138+
second = reverse_tweaked
1139+
1140+
return is_unsafe_overlapping_overload_signatures(first, second)
1141+
11121142
def check_inplace_operator_method(self, defn: FuncBase) -> None:
11131143
"""Check an inplace operator method such as __iadd__.
11141144

0 commit comments

Comments
 (0)