From 4e1a84ec58fb0db84b6f2d6b5faa7a220def43fd Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Thu, 13 Jul 2023 22:14:51 +0200 Subject: [PATCH 01/23] init --- doc/source/whatsnew/v2.1.0.rst | 1 + pandas/core/frame.py | 22 ++++++++++++- pandas/core/generic.py | 27 ++++++++++++++-- pandas/tests/frame/methods/test_shift.py | 41 ++++++++++++++++++++++++ 4 files changed, 88 insertions(+), 3 deletions(-) diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index dd99d5031e724..b08f9c8be002e 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -128,6 +128,7 @@ Other enhancements - Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to lzma.LZMAFile (:issue:`52979`) - Performance improvement in :func:`concat` with homogeneous ``np.float64`` or ``np.float32`` dtypes (:issue:`52685`) - Performance improvement in :meth:`DataFrame.filter` when ``items`` is given (:issue:`52941`) +- :meth:`DataFrame.shift` and :meth:`Series.shift` now allow shifting by multiple periods by supplying a list of periods (:issue:`44424`) - .. --------------------------------------------------------------------------- diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 3a2ad225ae495..8cd1d634854e4 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -5501,13 +5501,33 @@ def _replace_columnwise( @doc(NDFrame.shift, klass=_shared_doc_kwargs["klass"]) def shift( self, - periods: int = 1, + periods: int | Iterable[int] = 1, freq: Frequency | None = None, axis: Axis = 0, fill_value: Hashable = lib.no_default, + suffix: str | None = None, ) -> DataFrame: axis = self._get_axis_number(axis) + if is_list_like(periods): + if axis == 1: + raise ValueError('If `periods` contains multiple shifts, `axis` cannot be 1.') + if len(periods) == 0: + raise ValueError('If `periods` is an iterable, it cannot be empty.') + from pandas.core.reshape.concat import concat + result = [] + for period in periods: + if not isinstance(period, int): + raise TypeError( + f"Value {period} in periods must be integer, but is type {type(period)}." + ) + result.append( + super() + .shift(periods=period, freq=freq, axis=axis, fill_value=fill_value) + .add_suffix(f"{suffix}_{period}" if suffix else f"_{period}") + ) + return concat(result, axis=1) if result else self + if freq is not None and fill_value is not lib.no_default: # GH#53832 raise ValueError( diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 9084395871675..1986d5b3e5fcc 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -15,6 +15,7 @@ Any, Callable, ClassVar, + Iterable, Literal, NoReturn, cast, @@ -10521,10 +10522,11 @@ def mask( @doc(klass=_shared_doc_kwargs["klass"]) def shift( self, - periods: int = 1, + periods: int | Iterable = 1, freq=None, axis: Axis = 0, fill_value: Hashable = lib.no_default, + suffix: str | None = None, ) -> Self: """ Shift index by desired number of periods with an optional time `freq`. @@ -10538,8 +10540,13 @@ def shift( Parameters ---------- - periods : int + periods : int or Iterable Number of periods to shift. Can be positive or negative. + If an iterable of ints, the data will be shifted once by each int. + This is equivalent to shifting by one value at a time and + concatenating all resulting frames. The resulting columns will have + the shift suffixed to their column names. For multiple periods, + axis must not be 1. freq : DateOffset, tseries.offsets, timedelta, or str, optional Offset to use from the tseries module or time rule (e.g. 'EOM'). If `freq` is specified then the index values are shifted but the @@ -10556,6 +10563,9 @@ def shift( For numeric data, ``np.nan`` is used. For datetime, timedelta, or period data, etc. :attr:`NaT` is used. For extension dtypes, ``self.dtype.na_value`` is used. + suffix: str, optional + If str and periods is an iterable, this is added after the column + name and before the shift value for each shifted column name. Returns ------- @@ -10621,6 +10631,14 @@ def shift( 2020-01-06 15 18 22 2020-01-07 30 33 37 2020-01-08 45 48 52 + + >>> df['Col1'].shift(periods=[1, 2]) + Col1_0 Col1_1 Col1_2 + 2020-01-01 10 NaN NaN + 2020-01-02 20 10.0 NaN + 2020-01-03 15 20.0 10.0 + 2020-01-04 30 15.0 20.0 + 2020-01-05 45 30.0 15.0 """ axis = self._get_axis_number(axis) @@ -10634,6 +10652,11 @@ def shift( if periods == 0: return self.copy(deep=None) + if is_list_like(periods) and len(self.shape) == 1: + return self.to_frame().shift( + periods=periods, freq=freq, axis=axis, fill_value=fill_value + ) + if freq is None: # when freq is None, data is shifted, index is not axis = self._get_axis_number(axis) diff --git a/pandas/tests/frame/methods/test_shift.py b/pandas/tests/frame/methods/test_shift.py index ebbb7ca13646f..899f435f93906 100644 --- a/pandas/tests/frame/methods/test_shift.py +++ b/pandas/tests/frame/methods/test_shift.py @@ -656,3 +656,44 @@ def test_shift_axis1_many_periods(self): shifted2 = df.shift(-6, axis=1, fill_value=None) tm.assert_frame_equal(shifted2, expected) + + def test_shift_with_iterable(self): + # GH#44424 + data = {"a": [1, 2, 3], "b": [4, 5, 6]} + shifts = [0, 1, 2] + + df = DataFrame(data) + shifted = df.shift(shifts) + + expected = DataFrame( + { + "a_0": [1, 2, 3], + "b_0": [4, 5, 6], + "a_1": [np.NaN, 1.0, 2.0], + "b_1": [np.NaN, 4.0, 5.0], + "a_2": [np.NaN, np.NaN, 1.0], + "b_2": [np.NaN, np.NaN, 4.0], + } + ) + tm.assert_frame_equal(expected, shifted) + + # test pd.Series + s: pd.Series = df['a'] + tm.assert_frame_equal(s.shift(shifts), df[['a']].shift(shifts)) + + # test suffix + columns = df[['a']].shift(shifts, suffix='_suffix').columns + assert columns.tolist() == ['a_suffix_0', 'a_suffix_1', 'a_suffix_2'] + + # check bad inputs when doing multiple shifts + msg = "If `periods` contains multiple shifts, `axis` cannot be 1." + with pytest.raises(ValueError, match=msg): + df.shift([1, 2], axis=1) + + msg = f"Value s in periods must be integer, but is type ." + with pytest.raises(TypeError, match=msg): + df.shift(['s']) + + msg = f"If `periods` is an iterable, it cannot be empty." + with pytest.raises(ValueError, match=msg): + df.shift([]) From db9cd03c1d0a4c804e3e0b38e799f4e31d499fcd Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Thu, 13 Jul 2023 22:27:29 +0200 Subject: [PATCH 02/23] precommit --- doc/source/whatsnew/v2.1.0.rst | 2 +- pandas/core/frame.py | 9 ++++++--- pandas/core/generic.py | 4 ++-- pandas/tests/frame/methods/test_shift.py | 16 ++++++++-------- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index b08f9c8be002e..8ea8bdc867925 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -123,12 +123,12 @@ Other enhancements - Added :meth:`ExtensionArray.interpolate` used by :meth:`Series.interpolate` and :meth:`DataFrame.interpolate` (:issue:`53659`) - Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`) - Added a new parameter ``by_row`` to :meth:`Series.apply` and :meth:`DataFrame.apply`. When set to ``False`` the supplied callables will always operate on the whole Series or DataFrame (:issue:`53400`, :issue:`53601`). +- :meth:`DataFrame.shift` and :meth:`Series.shift` now allow shifting by multiple periods by supplying a list of periods (:issue:`44424`) - Groupby aggregations (such as :meth:`DataFrameGroupby.sum`) now can preserve the dtype of the input instead of casting to ``float64`` (:issue:`44952`) - Improved error message when :meth:`DataFrameGroupBy.agg` failed (:issue:`52930`) - Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to lzma.LZMAFile (:issue:`52979`) - Performance improvement in :func:`concat` with homogeneous ``np.float64`` or ``np.float32`` dtypes (:issue:`52685`) - Performance improvement in :meth:`DataFrame.filter` when ``items`` is given (:issue:`52941`) -- :meth:`DataFrame.shift` and :meth:`Series.shift` now allow shifting by multiple periods by supplying a list of periods (:issue:`44424`) - .. --------------------------------------------------------------------------- diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 8cd1d634854e4..e6769114e3acc 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -5511,15 +5511,18 @@ def shift( if is_list_like(periods): if axis == 1: - raise ValueError('If `periods` contains multiple shifts, `axis` cannot be 1.') + raise ValueError( + "If `periods` contains multiple shifts, `axis` cannot be 1." + ) if len(periods) == 0: - raise ValueError('If `periods` is an iterable, it cannot be empty.') + raise ValueError("If `periods` is an iterable, it cannot be empty.") from pandas.core.reshape.concat import concat + result = [] for period in periods: if not isinstance(period, int): raise TypeError( - f"Value {period} in periods must be integer, but is type {type(period)}." + f"Periods must be integer, but {period} is {type(period)}." ) result.append( super() diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 1986d5b3e5fcc..b49b6fce615cf 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -15,7 +15,6 @@ Any, Callable, ClassVar, - Iterable, Literal, NoReturn, cast, @@ -198,6 +197,7 @@ if TYPE_CHECKING: from collections.abc import ( Hashable, + Iterable, Iterator, Mapping, Sequence, @@ -10655,7 +10655,7 @@ def shift( if is_list_like(periods) and len(self.shape) == 1: return self.to_frame().shift( periods=periods, freq=freq, axis=axis, fill_value=fill_value - ) + ) if freq is None: # when freq is None, data is shifted, index is not diff --git a/pandas/tests/frame/methods/test_shift.py b/pandas/tests/frame/methods/test_shift.py index 899f435f93906..b82f8b5bb87fc 100644 --- a/pandas/tests/frame/methods/test_shift.py +++ b/pandas/tests/frame/methods/test_shift.py @@ -678,22 +678,22 @@ def test_shift_with_iterable(self): tm.assert_frame_equal(expected, shifted) # test pd.Series - s: pd.Series = df['a'] - tm.assert_frame_equal(s.shift(shifts), df[['a']].shift(shifts)) + s: Series = df["a"] + tm.assert_frame_equal(s.shift(shifts), df[["a"]].shift(shifts)) # test suffix - columns = df[['a']].shift(shifts, suffix='_suffix').columns - assert columns.tolist() == ['a_suffix_0', 'a_suffix_1', 'a_suffix_2'] + columns = df[["a"]].shift(shifts, suffix="_suffix").columns + assert columns.tolist() == ["a_suffix_0", "a_suffix_1", "a_suffix_2"] # check bad inputs when doing multiple shifts msg = "If `periods` contains multiple shifts, `axis` cannot be 1." with pytest.raises(ValueError, match=msg): df.shift([1, 2], axis=1) - - msg = f"Value s in periods must be integer, but is type ." + + msg = "Periods must be integer, but s is ." with pytest.raises(TypeError, match=msg): - df.shift(['s']) + df.shift(["s"]) - msg = f"If `periods` is an iterable, it cannot be empty." + msg = "If `periods` is an iterable, it cannot be empty." with pytest.raises(ValueError, match=msg): df.shift([]) From 9def0458fe95c9286fc7b8a8ca598c4d6a8206a4 Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Thu, 13 Jul 2023 22:41:30 +0200 Subject: [PATCH 03/23] slightly update test --- pandas/tests/frame/methods/test_shift.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pandas/tests/frame/methods/test_shift.py b/pandas/tests/frame/methods/test_shift.py index b82f8b5bb87fc..a8e6bc37fe00c 100644 --- a/pandas/tests/frame/methods/test_shift.py +++ b/pandas/tests/frame/methods/test_shift.py @@ -679,7 +679,8 @@ def test_shift_with_iterable(self): # test pd.Series s: Series = df["a"] - tm.assert_frame_equal(s.shift(shifts), df[["a"]].shift(shifts)) + df_one_column: DataFrame = df[["a"]] + tm.assert_frame_equal(s.shift(shifts), df_one_column.shift(shifts)) # test suffix columns = df[["a"]].shift(shifts, suffix="_suffix").columns From b7ea29719a484657fd517859206bdeec2c87685d Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Thu, 13 Jul 2023 23:09:54 +0200 Subject: [PATCH 04/23] Fix groupby API tests --- pandas/core/groupby/groupby.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 85ec8c1b86374..2eb4b01628511 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -4866,6 +4866,7 @@ def shift( freq=None, axis: Axis | lib.NoDefault = lib.no_default, fill_value=None, + suffix: str | None = None, ): """ Shift each group by periods observations. @@ -4887,6 +4888,8 @@ def shift( fill_value : optional The scalar value to use for newly introduced missing values. + suffix : str, optional + An optional suffix to append when there are multiple periods. Returns ------- From 299927a488934e4a8bcfb914bc84fac4fb3474fa Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Fri, 14 Jul 2023 13:39:00 +0200 Subject: [PATCH 05/23] mostly types, also exclude groupby --- pandas/core/dtypes/dtypes.py | 2 +- pandas/core/frame.py | 12 ++++++++---- pandas/core/generic.py | 3 ++- pandas/core/groupby/groupby.py | 13 ++++++++----- pandas/tests/groupby/test_api.py | 4 ++++ pandas/tests/groupby/test_groupby_shift_diff.py | 9 +++++++++ 6 files changed, 32 insertions(+), 11 deletions(-) diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 04e2b00744156..0f97766d07931 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -951,7 +951,7 @@ class PeriodDtype(PeriodDtypeBase, PandasExtensionDtype): # "Dict[int, PandasExtensionDtype]", base class "PandasExtensionDtype" # defined the type as "Dict[str, PandasExtensionDtype]") [assignment] _cache_dtypes: dict[BaseOffset, PeriodDtype] = {} # type: ignore[assignment] # noqa: E501 - __hash__ = PeriodDtypeBase.__hash__ + __hash__ = PeriodDtypeBase.__hash__ # type: ignore[assignment] _freq: BaseOffset def __new__(cls, freq): diff --git a/pandas/core/frame.py b/pandas/core/frame.py index e6769114e3acc..2890d85da2e6b 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -5510,6 +5510,8 @@ def shift( axis = self._get_axis_number(axis) if is_list_like(periods): + # periods is not necessarily a list, but otherwise mypy complains. + periods = cast(list, periods) if axis == 1: raise ValueError( "If `periods` contains multiple shifts, `axis` cannot be 1." @@ -5518,18 +5520,20 @@ def shift( raise ValueError("If `periods` is an iterable, it cannot be empty.") from pandas.core.reshape.concat import concat - result = [] + shifted_dataframes = [] for period in periods: - if not isinstance(period, int): + if not isinstance(int, period): raise TypeError( f"Periods must be integer, but {period} is {type(period)}." ) - result.append( + period = cast(int, period) + shifted_dataframes.append( super() .shift(periods=period, freq=freq, axis=axis, fill_value=fill_value) .add_suffix(f"{suffix}_{period}" if suffix else f"_{period}") ) - return concat(result, axis=1) if result else self + return concat(shifted_dataframes, axis=1) if shifted_dataframes else self + periods = cast(int, periods) if freq is not None and fill_value is not lib.no_default: # GH#53832 diff --git a/pandas/core/generic.py b/pandas/core/generic.py index b49b6fce615cf..9a0be2d6770fa 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -10522,7 +10522,7 @@ def mask( @doc(klass=_shared_doc_kwargs["klass"]) def shift( self, - periods: int | Iterable = 1, + periods: int | Iterable[int] = 1, freq=None, axis: Axis = 0, fill_value: Hashable = lib.no_default, @@ -10656,6 +10656,7 @@ def shift( return self.to_frame().shift( periods=periods, freq=freq, axis=axis, fill_value=fill_value ) + periods = cast(int, periods) if freq is None: # when freq is None, data is shifted, index is not diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 2eb4b01628511..04c09391e494c 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -4866,7 +4866,6 @@ def shift( freq=None, axis: Axis | lib.NoDefault = lib.no_default, fill_value=None, - suffix: str | None = None, ): """ Shift each group by periods observations. @@ -4875,8 +4874,9 @@ def shift( Parameters ---------- - periods : int, default 1 - Number of periods to shift. + periods : int | Iterable[int], default 1 + Number of periods to shift. If a list of values, shift each group by + each period. freq : str, optional Frequency string. axis : axis to shift, default 0 @@ -4888,8 +4888,6 @@ def shift( fill_value : optional The scalar value to use for newly introduced missing values. - suffix : str, optional - An optional suffix to append when there are multiple periods. Returns ------- @@ -4938,6 +4936,11 @@ def shift( catfish NaN NaN goldfish 5.0 8.0 """ + if is_list_like(periods): + raise NotImplementedError( + "shift with multiple periods is not implemented yet for groupby." + ) + if axis is not lib.no_default: axis = self.obj._get_axis_number(axis) self._deprecate_axis(axis, "shift") diff --git a/pandas/tests/groupby/test_api.py b/pandas/tests/groupby/test_api.py index 1122403be877f..488782ad31075 100644 --- a/pandas/tests/groupby/test_api.py +++ b/pandas/tests/groupby/test_api.py @@ -192,6 +192,8 @@ def test_frame_consistency(groupby_func): exclude_expected = {"numeric_only"} elif groupby_func in ("quantile",): exclude_expected = {"method", "axis"} + elif groupby_func in ("shift",): + exclude_expected = {"suffix"} # Ensure excluded arguments are actually in the signatures assert result & exclude_result == exclude_result @@ -252,6 +254,8 @@ def test_series_consistency(request, groupby_func): exclude_expected = {"args", "kwargs"} elif groupby_func in ("quantile",): exclude_result = {"numeric_only"} + elif groupby_func in ("shift",): + exclude_expected = {"suffix"} # Ensure excluded arguments are actually in the signatures assert result & exclude_result == exclude_result diff --git a/pandas/tests/groupby/test_groupby_shift_diff.py b/pandas/tests/groupby/test_groupby_shift_diff.py index 656471b2f6eb0..7b950e5982215 100644 --- a/pandas/tests/groupby/test_groupby_shift_diff.py +++ b/pandas/tests/groupby/test_groupby_shift_diff.py @@ -154,3 +154,12 @@ def test_multindex_empty_shift_with_fill(): shifted_with_fill = df.groupby(["a", "b"]).shift(1, fill_value=0) tm.assert_frame_equal(shifted, shifted_with_fill) tm.assert_index_equal(shifted.index, shifted_with_fill.index) + + +def test_group_shift_with_multiple_periods(): + df = DataFrame({"a": [1, 2, 3, 3, 2], "b": [True, True, False, False, True]}) + with pytest.raises( + NotImplementedError, + match=r"shift with multiple periods is not implemented yet for groupby.", + ): + df.groupby("a")["b"].shift([1, 2]) From 6cf632e0e168cd287b5f80151dc9efa392595cfc Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Fri, 14 Jul 2023 14:11:48 +0200 Subject: [PATCH 06/23] fix test --- pandas/core/frame.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 2890d85da2e6b..b3c0af3840f68 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -5522,7 +5522,7 @@ def shift( shifted_dataframes = [] for period in periods: - if not isinstance(int, period): + if not isinstance(period, int): raise TypeError( f"Periods must be integer, but {period} is {type(period)}." ) From 23d9142f8f37cc0d817b02883fff3596780272e4 Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Fri, 14 Jul 2023 15:30:55 +0200 Subject: [PATCH 07/23] mypy --- pandas/core/dtypes/dtypes.py | 2 +- pandas/core/frame.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 0f97766d07931..04e2b00744156 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -951,7 +951,7 @@ class PeriodDtype(PeriodDtypeBase, PandasExtensionDtype): # "Dict[int, PandasExtensionDtype]", base class "PandasExtensionDtype" # defined the type as "Dict[str, PandasExtensionDtype]") [assignment] _cache_dtypes: dict[BaseOffset, PeriodDtype] = {} # type: ignore[assignment] # noqa: E501 - __hash__ = PeriodDtypeBase.__hash__ # type: ignore[assignment] + __hash__ = PeriodDtypeBase.__hash__ _freq: BaseOffset def __new__(cls, freq): diff --git a/pandas/core/frame.py b/pandas/core/frame.py index b3c0af3840f68..3ed5a21cebd8b 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -5526,7 +5526,6 @@ def shift( raise TypeError( f"Periods must be integer, but {period} is {type(period)}." ) - period = cast(int, period) shifted_dataframes.append( super() .shift(periods=period, freq=freq, axis=axis, fill_value=fill_value) From b78d7d64e1e41cbe6cf7d3887c0525b85ee97ceb Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Fri, 14 Jul 2023 16:13:13 +0200 Subject: [PATCH 08/23] fix docstring --- pandas/core/generic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 9a0be2d6770fa..80a28b28b4692 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -10563,7 +10563,7 @@ def shift( For numeric data, ``np.nan`` is used. For datetime, timedelta, or period data, etc. :attr:`NaT` is used. For extension dtypes, ``self.dtype.na_value`` is used. - suffix: str, optional + suffix : str, optional If str and periods is an iterable, this is added after the column name and before the shift value for each shifted column name. @@ -10632,7 +10632,7 @@ def shift( 2020-01-07 30 33 37 2020-01-08 45 48 52 - >>> df['Col1'].shift(periods=[1, 2]) + >>> df['Col1'].shift(periods=[0, 1, 2]) Col1_0 Col1_1 Col1_2 2020-01-01 10 NaN NaN 2020-01-02 20 10.0 NaN From cabd28ce8b7a4c65530d880f4cf626daa7475ab4 Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Sat, 15 Jul 2023 14:49:28 +0200 Subject: [PATCH 09/23] change how futurewarning is handled in the test --- pandas/core/frame.py | 1 + pandas/core/groupby/groupby.py | 32 ++++++++++++++++--- pandas/tests/groupby/test_api.py | 4 --- .../tests/groupby/test_groupby_shift_diff.py | 16 +++++++--- 4 files changed, 39 insertions(+), 14 deletions(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 3ed5a21cebd8b..c2d627fdaec16 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -5526,6 +5526,7 @@ def shift( raise TypeError( f"Periods must be integer, but {period} is {type(period)}." ) + print(super()) shifted_dataframes.append( super() .shift(periods=period, freq=freq, axis=axis, fill_value=fill_value) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 04c09391e494c..d9268c552e292 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -4866,6 +4866,7 @@ def shift( freq=None, axis: Axis | lib.NoDefault = lib.no_default, fill_value=None, + suffix: str | None = None, ): """ Shift each group by periods observations. @@ -4936,17 +4937,38 @@ def shift( catfish NaN NaN goldfish 5.0 8.0 """ - if is_list_like(periods): - raise NotImplementedError( - "shift with multiple periods is not implemented yet for groupby." - ) - if axis is not lib.no_default: axis = self.obj._get_axis_number(axis) self._deprecate_axis(axis, "shift") else: axis = 0 + if is_list_like(periods): + # periods is not necessarily a list, but otherwise mypy complains. + periods = cast(list, periods) + if axis == 1: + raise ValueError( + "If `periods` contains multiple shifts, `axis` cannot be 1." + ) + if len(periods) == 0: + raise ValueError("If `periods` is an iterable, it cannot be empty.") + from pandas.core.reshape.concat import concat + + shifted_dataframes = [] + for period in periods: + if not isinstance(period, int): + raise TypeError( + f"Periods must be integer, but {period} is {type(period)}." + ) + shifted_dataframes.append( + DataFrame( + self.shift( + periods=period, freq=freq, axis=axis, fill_value=fill_value + ) + ).add_suffix(f"{suffix}_{period}" if suffix else f"_{period}") + ) + return concat(shifted_dataframes, axis=1) if shifted_dataframes else self + if freq is not None or axis != 0: f = lambda x: x.shift(periods, freq, axis, fill_value) return self._python_apply_general(f, self._selected_obj, is_transform=True) diff --git a/pandas/tests/groupby/test_api.py b/pandas/tests/groupby/test_api.py index 488782ad31075..1122403be877f 100644 --- a/pandas/tests/groupby/test_api.py +++ b/pandas/tests/groupby/test_api.py @@ -192,8 +192,6 @@ def test_frame_consistency(groupby_func): exclude_expected = {"numeric_only"} elif groupby_func in ("quantile",): exclude_expected = {"method", "axis"} - elif groupby_func in ("shift",): - exclude_expected = {"suffix"} # Ensure excluded arguments are actually in the signatures assert result & exclude_result == exclude_result @@ -254,8 +252,6 @@ def test_series_consistency(request, groupby_func): exclude_expected = {"args", "kwargs"} elif groupby_func in ("quantile",): exclude_result = {"numeric_only"} - elif groupby_func in ("shift",): - exclude_expected = {"suffix"} # Ensure excluded arguments are actually in the signatures assert result & exclude_result == exclude_result diff --git a/pandas/tests/groupby/test_groupby_shift_diff.py b/pandas/tests/groupby/test_groupby_shift_diff.py index 7b950e5982215..39685cd2393b6 100644 --- a/pandas/tests/groupby/test_groupby_shift_diff.py +++ b/pandas/tests/groupby/test_groupby_shift_diff.py @@ -156,10 +156,16 @@ def test_multindex_empty_shift_with_fill(): tm.assert_index_equal(shifted.index, shifted_with_fill.index) +@pytest.mark.filterwarnings("ignore::FutureWarning") def test_group_shift_with_multiple_periods(): df = DataFrame({"a": [1, 2, 3, 3, 2], "b": [True, True, False, False, True]}) - with pytest.raises( - NotImplementedError, - match=r"shift with multiple periods is not implemented yet for groupby.", - ): - df.groupby("a")["b"].shift([1, 2]) + + shifted_df = df.groupby("b")[["a"]].shift([0, 1]) + expected_df = DataFrame( + {"a_0": [1, 2, 3, 3, 2], "a_1": [np.nan, 1.0, np.nan, 3.0, 2.0]} + ) + tm.assert_frame_equal(shifted_df, expected_df) + + # series + shifted_series = df.groupby("b")["a"].shift([0, 1]) + tm.assert_frame_equal(shifted_series, expected_df) From 5017721cd8a5949c34e22b195779997826e17cba Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Sat, 15 Jul 2023 16:52:26 +0200 Subject: [PATCH 10/23] fix docstring --- pandas/core/groupby/groupby.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index d9268c552e292..960f3ebd2facb 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -10,6 +10,7 @@ class providing the base-class of operations. from collections.abc import ( Hashable, + Iterable, Iterator, Mapping, Sequence, @@ -4862,7 +4863,7 @@ def cummax( @Substitution(name="groupby") def shift( self, - periods: int = 1, + periods: int | Iterable[int] = 1, freq=None, axis: Axis | lib.NoDefault = lib.no_default, fill_value=None, @@ -4890,6 +4891,10 @@ def shift( fill_value : optional The scalar value to use for newly introduced missing values. + suffix : str, optional + A string to add to each shifted column if there are multiple periods. + Ignored otherwise. + Returns ------- Series or DataFrame @@ -4968,6 +4973,7 @@ def shift( ).add_suffix(f"{suffix}_{period}" if suffix else f"_{period}") ) return concat(shifted_dataframes, axis=1) if shifted_dataframes else self + periods = cast(int, periods) if freq is not None or axis != 0: f = lambda x: x.shift(periods, freq, axis, fill_value) From 6f3ec9be6d180d5c9fc87a4caff59f09f1e5d421 Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Sat, 15 Jul 2023 17:37:28 +0200 Subject: [PATCH 11/23] remove debug statement --- pandas/core/frame.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index c2d627fdaec16..3ed5a21cebd8b 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -5526,7 +5526,6 @@ def shift( raise TypeError( f"Periods must be integer, but {period} is {type(period)}." ) - print(super()) shifted_dataframes.append( super() .shift(periods=period, freq=freq, axis=axis, fill_value=fill_value) From 8d085f3a649181c1762c69143675a592f70f5121 Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Sun, 16 Jul 2023 22:41:50 +0200 Subject: [PATCH 12/23] address comments --- pandas/core/frame.py | 14 ++-- pandas/core/groupby/groupby.py | 65 +++++++++++-------- pandas/tests/frame/methods/test_shift.py | 40 ++++++++++-- .../tests/groupby/test_groupby_shift_diff.py | 2 +- 4 files changed, 79 insertions(+), 42 deletions(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 3ed5a21cebd8b..9fbdd462405f8 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -5507,6 +5507,13 @@ def shift( fill_value: Hashable = lib.no_default, suffix: str | None = None, ) -> DataFrame: + if freq is not None and fill_value is not lib.no_default: + # GH#53832 + raise ValueError( + "Cannot pass both 'freq' and 'fill_value' to " + f"{type(self).__name__}.shift" + ) + axis = self._get_axis_number(axis) if is_list_like(periods): @@ -5534,13 +5541,6 @@ def shift( return concat(shifted_dataframes, axis=1) if shifted_dataframes else self periods = cast(int, periods) - if freq is not None and fill_value is not lib.no_default: - # GH#53832 - raise ValueError( - "Cannot pass both 'freq' and 'fill_value' to " - f"{type(self).__name__}.shift" - ) - ncols = len(self.columns) arrays = self._mgr.arrays if axis == 1 and periods != 0 and ncols > 0 and freq is None: diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 960f3ebd2facb..8f8a8b96d44bf 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -4949,8 +4949,6 @@ def shift( axis = 0 if is_list_like(periods): - # periods is not necessarily a list, but otherwise mypy complains. - periods = cast(list, periods) if axis == 1: raise ValueError( "If `periods` contains multiple shifts, `axis` cannot be 1." @@ -4959,39 +4957,50 @@ def shift( raise ValueError("If `periods` is an iterable, it cannot be empty.") from pandas.core.reshape.concat import concat - shifted_dataframes = [] - for period in periods: - if not isinstance(period, int): - raise TypeError( - f"Periods must be integer, but {period} is {type(period)}." - ) - shifted_dataframes.append( - DataFrame( - self.shift( - periods=period, freq=freq, axis=axis, fill_value=fill_value - ) - ).add_suffix(f"{suffix}_{period}" if suffix else f"_{period}") + add_suffix = True + else: + periods = [periods] + add_suffix = False + + shifted_dataframes = [] + for period in periods: + if not isinstance(period, int): + raise TypeError( + f"Periods must be integer, but {period} is {type(period)}." + ) + if freq is not None or axis != 0: + f = lambda x: x.shift(period, freq, axis, fill_value) + shifted = self._python_apply_general( + f, self._selected_obj, is_transform=True ) - return concat(shifted_dataframes, axis=1) if shifted_dataframes else self - periods = cast(int, periods) - if freq is not None or axis != 0: - f = lambda x: x.shift(periods, freq, axis, fill_value) - return self._python_apply_general(f, self._selected_obj, is_transform=True) + else: + ids, _, ngroups = self.grouper.group_info + res_indexer = np.zeros(len(ids), dtype=np.int64) - ids, _, ngroups = self.grouper.group_info - res_indexer = np.zeros(len(ids), dtype=np.int64) + libgroupby.group_shift_indexer(res_indexer, ids, ngroups, period) - libgroupby.group_shift_indexer(res_indexer, ids, ngroups, periods) + obj = self._obj_with_exclusions - obj = self._obj_with_exclusions + shifted = obj._reindex_with_indexers( + {self.axis: (obj.axes[self.axis], res_indexer)}, + fill_value=fill_value, + allow_dups=True, + ) - res = obj._reindex_with_indexers( - {self.axis: (obj.axes[self.axis], res_indexer)}, - fill_value=fill_value, - allow_dups=True, + if add_suffix: + if len(shifted.shape) == 1: + shifted = shifted.to_frame() + shifted = shifted.add_suffix( + f"{suffix}_{period}" if suffix else f"_{period}" + ) + shifted_dataframes.append(shifted) + + return ( + shifted_dataframes[0] + if len(shifted_dataframes) == 1 + else concat(shifted_dataframes, axis=1) ) - return res @final @Substitution(name="groupby") diff --git a/pandas/tests/frame/methods/test_shift.py b/pandas/tests/frame/methods/test_shift.py index a8e6bc37fe00c..5ff5cc05c841c 100644 --- a/pandas/tests/frame/methods/test_shift.py +++ b/pandas/tests/frame/methods/test_shift.py @@ -657,7 +657,7 @@ def test_shift_axis1_many_periods(self): shifted2 = df.shift(-6, axis=1, fill_value=None) tm.assert_frame_equal(shifted2, expected) - def test_shift_with_iterable(self): + def test_shift_with_iterable_basic_functionality(self): # GH#44424 data = {"a": [1, 2, 3], "b": [4, 5, 6]} shifts = [0, 1, 2] @@ -677,19 +677,47 @@ def test_shift_with_iterable(self): ) tm.assert_frame_equal(expected, shifted) - # test pd.Series + def test_shift_with_iterable_series(self): + data = {"a": [1, 2, 3]} + shifts = [0, 1, 2] + + df = DataFrame(data) s: Series = df["a"] - df_one_column: DataFrame = df[["a"]] - tm.assert_frame_equal(s.shift(shifts), df_one_column.shift(shifts)) + tm.assert_frame_equal(s.shift(shifts), df.shift(shifts)) + + def test_shift_with_iterable_freq_and_fill_value(self): + df = DataFrame( + np.random.randn(5), index=date_range("1/1/2000", periods=5, freq="H") + ) + + tm.assert_frame_equal( + # rename because shift with an iterable leads to str column names + df.shift([1], fill_value=1).rename(columns=lambda x: int(x[0])), + df.shift(1, fill_value=1), + ) + + tm.assert_frame_equal( + df.shift([1], freq="H").rename(columns=lambda x: int(x[0])), + df.shift(1, freq="H"), + ) + + msg = r"Cannot pass both 'freq' and 'fill_value' to.*" + with pytest.raises(ValueError, match=msg): + df.shift([1, 2], fill_value=1, freq="H") + + def test_shift_with_iterable_check_other_arguments(self): + data = {"a": [1, 2], "b": [4, 5]} + shifts = [0, 1] + df = DataFrame(data) # test suffix columns = df[["a"]].shift(shifts, suffix="_suffix").columns - assert columns.tolist() == ["a_suffix_0", "a_suffix_1", "a_suffix_2"] + assert columns.tolist() == ["a_suffix_0", "a_suffix_1"] # check bad inputs when doing multiple shifts msg = "If `periods` contains multiple shifts, `axis` cannot be 1." with pytest.raises(ValueError, match=msg): - df.shift([1, 2], axis=1) + df.shift(shifts, axis=1) msg = "Periods must be integer, but s is ." with pytest.raises(TypeError, match=msg): diff --git a/pandas/tests/groupby/test_groupby_shift_diff.py b/pandas/tests/groupby/test_groupby_shift_diff.py index 39685cd2393b6..a4837b239073a 100644 --- a/pandas/tests/groupby/test_groupby_shift_diff.py +++ b/pandas/tests/groupby/test_groupby_shift_diff.py @@ -156,7 +156,7 @@ def test_multindex_empty_shift_with_fill(): tm.assert_index_equal(shifted.index, shifted_with_fill.index) -@pytest.mark.filterwarnings("ignore::FutureWarning") +@pytest.mark.filterwarnings("ignore:The 'axis' keyword in") def test_group_shift_with_multiple_periods(): df = DataFrame({"a": [1, 2, 3, 3, 2], "b": [True, True, False, False, True]}) From 6e66a1984536c91eeaad78d5fa65d555a8c98bdc Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Mon, 17 Jul 2023 01:11:34 +0200 Subject: [PATCH 13/23] refactor --- pandas/core/groupby/groupby.py | 7 ++- .../tests/groupby/test_groupby_shift_diff.py | 46 +++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 8f8a8b96d44bf..ac061f2cc99d8 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -4949,6 +4949,7 @@ def shift( axis = 0 if is_list_like(periods): + periods = cast(list, periods) if axis == 1: raise ValueError( "If `periods` contains multiple shifts, `axis` cannot be 1." @@ -4959,6 +4960,10 @@ def shift( add_suffix = True else: + if not isinstance(periods, int): + raise TypeError( + f"Periods must be integer, but {periods} is {type(periods)}." + ) periods = [periods] add_suffix = False @@ -4994,7 +4999,7 @@ def shift( shifted = shifted.add_suffix( f"{suffix}_{period}" if suffix else f"_{period}" ) - shifted_dataframes.append(shifted) + shifted_dataframes.append(cast(Union[Series, DataFrame], shifted)) return ( shifted_dataframes[0] diff --git a/pandas/tests/groupby/test_groupby_shift_diff.py b/pandas/tests/groupby/test_groupby_shift_diff.py index a4837b239073a..4fef25d82083b 100644 --- a/pandas/tests/groupby/test_groupby_shift_diff.py +++ b/pandas/tests/groupby/test_groupby_shift_diff.py @@ -7,6 +7,7 @@ Series, Timedelta, Timestamp, + date_range, ) import pandas._testing as tm @@ -169,3 +170,48 @@ def test_group_shift_with_multiple_periods(): # series shifted_series = df.groupby("b")["a"].shift([0, 1]) tm.assert_frame_equal(shifted_series, expected_df) + + +@pytest.mark.filterwarnings("ignore:The 'axis' keyword in") +def test_group_shift_with_multiple_periods_fill_and_freq(): + from pandas._libs import lib + + df = DataFrame( + {"a": [1, 2, 3, 4, 5], "b": [True, True, False, False, True]}, + index=date_range("1/1/2000", periods=5, freq="H"), + ) + + msg = r"Cannot pass both 'freq' and 'fill_value' to.*" + with pytest.raises(ValueError, match=msg): + df.shift([1, 2], fill_value=1, freq="H") + + shifted_df = df.groupby("b")[["a"]].shift( + [0, 1], + freq="H", + fill_value=lib.no_default, + ) + expected_df = DataFrame( + { + "a_0": [1.0, 2.0, 3.0, 4.0, 5.0, np.nan], + "a_1": [ + np.nan, + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + ], + }, + index=date_range("1/1/2000", periods=6, freq="H"), + ) + tm.assert_frame_equal(shifted_df, expected_df) + + # series + shifted_df = df.groupby("b")[["a"]].shift( + [0, 1], freq=lib.no_default, fill_value=-1 + ) + expected_df = DataFrame( + {"a_0": [1, 2, 3, 4, 5], "a_1": [-1, 1, -1, 3, 2]}, + index=date_range("1/1/2000", periods=5, freq="H"), + ) + tm.assert_frame_equal(shifted_df, expected_df) From f7ea7a35e663431a540da161e6c17309ca7be0dc Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Mon, 17 Jul 2023 16:08:41 +0200 Subject: [PATCH 14/23] handle default --- pandas/core/frame.py | 2 +- pandas/core/generic.py | 2 +- .../tests/groupby/test_groupby_shift_diff.py | 19 ++++++++----------- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 9fbdd462405f8..51ededaedb623 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -5507,7 +5507,7 @@ def shift( fill_value: Hashable = lib.no_default, suffix: str | None = None, ) -> DataFrame: - if freq is not None and fill_value is not lib.no_default: + if freq is not None and fill_value not in (lib.no_default, None): # GH#53832 raise ValueError( "Cannot pass both 'freq' and 'fill_value' to " diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 80a28b28b4692..0d4f03beae65a 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -10642,7 +10642,7 @@ def shift( """ axis = self._get_axis_number(axis) - if freq is not None and fill_value is not lib.no_default: + if freq is not None and fill_value not in (None, lib.no_default): # GH#53832 raise ValueError( "Cannot pass both 'freq' and 'fill_value' to " diff --git a/pandas/tests/groupby/test_groupby_shift_diff.py b/pandas/tests/groupby/test_groupby_shift_diff.py index 4fef25d82083b..959a39230390b 100644 --- a/pandas/tests/groupby/test_groupby_shift_diff.py +++ b/pandas/tests/groupby/test_groupby_shift_diff.py @@ -174,21 +174,15 @@ def test_group_shift_with_multiple_periods(): @pytest.mark.filterwarnings("ignore:The 'axis' keyword in") def test_group_shift_with_multiple_periods_fill_and_freq(): - from pandas._libs import lib - df = DataFrame( {"a": [1, 2, 3, 4, 5], "b": [True, True, False, False, True]}, index=date_range("1/1/2000", periods=5, freq="H"), ) - msg = r"Cannot pass both 'freq' and 'fill_value' to.*" - with pytest.raises(ValueError, match=msg): - df.shift([1, 2], fill_value=1, freq="H") - + # only freq shifted_df = df.groupby("b")[["a"]].shift( [0, 1], freq="H", - fill_value=lib.no_default, ) expected_df = DataFrame( { @@ -206,12 +200,15 @@ def test_group_shift_with_multiple_periods_fill_and_freq(): ) tm.assert_frame_equal(shifted_df, expected_df) - # series - shifted_df = df.groupby("b")[["a"]].shift( - [0, 1], freq=lib.no_default, fill_value=-1 - ) + # only fill + shifted_df = df.groupby("b")[["a"]].shift([0, 1], fill_value=-1) expected_df = DataFrame( {"a_0": [1, 2, 3, 4, 5], "a_1": [-1, 1, -1, 3, 2]}, index=date_range("1/1/2000", periods=5, freq="H"), ) tm.assert_frame_equal(shifted_df, expected_df) + + # both + msg = r"Cannot pass both 'freq' and 'fill_value' to.*" + with pytest.raises(ValueError, match=msg): + df.shift([1, 2], fill_value=1, freq="H") From f8e29d92076ec64937ab1640801602f85c67588f Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Mon, 17 Jul 2023 21:04:30 +0200 Subject: [PATCH 15/23] pylint --- pandas/core/groupby/groupby.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index ac061f2cc99d8..a8f7ccdee426a 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -4974,7 +4974,9 @@ def shift( f"Periods must be integer, but {period} is {type(period)}." ) if freq is not None or axis != 0: - f = lambda x: x.shift(period, freq, axis, fill_value) + f = lambda x: x.shift( + period, freq, axis, fill_value + ) # pylint: disable=cell-var-from-loop shifted = self._python_apply_general( f, self._selected_obj, is_transform=True ) From 21e8e70b95ebe8bbdf340ebbb95eed3b4937f3ed Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Tue, 18 Jul 2023 21:18:59 +0200 Subject: [PATCH 16/23] merge conflicts and mypy --- pandas/core/groupby/groupby.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 092b679d566ae..cb4591e0f905b 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -5006,6 +5006,7 @@ def shift( if len(periods) == 0: raise ValueError("If `periods` is an iterable, it cannot be empty.") from pandas.core.reshape.concat import concat + add_suffix = True else: if not isinstance(periods, int): @@ -5023,8 +5024,8 @@ def shift( ) if freq is not None or axis != 0: f = lambda x: x.shift( - period, freq, axis, fill_value - ) # pylint: disable=cell-var-from-loop + period, freq, axis, fill_value # pylint: disable=cell-var-from-loop + ) shifted = self._python_apply_general( f, self._selected_obj, is_transform=True ) From 0e78963cea383c31c5cffe5cb8aa7102fdf9c16e Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Tue, 18 Jul 2023 22:12:15 +0200 Subject: [PATCH 17/23] split tests, remove checking for None default --- pandas/core/frame.py | 2 +- pandas/core/generic.py | 2 +- .../tests/groupby/test_groupby_shift_diff.py | 20 +++++++++++++------ 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 21fcc227add98..7453ee7694b04 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -5529,7 +5529,7 @@ def shift( fill_value: Hashable = lib.no_default, suffix: str | None = None, ) -> DataFrame: - if freq is not None and fill_value not in (lib.no_default, None): + if freq is not None and fill_value is not lib.no_default: # GH#53832 raise ValueError( "Cannot pass both 'freq' and 'fill_value' to " diff --git a/pandas/core/generic.py b/pandas/core/generic.py index f191dd7e89556..18b823f7f0ee8 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -10665,7 +10665,7 @@ def shift( """ axis = self._get_axis_number(axis) - if freq is not None and fill_value not in (None, lib.no_default): + if freq is not None and fill_value is not lib.no_default: # GH#53832 raise ValueError( "Cannot pass both 'freq' and 'fill_value' to " diff --git a/pandas/tests/groupby/test_groupby_shift_diff.py b/pandas/tests/groupby/test_groupby_shift_diff.py index e0955851ea434..df816f81e4ec0 100644 --- a/pandas/tests/groupby/test_groupby_shift_diff.py +++ b/pandas/tests/groupby/test_groupby_shift_diff.py @@ -191,13 +191,11 @@ def test_group_shift_with_multiple_periods(): @pytest.mark.filterwarnings("ignore:The 'axis' keyword in") -def test_group_shift_with_multiple_periods_fill_and_freq(): +def test_group_shift_with_multiple_periods_and_freq(): df = DataFrame( {"a": [1, 2, 3, 4, 5], "b": [True, True, False, False, True]}, index=date_range("1/1/2000", periods=5, freq="H"), ) - - # only freq shifted_df = df.groupby("b")[["a"]].shift( [0, 1], freq="H", @@ -218,15 +216,25 @@ def test_group_shift_with_multiple_periods_fill_and_freq(): ) tm.assert_frame_equal(shifted_df, expected_df) - # only fill + +@pytest.mark.filterwarnings("ignore:The 'axis' keyword in") +def test_group_shift_with_multiple_periods_and_fill_value(): + df = DataFrame( + {"a": [1, 2, 3, 4, 5], "b": [True, True, False, False, True]}, + ) shifted_df = df.groupby("b")[["a"]].shift([0, 1], fill_value=-1) expected_df = DataFrame( {"a_0": [1, 2, 3, 4, 5], "a_1": [-1, 1, -1, 3, 2]}, - index=date_range("1/1/2000", periods=5, freq="H"), ) tm.assert_frame_equal(shifted_df, expected_df) - # both + +@pytest.mark.filterwarnings("ignore:The 'axis' keyword in") +def test_group_shift_with_multiple_periods_and_both_fill_and_freq_fails(): + df = DataFrame( + {"a": [1, 2, 3, 4, 5], "b": [True, True, False, False, True]}, + index=date_range("1/1/2000", periods=5, freq="H"), + ) msg = r"Cannot pass both 'freq' and 'fill_value' to.*" with pytest.raises(ValueError, match=msg): df.shift([1, 2], fill_value=1, freq="H") From d9bf54fac6d39a7ba666d73ba726d0e4da4f0cf8 Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Wed, 19 Jul 2023 11:04:52 +0200 Subject: [PATCH 18/23] address comments --- pandas/core/frame.py | 8 +++++--- pandas/core/generic.py | 7 +++---- pandas/core/groupby/groupby.py | 11 ++++++----- pandas/tests/frame/methods/test_shift.py | 12 ++++++++++-- pandas/tests/groupby/test_groupby_shift_diff.py | 17 +++++++++++++++-- 5 files changed, 39 insertions(+), 16 deletions(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 7453ee7694b04..4811a30c5fbb8 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -5523,7 +5523,7 @@ def _replace_columnwise( @doc(NDFrame.shift, klass=_shared_doc_kwargs["klass"]) def shift( self, - periods: int | Iterable[int] = 1, + periods: int | Sequence[int] = 1, freq: Frequency | None = None, axis: Axis = 0, fill_value: Hashable = lib.no_default, @@ -5540,7 +5540,7 @@ def shift( if is_list_like(periods): # periods is not necessarily a list, but otherwise mypy complains. - periods = cast(list, periods) + periods = cast(Sequence, periods) if axis == 1: raise ValueError( "If `periods` contains multiple shifts, `axis` cannot be 1." @@ -5560,7 +5560,9 @@ def shift( .shift(periods=period, freq=freq, axis=axis, fill_value=fill_value) .add_suffix(f"{suffix}_{period}" if suffix else f"_{period}") ) - return concat(shifted_dataframes, axis=1) if shifted_dataframes else self + return concat(shifted_dataframes, axis=1) + elif suffix: + raise ValueError("Cannot specify `suffix` if `periods` is an int.") periods = cast(int, periods) ncols = len(self.columns) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 18b823f7f0ee8..901e1f573fa27 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -197,7 +197,6 @@ if TYPE_CHECKING: from collections.abc import ( Hashable, - Iterable, Iterator, Mapping, Sequence, @@ -10545,7 +10544,7 @@ def mask( @doc(klass=_shared_doc_kwargs["klass"]) def shift( self, - periods: int | Iterable[int] = 1, + periods: int | Sequence[int] = 1, freq=None, axis: Axis = 0, fill_value: Hashable = lib.no_default, @@ -10563,7 +10562,7 @@ def shift( Parameters ---------- - periods : int or Iterable + periods : int or Sequence Number of periods to shift. Can be positive or negative. If an iterable of ints, the data will be shifted once by each int. This is equivalent to shifting by one value at a time and @@ -10675,7 +10674,7 @@ def shift( if periods == 0: return self.copy(deep=None) - if is_list_like(periods) and len(self.shape) == 1: + if is_list_like(periods) and isinstance(self, ABCSeries): return self.to_frame().shift( periods=periods, freq=freq, axis=axis, fill_value=fill_value ) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index cb4591e0f905b..84786218a206f 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -10,7 +10,6 @@ class providing the base-class of operations. from collections.abc import ( Hashable, - Iterable, Iterator, Mapping, Sequence, @@ -4909,7 +4908,7 @@ def cummax( @Substitution(name="groupby") def shift( self, - periods: int | Iterable[int] = 1, + periods: int | Sequence[int] = 1, freq=None, axis: Axis | lib.NoDefault = lib.no_default, fill_value=lib.no_default, @@ -4922,7 +4921,7 @@ def shift( Parameters ---------- - periods : int | Iterable[int], default 1 + periods : int | Sequence[int], default 1 Number of periods to shift. If a list of values, shift each group by each period. freq : str, optional @@ -4998,11 +4997,11 @@ def shift( axis = 0 if is_list_like(periods): - periods = cast(list, periods) if axis == 1: raise ValueError( "If `periods` contains multiple shifts, `axis` cannot be 1." ) + periods = cast(Sequence, periods) if len(periods) == 0: raise ValueError("If `periods` is an iterable, it cannot be empty.") from pandas.core.reshape.concat import concat @@ -5013,6 +5012,8 @@ def shift( raise TypeError( f"Periods must be integer, but {periods} is {type(periods)}." ) + if suffix: + raise ValueError("Cannot specify `suffix` if `periods` is an int.") periods = [periods] add_suffix = False @@ -5046,7 +5047,7 @@ def shift( ) if add_suffix: - if len(shifted.shape) == 1: + if isinstance(shifted, Series): shifted = shifted.to_frame() shifted = shifted.add_suffix( f"{suffix}_{period}" if suffix else f"_{period}" diff --git a/pandas/tests/frame/methods/test_shift.py b/pandas/tests/frame/methods/test_shift.py index 5ff5cc05c841c..de6a442c91858 100644 --- a/pandas/tests/frame/methods/test_shift.py +++ b/pandas/tests/frame/methods/test_shift.py @@ -678,6 +678,7 @@ def test_shift_with_iterable_basic_functionality(self): tm.assert_frame_equal(expected, shifted) def test_shift_with_iterable_series(self): + # GH#44424 data = {"a": [1, 2, 3]} shifts = [0, 1, 2] @@ -686,6 +687,7 @@ def test_shift_with_iterable_series(self): tm.assert_frame_equal(s.shift(shifts), df.shift(shifts)) def test_shift_with_iterable_freq_and_fill_value(self): + # GH#44424 df = DataFrame( np.random.randn(5), index=date_range("1/1/2000", periods=5, freq="H") ) @@ -706,13 +708,15 @@ def test_shift_with_iterable_freq_and_fill_value(self): df.shift([1, 2], fill_value=1, freq="H") def test_shift_with_iterable_check_other_arguments(self): + # GH#44424 data = {"a": [1, 2], "b": [4, 5]} shifts = [0, 1] df = DataFrame(data) # test suffix - columns = df[["a"]].shift(shifts, suffix="_suffix").columns - assert columns.tolist() == ["a_suffix_0", "a_suffix_1"] + shifted = df[["a"]].shift(shifts, suffix="_suffix") + expected = DataFrame({"a_suffix_0": [1, 2], "a_suffix_1": [np.nan, 1.0]}) + tm.assert_frame_equal(shifted, expected) # check bad inputs when doing multiple shifts msg = "If `periods` contains multiple shifts, `axis` cannot be 1." @@ -726,3 +730,7 @@ def test_shift_with_iterable_check_other_arguments(self): msg = "If `periods` is an iterable, it cannot be empty." with pytest.raises(ValueError, match=msg): df.shift([]) + + msg = "Cannot specify `suffix` if `periods` is an int." + with pytest.raises(ValueError, match=msg): + df.shift(1, suffix="fails") diff --git a/pandas/tests/groupby/test_groupby_shift_diff.py b/pandas/tests/groupby/test_groupby_shift_diff.py index df816f81e4ec0..1b559b2163b13 100644 --- a/pandas/tests/groupby/test_groupby_shift_diff.py +++ b/pandas/tests/groupby/test_groupby_shift_diff.py @@ -175,8 +175,18 @@ def test_shift_disallow_freq_and_fill_value(): df.groupby(df.index).shift(periods=-2, freq="D", fill_value="1") +def test_shift_disallow_suffix_if_periods_is_int(): + # GH#44424 + data = {"a": [1, 2, 3, 4, 5, 6], "b": [0, 0, 0, 1, 1, 1]} + df = DataFrame(data) + msg = "Cannot specify `suffix` if `periods` is an int." + with pytest.raises(ValueError, match=msg): + df.groupby("b").shift(1, suffix="fails") + + @pytest.mark.filterwarnings("ignore:The 'axis' keyword in") def test_group_shift_with_multiple_periods(): + # GH#44424 df = DataFrame({"a": [1, 2, 3, 3, 2], "b": [True, True, False, False, True]}) shifted_df = df.groupby("b")[["a"]].shift([0, 1]) @@ -192,6 +202,7 @@ def test_group_shift_with_multiple_periods(): @pytest.mark.filterwarnings("ignore:The 'axis' keyword in") def test_group_shift_with_multiple_periods_and_freq(): + # GH#44424 df = DataFrame( {"a": [1, 2, 3, 4, 5], "b": [True, True, False, False, True]}, index=date_range("1/1/2000", periods=5, freq="H"), @@ -219,6 +230,7 @@ def test_group_shift_with_multiple_periods_and_freq(): @pytest.mark.filterwarnings("ignore:The 'axis' keyword in") def test_group_shift_with_multiple_periods_and_fill_value(): + # GH#44424 df = DataFrame( {"a": [1, 2, 3, 4, 5], "b": [True, True, False, False, True]}, ) @@ -229,12 +241,13 @@ def test_group_shift_with_multiple_periods_and_fill_value(): tm.assert_frame_equal(shifted_df, expected_df) -@pytest.mark.filterwarnings("ignore:The 'axis' keyword in") +# @pytest.mark.filterwarnings("ignore:The 'axis' keyword in") def test_group_shift_with_multiple_periods_and_both_fill_and_freq_fails(): + # GH#44424 df = DataFrame( {"a": [1, 2, 3, 4, 5], "b": [True, True, False, False, True]}, index=date_range("1/1/2000", periods=5, freq="H"), ) msg = r"Cannot pass both 'freq' and 'fill_value' to.*" with pytest.raises(ValueError, match=msg): - df.shift([1, 2], fill_value=1, freq="H") + df.groupby("b")[["a"]].shift([1, 2], fill_value=1, freq="H") From 28469db411d9fc15371014b257e9528f27826474 Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Wed, 19 Jul 2023 17:34:04 +0200 Subject: [PATCH 19/23] mypy --- pandas/core/groupby/groupby.py | 2 +- pandas/tests/groupby/test_groupby_shift_diff.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index fd286226529cd..ef99c611bd259 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -5065,7 +5065,7 @@ def shift( if add_suffix: if isinstance(shifted, Series): - shifted = shifted.to_frame() + shifted = cast(NDFrameT, shifted.to_frame()) shifted = shifted.add_suffix( f"{suffix}_{period}" if suffix else f"_{period}" ) diff --git a/pandas/tests/groupby/test_groupby_shift_diff.py b/pandas/tests/groupby/test_groupby_shift_diff.py index 1b559b2163b13..2ea4ad7694dea 100644 --- a/pandas/tests/groupby/test_groupby_shift_diff.py +++ b/pandas/tests/groupby/test_groupby_shift_diff.py @@ -241,7 +241,7 @@ def test_group_shift_with_multiple_periods_and_fill_value(): tm.assert_frame_equal(shifted_df, expected_df) -# @pytest.mark.filterwarnings("ignore:The 'axis' keyword in") +@pytest.mark.filterwarnings("ignore:The 'axis' keyword in") def test_group_shift_with_multiple_periods_and_both_fill_and_freq_fails(): # GH#44424 df = DataFrame( From c8e5bad843f061ff15c9f84fd1845d3d18c5162d Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Wed, 19 Jul 2023 21:30:24 +0200 Subject: [PATCH 20/23] mypy again --- pandas/core/generic.py | 2 +- pandas/core/series.py | 2 +- pandas/tests/groupby/test_groupby_shift_diff.py | 4 ---- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 9ef682bb6ad58..0c446ffdae53e 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -10550,7 +10550,7 @@ def shift( axis: Axis = 0, fill_value: Hashable = lib.no_default, suffix: str | None = None, - ) -> Self: + ) -> Self | DataFrame: """ Shift index by desired number of periods with an optional time `freq`. diff --git a/pandas/core/series.py b/pandas/core/series.py index b6e6d9b5d70c5..12fe0e7690bfa 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -3002,7 +3002,7 @@ def autocorr(self, lag: int = 1) -> float: >>> s.autocorr() nan """ - return self.corr(self.shift(lag)) + return self.corr(cast(Series, self.shift(lag))) def dot(self, other: AnyArrayLike) -> Series | np.ndarray: """ diff --git a/pandas/tests/groupby/test_groupby_shift_diff.py b/pandas/tests/groupby/test_groupby_shift_diff.py index 2ea4ad7694dea..495f3fcd359c7 100644 --- a/pandas/tests/groupby/test_groupby_shift_diff.py +++ b/pandas/tests/groupby/test_groupby_shift_diff.py @@ -184,7 +184,6 @@ def test_shift_disallow_suffix_if_periods_is_int(): df.groupby("b").shift(1, suffix="fails") -@pytest.mark.filterwarnings("ignore:The 'axis' keyword in") def test_group_shift_with_multiple_periods(): # GH#44424 df = DataFrame({"a": [1, 2, 3, 3, 2], "b": [True, True, False, False, True]}) @@ -200,7 +199,6 @@ def test_group_shift_with_multiple_periods(): tm.assert_frame_equal(shifted_series, expected_df) -@pytest.mark.filterwarnings("ignore:The 'axis' keyword in") def test_group_shift_with_multiple_periods_and_freq(): # GH#44424 df = DataFrame( @@ -228,7 +226,6 @@ def test_group_shift_with_multiple_periods_and_freq(): tm.assert_frame_equal(shifted_df, expected_df) -@pytest.mark.filterwarnings("ignore:The 'axis' keyword in") def test_group_shift_with_multiple_periods_and_fill_value(): # GH#44424 df = DataFrame( @@ -241,7 +238,6 @@ def test_group_shift_with_multiple_periods_and_fill_value(): tm.assert_frame_equal(shifted_df, expected_df) -@pytest.mark.filterwarnings("ignore:The 'axis' keyword in") def test_group_shift_with_multiple_periods_and_both_fill_and_freq_fails(): # GH#44424 df = DataFrame( From cb49cacd4bf0be3e32dd032692ee06801855ff77 Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Fri, 21 Jul 2023 12:50:53 +0200 Subject: [PATCH 21/23] black --- pandas/tests/frame/methods/test_shift.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pandas/tests/frame/methods/test_shift.py b/pandas/tests/frame/methods/test_shift.py index 9581550c942bb..026550f928679 100644 --- a/pandas/tests/frame/methods/test_shift.py +++ b/pandas/tests/frame/methods/test_shift.py @@ -743,4 +743,3 @@ def test_shift_with_iterable_check_other_arguments(self): msg = "Cannot specify `suffix` if `periods` is an int." with pytest.raises(ValueError, match=msg): df.shift(1, suffix="fails") - From 44b08665a64c9e56c4dd9d7c563bc9ee1316500b Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Tue, 25 Jul 2023 13:42:58 +0200 Subject: [PATCH 22/23] address comments --- pandas/core/frame.py | 3 +-- pandas/core/groupby/groupby.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index cde10d54bd57a..60acbe801b9a5 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -5566,7 +5566,6 @@ def shift( axis = self._get_axis_number(axis) if is_list_like(periods): - # periods is not necessarily a list, but otherwise mypy complains. periods = cast(Sequence, periods) if axis == 1: raise ValueError( @@ -5578,7 +5577,7 @@ def shift( shifted_dataframes = [] for period in periods: - if not isinstance(period, int): + if not is_integer(period): raise TypeError( f"Periods must be integer, but {period} is {type(period)}." ) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index ef99c611bd259..3a41f62b02468 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -5025,7 +5025,7 @@ def shift( add_suffix = True else: - if not isinstance(periods, int): + if not is_integer(periods): raise TypeError( f"Periods must be integer, but {periods} is {type(periods)}." ) @@ -5036,7 +5036,7 @@ def shift( shifted_dataframes = [] for period in periods: - if not isinstance(period, int): + if not is_integer(period): raise TypeError( f"Periods must be integer, but {period} is {type(period)}." ) From eff4ed25472948760d623f9a5f9abac2a7dc26c0 Mon Sep 17 00:00:00 2001 From: "jona.sassenhagen" Date: Tue, 25 Jul 2023 16:56:31 +0200 Subject: [PATCH 23/23] mypy --- pandas/core/frame.py | 1 + pandas/core/groupby/groupby.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 60acbe801b9a5..0efdc85b228db 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -5581,6 +5581,7 @@ def shift( raise TypeError( f"Periods must be integer, but {period} is {type(period)}." ) + period = cast(int, period) shifted_dataframes.append( super() .shift(periods=period, freq=freq, axis=axis, fill_value=fill_value) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 3a41f62b02468..ae6be0d8463d3 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -5031,7 +5031,7 @@ def shift( ) if suffix: raise ValueError("Cannot specify `suffix` if `periods` is an int.") - periods = [periods] + periods = [cast(int, periods)] add_suffix = False shifted_dataframes = [] @@ -5040,6 +5040,7 @@ def shift( raise TypeError( f"Periods must be integer, but {period} is {type(period)}." ) + period = cast(int, period) if freq is not None or axis != 0: f = lambda x: x.shift( period, freq, axis, fill_value # pylint: disable=cell-var-from-loop