diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index bcebe3ab024ba..0556ed754af82 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -186,6 +186,38 @@ representation of :class:`DataFrame` objects (:issue:`4889`). df df.to_dict(orient='tight') +.. _whatsnew_140.enhancements.shift: + +DataFrame.shift and Series.shift now accept an iterable for parameter ``'period'`` and new parameter ``'suffix'`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :meth:`DataFrame.shift` and :meth:`Series.shift` functions can take in an iterable, such as a list, for the period parameter. When an iterable is passed +to either function it returns a :class:`DataFrame` object with all of the shifted rows or columns concatenated with one another. +The function applies a shift designated by each element in the iterable. The resulting :class:`DataFrame` object's columns will retain the +names from the :class:`DataFrame` object that called shift, but postfixed with _, where name is the original +column name and num correlates to the current element of the period iterable. The function also now takes in a ``'suffix'`` parameter to add a custom suffix +to the column names instead of adding the current element of the period iterable (:issue:`44424`). + +Usage within the :class:`DataFrame` class: + +.. ipython:: python + + df = pd.DataFrame({ + 'a': [1, 2, 3], + 'b': [4, 5, 6] + }) + shifts = [0, 1, 2] + df.shift(shifts) + +Usage within the :class:`Series` class: + +.. ipython:: python + + ser = pd.Series([1, 2, 3]) + shifts = [0, 1, 2] + + ser.shift(shifts) + .. _whatsnew_140.enhancements.other: Other enhancements diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 8c85c4e961d99..ec1467f64e617 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -5326,9 +5326,38 @@ def shift( freq: Frequency | None = None, axis: Axis = 0, fill_value=lib.no_default, + suffix=None, ) -> DataFrame: axis = self._get_axis_number(axis) + # GH#44424 Handle the case of multiple shifts + if is_list_like(periods): + + new_df = DataFrame() + + from pandas.core.reshape.concat import concat + + new_df_list = [] + + for i in periods: + if not isinstance(i, int): + raise TypeError( + f"Value {i} in periods is not an integer, expected an integer" + ) + + new_df_list.append( + super() + .shift(periods=i, freq=freq, axis=axis, fill_value=fill_value) + .add_suffix(f"_{i}" if suffix is None else suffix) + ) + + new_df = concat(new_df_list, axis=1) + + if new_df.empty: + return self + + return new_df + ncols = len(self.columns) if axis == 1 and periods != 0 and fill_value is lib.no_default and ncols > 0: # We will infer fill_value to match the closest column diff --git a/pandas/core/series.py b/pandas/core/series.py index ffa31b4f66211..309e7193afe6f 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -9,6 +9,7 @@ TYPE_CHECKING, Any, Callable, + Collection, Hashable, Iterable, Literal, @@ -4938,7 +4939,18 @@ def _replace_single(self, to_replace, method: str, inplace: bool, limit): # error: Cannot determine type of 'shift' @doc(NDFrame.shift, klass=_shared_doc_kwargs["klass"]) # type: ignore[has-type] - def shift(self, periods=1, freq=None, axis=0, fill_value=None) -> Series: + def shift( + self, periods: int | Collection[int] = 1, freq=None, axis=0, fill_value=None + ) -> Series | DataFrame: + # Handle the case of multiple shifts + if is_list_like(periods): + if len(periods) == 0: + return self + + df = self.to_frame() + + return df.shift(periods, freq=freq, axis=axis, fill_value=fill_value) + return super().shift( periods=periods, freq=freq, axis=axis, fill_value=fill_value ) diff --git a/pandas/tests/frame/methods/test_shift.py b/pandas/tests/frame/methods/test_shift.py index 9cd0b8bb5b315..0cb43ffefe5a8 100644 --- a/pandas/tests/frame/methods/test_shift.py +++ b/pandas/tests/frame/methods/test_shift.py @@ -429,3 +429,23 @@ def test_shift_axis1_categorical_columns(self): columns=ci, ) tm.assert_frame_equal(result, expected) + + def test_shift_with_iterable(self): + # GH#44424 + data = {"a": [1, 2, 3], "b": [4, 5, 6]} + shifts = [0, 1, 2] + + df = DataFrame(data) + shifted = df.shift(shifts) + + expected = DataFrame( + { + "a_0": [1, 2, 3], + "b_0": [4, 5, 6], + "a_1": [np.NaN, 1.0, 2.0], + "b_1": [np.NaN, 4.0, 5.0], + "a_2": [np.NaN, np.NaN, 1.0], + "b_2": [np.NaN, np.NaN, 4.0], + } + ) + tm.assert_frame_equal(expected, shifted) diff --git a/pandas/tests/series/methods/test_shift.py b/pandas/tests/series/methods/test_shift.py index 4fb378720d89d..4ceb2a569aa9e 100644 --- a/pandas/tests/series/methods/test_shift.py +++ b/pandas/tests/series/methods/test_shift.py @@ -14,6 +14,7 @@ offsets, ) import pandas._testing as tm +from pandas.core.frame import DataFrame from pandas.tseries.offsets import BDay @@ -376,3 +377,15 @@ def test_shift_non_writable_array(self, input_data, output_data): expected = Series(output_data, dtype="float64") tm.assert_series_equal(result, expected) + + def test_shift_with_iterable(self): + # GH#44424 + ser = Series([1, 2, 3]) + shifts = [0, 1, 2] + + shifted = ser.shift(shifts) + expected = DataFrame( + {"0_0": [1, 2, 3], "0_1": [np.NaN, 1.0, 2.0], "0_2": [np.NaN, np.NaN, 1.0]} + ) + + tm.assert_frame_equal(expected, shifted)