-
-
Notifications
You must be signed in to change notification settings - Fork 19.5k
ENH: .equals for Extension Arrays #30652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
36c8b88
786963c
6800315
a3e7b7f
860013f
fc3d2c2
3da5726
c5027dd
375664c
b6ad2fb
365362a
aae2f94
8d052ad
9ee034e
38501e6
dccec7f
0b1255f
b8be858
4c7273f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -78,6 +78,7 @@ class ExtensionArray: | |
| dropna | ||
| factorize | ||
| fillna | ||
| equals | ||
| isna | ||
| ravel | ||
| repeat | ||
|
|
@@ -104,6 +105,7 @@ class ExtensionArray: | |
| * _from_sequence | ||
| * _from_factorized | ||
| * __getitem__ | ||
| * __eq__ | ||
| * __len__ | ||
| * dtype | ||
| * nbytes | ||
|
|
@@ -350,6 +352,38 @@ def __iter__(self): | |
| for i in range(len(self)): | ||
| yield self[i] | ||
|
|
||
| def __eq__(self, other: ABCExtensionArray) -> bool: | ||
|
||
| """ | ||
| Whether the two arrays are equivalent. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| other: ExtensionArray | ||
| The array to compare to this array. | ||
|
|
||
| Returns | ||
| ------- | ||
| bool | ||
| """ | ||
|
|
||
| raise AbstractMethodError(self) | ||
|
||
|
|
||
| def __ne__(self, other: ABCExtensionArray) -> bool: | ||
|
||
| """ | ||
| Whether the two arrays are not equivalent. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| other: ExtensionArray | ||
| The array to compare to this array. | ||
|
|
||
| Returns | ||
| ------- | ||
| bool | ||
| """ | ||
|
|
||
| return ~(self == other) | ||
|
|
||
| # ------------------------------------------------------------------------ | ||
| # Required attributes | ||
| # ------------------------------------------------------------------------ | ||
|
|
@@ -657,6 +691,25 @@ def searchsorted(self, value, side="left", sorter=None): | |
| arr = self.astype(object) | ||
| return arr.searchsorted(value, side=side, sorter=sorter) | ||
|
|
||
| def equals(self, other: ABCExtensionArray) -> bool: | ||
| """ | ||
| Return if another array is equivalent to this array. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| other: ExtensionArray | ||
| Array to compare to this Array. | ||
|
|
||
| Returns | ||
| ------- | ||
| boolean | ||
| Whether the arrays are equivalent. | ||
|
|
||
| """ | ||
| return isinstance(other, self.__class__) and ( | ||
|
||
| ((self == other) | (self.isna() == other.isna())).all() | ||
| ) | ||
|
|
||
| def _values_for_factorize(self) -> Tuple[np.ndarray, Any]: | ||
| """ | ||
| Return an array and missing value suitable for factorization. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -376,6 +376,15 @@ def __getitem__(self, item): | |
|
|
||
| return type(self)(self._data[item], self._mask[item]) | ||
|
|
||
| def __eq__(self, other): | ||
| return ( | ||
| isinstance(other, IntegerArray) | ||
| and hasattr(other, "_data") | ||
|
||
| and self._data == other._data | ||
| and hasattr(other, "_mask") | ||
| and self._mask == other._mask | ||
| ) | ||
|
|
||
| def _coerce_to_ndarray(self, dtype=None, na_value=lib._no_default): | ||
| """ | ||
| coerce to an ndarary of object dtype | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -358,3 +358,18 @@ def test_repeat_raises(self, data, repeats, kwargs, error, msg, use_numpy): | |
| np.repeat(data, repeats, **kwargs) | ||
| else: | ||
| data.repeat(repeats, **kwargs) | ||
|
|
||
| def test_equals(self, data, na_value): | ||
| cls = type(data) | ||
| ser = pd.Series(cls._from_sequence(data, dtype=data.dtype)) | ||
| na_ser = pd.Series(cls._from_sequence([na_value], dtype=data.dtype)) | ||
|
|
||
| assert data.equals(data) | ||
|
||
| assert ser.equals(ser) | ||
| assert na_ser.equals(na_ser) | ||
|
|
||
| assert not data.equals(na_value) | ||
| assert not na_ser.equals(ser) | ||
| assert not ser.equals(na_ser) | ||
| assert not ser.equals(0) | ||
| assert not na_ser.equals(0) | ||
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -132,10 +132,8 @@ class BaseComparisonOpsTests(BaseOpsUtil): | |
| def _compare_other(self, s, data, op_name, other): | ||
| op = self.get_op_from_name(op_name) | ||
| if op_name == "__eq__": | ||
| assert getattr(data, op_name)(other) is NotImplemented | ||
TomAugspurger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| assert not op(s, other).all() | ||
| elif op_name == "__ne__": | ||
| assert getattr(data, op_name)(other) is NotImplemented | ||
| assert op(s, other).all() | ||
|
|
||
| else: | ||
|
|
@@ -158,13 +156,3 @@ def test_compare_array(self, data, all_compare_operators): | |
| s = pd.Series(data) | ||
| other = pd.Series([data[0]] * len(data)) | ||
| self._compare_other(s, data, op_name, other) | ||
|
|
||
| def test_direct_arith_with_series_returns_not_implemented(self, data): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Undo this change.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as the other comment. The whole PR is intended to break this test since we implement
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The reason this needs to be kept is that, although |
||
| # EAs should return NotImplemented for ops with Series. | ||
| # Pandas takes care of unboxing the series and calling the EA's op. | ||
| other = pd.Series(data) | ||
| if hasattr(data, "__eq__"): | ||
| result = data.__eq__(other) | ||
| assert result is NotImplemented | ||
| else: | ||
| raise pytest.skip(f"{type(data).__name__} does not implement __eq__") | ||
Uh oh!
There was an error while loading. Please reload this page.