Skip to content

Commit 7f29b73

Browse files
authored
Refactor reversible operators (#5475)
This pull request refactors and reworks how we handle reversible operators like __add__. Specifically, what our code was previously doing was assuming that given the expression `A() + B()`, we would always try calling `A().__add__(B())` first, followed by `B().__radd__(A())` second (if the `__radd__` method exists). Unfortunately, it seems like this model was a little too naive, which caused several mismatches/weird errors when I was working on refining how we handle overlaps and TypeVars in a subsequent PR. Specifically, what actually happens is that... 1. When doing `A() + A()`, we only ever try calling `A.__add__`, never `A.__radd__`. This is the case even if `__add__` is undefined. 2. If `B` is a subclass of `A`, and if `B` defines an `__radd__` method, and we do `A() + B()`, Python will actually try checking `B.__radd__` *first*, then `A.__add__` second. This lets a subclass effectively "refine" the desired return type. Note that if `B` only *inherits* an `__radd__` method, Python calls `A.__add__` first as usual. Basically, `B` must provide a genuine refinement over whatever `A` returns. 3. In all other cases, we call `__add__` then `__radd__` as usual. This pull request modifies both checker.py and checkexpr.py to match this behavior, and adds logic so that we check the calls in the correct order. This ended up slightly changing a few error messages in certain edge cases.
1 parent 5c73e6a commit 7f29b73

13 files changed

+623
-169
lines changed

mypy/checker.py

Lines changed: 101 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])
778778
self.msg, context=fdef)
779779

780780
if name: # Special method names
781-
if defn.info and name in nodes.reverse_op_method_set:
781+
if defn.info and self.is_reverse_op_method(name):
782782
self.check_reverse_op_method(item, typ, name, defn)
783783
elif name in ('__getattr__', '__getattribute__'):
784784
self.check_getattr_method(typ, defn, name)
@@ -923,6 +923,18 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])
923923

924924
self.binder = old_binder
925925

926+
def is_forward_op_method(self, method_name: str) -> bool:
927+
if self.options.python_version[0] == 2 and method_name == '__div__':
928+
return True
929+
else:
930+
return method_name in nodes.reverse_op_methods
931+
932+
def is_reverse_op_method(self, method_name: str) -> bool:
933+
if self.options.python_version[0] == 2 and method_name == '__rdiv__':
934+
return True
935+
else:
936+
return method_name in nodes.reverse_op_method_set
937+
926938
def check_for_missing_annotations(self, fdef: FuncItem) -> None:
927939
# Check for functions with unspecified/not fully specified types.
928940
def is_unannotated_any(t: Type) -> bool:
@@ -1010,7 +1022,10 @@ def check_reverse_op_method(self, defn: FuncItem,
10101022
arg_names=[reverse_type.arg_names[0], "_"])
10111023
assert len(reverse_type.arg_types) >= 2
10121024

1013-
forward_name = nodes.normal_from_reverse_op[reverse_name]
1025+
if self.options.python_version[0] == 2 and reverse_name == '__rdiv__':
1026+
forward_name = '__div__'
1027+
else:
1028+
forward_name = nodes.normal_from_reverse_op[reverse_name]
10141029
forward_inst = reverse_type.arg_types[1]
10151030
if isinstance(forward_inst, TypeVarType):
10161031
forward_inst = forward_inst.upper_bound
@@ -1042,73 +1057,105 @@ def check_overlapping_op_methods(self,
10421057
context: Context) -> None:
10431058
"""Check for overlapping method and reverse method signatures.
10441059
1045-
Assume reverse method has valid argument count and kinds.
1060+
This function assumes that:
1061+
1062+
- The reverse method has valid argument count and kinds.
1063+
- If the reverse operator method accepts some argument of type
1064+
X, the forward operator method also belong to class X.
1065+
1066+
For example, if we have the reverse operator `A.__radd__(B)`, then the
1067+
corresponding forward operator must have the type `B.__add__(...)`.
10461068
"""
10471069

1048-
# Reverse operator method that overlaps unsafely with the
1049-
# forward operator method can result in type unsafety. This is
1050-
# similar to overlapping overload variants.
1070+
# Note: Suppose we have two operator methods "A.__rOP__(B) -> R1" and
1071+
# "B.__OP__(C) -> R2". We check if these two methods are unsafely overlapping
1072+
# by using the following algorithm:
1073+
#
1074+
# 1. Rewrite "B.__OP__(C) -> R1" to "temp1(B, C) -> R1"
1075+
#
1076+
# 2. Rewrite "A.__rOP__(B) -> R2" to "temp2(B, A) -> R2"
1077+
#
1078+
# 3. Treat temp1 and temp2 as if they were both variants in the same
1079+
# overloaded function. (This mirrors how the Python runtime calls
1080+
# operator methods: we first try __OP__, then __rOP__.)
1081+
#
1082+
# If the first signature is unsafely overlapping with the second,
1083+
# report an error.
10511084
#
1052-
# This example illustrates the issue:
1085+
# 4. However, if temp1 shadows temp2 (e.g. the __rOP__ method can never
1086+
# be called), do NOT report an error.
10531087
#
1054-
# class X: pass
1055-
# class A:
1056-
# def __add__(self, x: X) -> int:
1057-
# if isinstance(x, X):
1058-
# return 1
1059-
# return NotImplemented
1060-
# class B:
1061-
# def __radd__(self, x: A) -> str: return 'x'
1062-
# class C(X, B): pass
1063-
# def f(b: B) -> None:
1064-
# A() + b # Result is 1, even though static type seems to be str!
1065-
# f(C())
1088+
# This behavior deviates from how we handle overloads -- many of the
1089+
# modules in typeshed seem to define __OP__ methods that shadow the
1090+
# corresponding __rOP__ method.
10661091
#
1067-
# The reason for the problem is that B and X are overlapping
1068-
# types, and the return types are different. Also, if the type
1069-
# of x in __radd__ would not be A, the methods could be
1070-
# non-overlapping.
1092+
# Note: we do not attempt to handle unsafe overlaps related to multiple
1093+
# inheritance. (This is consistent with how we handle overloads: we also
1094+
# do not try checking unsafe overlaps due to multiple inheritance there.)
10711095

10721096
for forward_item in union_items(forward_type):
10731097
if isinstance(forward_item, CallableType):
1074-
# TODO check argument kinds
1075-
if len(forward_item.arg_types) < 1:
1076-
# Not a valid operator method -- can't succeed anyway.
1077-
return
1078-
1079-
# Construct normalized function signatures corresponding to the
1080-
# operator methods. The first argument is the left operand and the
1081-
# second operand is the right argument -- we switch the order of
1082-
# the arguments of the reverse method.
1083-
forward_tweaked = CallableType(
1084-
[forward_base, forward_item.arg_types[0]],
1085-
[nodes.ARG_POS] * 2,
1086-
[None] * 2,
1087-
forward_item.ret_type,
1088-
forward_item.fallback,
1089-
name=forward_item.name)
1090-
reverse_args = reverse_type.arg_types
1091-
reverse_tweaked = CallableType(
1092-
[reverse_args[1], reverse_args[0]],
1093-
[nodes.ARG_POS] * 2,
1094-
[None] * 2,
1095-
reverse_type.ret_type,
1096-
fallback=self.named_type('builtins.function'),
1097-
name=reverse_type.name)
1098-
1099-
if is_unsafe_overlapping_operator_signatures(
1100-
forward_tweaked, reverse_tweaked):
1098+
if self.is_unsafe_overlapping_op(forward_item, forward_base, reverse_type):
11011099
self.msg.operator_method_signatures_overlap(
11021100
reverse_class, reverse_name,
11031101
forward_base, forward_name, context)
11041102
elif isinstance(forward_item, Overloaded):
11051103
for item in forward_item.items():
1106-
self.check_overlapping_op_methods(
1107-
reverse_type, reverse_name, reverse_class,
1108-
item, forward_name, forward_base, context)
1104+
if self.is_unsafe_overlapping_op(item, forward_base, reverse_type):
1105+
self.msg.operator_method_signatures_overlap(
1106+
reverse_class, reverse_name,
1107+
forward_base, forward_name,
1108+
context)
11091109
elif not isinstance(forward_item, AnyType):
11101110
self.msg.forward_operator_not_callable(forward_name, context)
11111111

1112+
def is_unsafe_overlapping_op(self,
1113+
forward_item: CallableType,
1114+
forward_base: Type,
1115+
reverse_type: CallableType) -> bool:
1116+
# TODO: check argument kinds?
1117+
if len(forward_item.arg_types) < 1:
1118+
# Not a valid operator method -- can't succeed anyway.
1119+
return False
1120+
1121+
# Erase the type if necessary to make sure we don't have a single
1122+
# TypeVar in forward_tweaked. (Having a function signature containing
1123+
# just a single TypeVar can lead to unpredictable behavior.)
1124+
forward_base_erased = forward_base
1125+
if isinstance(forward_base, TypeVarType):
1126+
forward_base_erased = erase_to_bound(forward_base)
1127+
1128+
# Construct normalized function signatures corresponding to the
1129+
# operator methods. The first argument is the left operand and the
1130+
# second operand is the right argument -- we switch the order of
1131+
# the arguments of the reverse method.
1132+
1133+
forward_tweaked = forward_item.copy_modified(
1134+
arg_types=[forward_base_erased, forward_item.arg_types[0]],
1135+
arg_kinds=[nodes.ARG_POS] * 2,
1136+
arg_names=[None] * 2,
1137+
)
1138+
reverse_tweaked = reverse_type.copy_modified(
1139+
arg_types=[reverse_type.arg_types[1], reverse_type.arg_types[0]],
1140+
arg_kinds=[nodes.ARG_POS] * 2,
1141+
arg_names=[None] * 2,
1142+
)
1143+
1144+
reverse_base_erased = reverse_type.arg_types[0]
1145+
if isinstance(reverse_base_erased, TypeVarType):
1146+
reverse_base_erased = erase_to_bound(reverse_base_erased)
1147+
1148+
if is_same_type(reverse_base_erased, forward_base_erased):
1149+
return False
1150+
elif is_subtype(reverse_base_erased, forward_base_erased):
1151+
first = reverse_tweaked
1152+
second = forward_tweaked
1153+
else:
1154+
first = forward_tweaked
1155+
second = reverse_tweaked
1156+
1157+
return is_unsafe_overlapping_overload_signatures(first, second)
1158+
11121159
def check_inplace_operator_method(self, defn: FuncBase) -> None:
11131160
"""Check an inplace operator method such as __iadd__.
11141161
@@ -1312,7 +1359,7 @@ def check_override(self, override: FunctionLike, original: FunctionLike,
13121359
fail = True
13131360
elif (not isinstance(original, Overloaded) and
13141361
isinstance(override, Overloaded) and
1315-
name in nodes.reverse_op_methods.keys()):
1362+
self.is_forward_op_method(name)):
13161363
# Operator method overrides cannot introduce overloading, as
13171364
# this could be unsafe with reverse operator methods.
13181365
fail = True

0 commit comments

Comments
 (0)