From 00625c66195972da0bb8a7f84d09204cf93c3aaa Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Wed, 11 Mar 2020 21:51:39 +0000 Subject: [PATCH] Backport PR #32633: TYP: Remove _ensure_type --- doc/source/whatsnew/v1.0.2.rst | 1 + pandas/core/base.py | 9 --------- pandas/core/frame.py | 20 +++++++++----------- pandas/core/generic.py | 12 ++++++------ pandas/core/reshape/pivot.py | 4 +++- pandas/io/formats/format.py | 4 +--- pandas/tests/frame/indexing/test_indexing.py | 11 +++++++++++ 7 files changed, 31 insertions(+), 30 deletions(-) diff --git a/doc/source/whatsnew/v1.0.2.rst b/doc/source/whatsnew/v1.0.2.rst index 0a1619750de28..d00a418af3ddb 100644 --- a/doc/source/whatsnew/v1.0.2.rst +++ b/doc/source/whatsnew/v1.0.2.rst @@ -28,6 +28,7 @@ Fixed regressions - Fixed regression in the repr of an object-dtype :class:`Index` with bools and missing values (:issue:`32146`) - Fixed regression in :meth:`read_csv` in which the ``encoding`` option was not recognized with certain file-like objects (:issue:`31819`) - Fixed regression in :meth:`DataFrame.reindex` and :meth:`Series.reindex` when reindexing with (tz-aware) index and ``method=nearest`` (:issue:`26683`) +- Fixed regression in :meth:`DataFrame.reindex_like` on a :class:`DataFrame` subclass raised an ``AssertionError`` (:issue:`31925`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/base.py b/pandas/core/base.py index 5acf1a28bb4b6..a46b3256a9d48 100644 --- a/pandas/core/base.py +++ b/pandas/core/base.py @@ -8,7 +8,6 @@ import numpy as np import pandas._libs.lib as lib -from pandas._typing import T from pandas.compat import PYPY from pandas.compat.numpy import function as nv from pandas.errors import AbstractMethodError @@ -85,14 +84,6 @@ def __sizeof__(self): # object's 'sizeof' return super().__sizeof__() - def _ensure_type(self: T, obj) -> T: - """Ensure that an object has same type as self. - - Used by type checkers. - """ - assert isinstance(obj, type(self)), type(obj) - return obj - class NoNewAttributesMixin: """Mixin which prevents adding new attributes. diff --git a/pandas/core/frame.py b/pandas/core/frame.py index b680234cb0afd..819d341d2d5b4 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -3853,7 +3853,7 @@ def reindex(self, *args, **kwargs) -> "DataFrame": # Pop these, since the values are in `kwargs` under different names kwargs.pop("axis", None) kwargs.pop("labels", None) - return self._ensure_type(super().reindex(**kwargs)) + return super().reindex(**kwargs) def drop( self, @@ -4174,8 +4174,8 @@ def replace( @Appender(_shared_docs["shift"] % _shared_doc_kwargs) def shift(self, periods=1, freq=None, axis=0, fill_value=None) -> "DataFrame": - return self._ensure_type( - super().shift(periods=periods, freq=freq, axis=axis, fill_value=fill_value) + return super().shift( + periods=periods, freq=freq, axis=axis, fill_value=fill_value ) def set_index( @@ -8410,14 +8410,12 @@ def isin(self, values) -> "DataFrame": from pandas.core.reshape.concat import concat values = collections.defaultdict(list, values) - return self._ensure_type( - concat( - ( - self.iloc[:, [i]].isin(values[col]) - for i, col in enumerate(self.columns) - ), - axis=1, - ) + return concat( + ( + self.iloc[:, [i]].isin(values[col]) + for i, col in enumerate(self.columns) + ), + axis=1, ) elif isinstance(values, Series): if not values.index.is_unique: diff --git a/pandas/core/generic.py b/pandas/core/generic.py index c8a6c7c760498..3e86e986eae80 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -8536,9 +8536,9 @@ def _align_frame( ) if method is not None: - left = self._ensure_type( - left.fillna(method=method, axis=fill_axis, limit=limit) - ) + _left = left.fillna(method=method, axis=fill_axis, limit=limit) + assert _left is not None # needed for mypy + left = _left right = right.fillna(method=method, axis=fill_axis, limit=limit) # if DatetimeIndex have different tz, convert to UTC @@ -10079,9 +10079,9 @@ def pct_change( if fill_method is None: data = self else: - data = self._ensure_type( - self.fillna(method=fill_method, axis=axis, limit=limit) - ) + _data = self.fillna(method=fill_method, axis=axis, limit=limit) + assert _data is not None # needed for mypy + data = _data rs = data.div(data.shift(periods=periods, freq=freq, axis=axis, **kwargs)) - 1 if freq is not None: diff --git a/pandas/core/reshape/pivot.py b/pandas/core/reshape/pivot.py index b443ba142369c..d91bf3e250f4a 100644 --- a/pandas/core/reshape/pivot.py +++ b/pandas/core/reshape/pivot.py @@ -148,7 +148,9 @@ def pivot_table( table = table.sort_index(axis=1) if fill_value is not None: - table = table._ensure_type(table.fillna(fill_value, downcast="infer")) + _table = table.fillna(fill_value, downcast="infer") + assert _table is not None # needed for mypy + table = _table if margins: if dropna: diff --git a/pandas/io/formats/format.py b/pandas/io/formats/format.py index 6adf69a922000..9578bb9f85fda 100644 --- a/pandas/io/formats/format.py +++ b/pandas/io/formats/format.py @@ -281,9 +281,7 @@ def _chk_truncate(self) -> None: series = series.iloc[:max_rows] else: row_num = max_rows // 2 - series = series._ensure_type( - concat((series.iloc[:row_num], series.iloc[-row_num:])) - ) + series = concat((series.iloc[:row_num], series.iloc[-row_num:])) self.tr_row_num = row_num else: self.tr_row_num = None diff --git a/pandas/tests/frame/indexing/test_indexing.py b/pandas/tests/frame/indexing/test_indexing.py index 4418dee66c092..db0c05040a22e 100644 --- a/pandas/tests/frame/indexing/test_indexing.py +++ b/pandas/tests/frame/indexing/test_indexing.py @@ -1596,6 +1596,17 @@ def test_reindex_methods(self, method, expected_values): actual = df[::-1].reindex(target, method=switched_method) tm.assert_frame_equal(expected, actual) + def test_reindex_subclass(self): + # https://github.com/pandas-dev/pandas/issues/31925 + class MyDataFrame(DataFrame): + pass + + expected = DataFrame() + df = MyDataFrame() + result = df.reindex_like(expected) + + tm.assert_frame_equal(result, expected) + def test_reindex_methods_nearest_special(self): df = pd.DataFrame({"x": list(range(5))}) target = np.array([-0.1, 0.9, 1.1, 1.5])