diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 1dac72335d2..d55d737aa1c 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1214,6 +1214,9 @@ def fillna(self, value): out = ops.fillna(self, value) return out + def ffill(self, dim, limit=None): + return ops.ffill(self, dim, limit) + def combine_first(self, other): """Combine two DataArray objects, with union of coordinates. diff --git a/xarray/core/ops.py b/xarray/core/ops.py index 2ed3f81d185..a32eaf1b0e2 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -153,6 +153,22 @@ def fillna(data, other, join="left", dataset_join="left"): dataset_fill_value=np.nan, keep_attrs=True) +def ffill(data, dim, limit=None): + axis_num = data.get_axis_num(dim) + + if not has_bottleneck: + raise ImportError('ffill requires bottleneck to be installed') + + data = data.copy() + + if limit is None: + # bottleneck raises an error if you pass `None` + data.values = bn.push(data.values, axis=axis_num) + else: + data.values = bn.push(data.values, axis=axis_num, n=limit) + + return data + def where_method(self, cond, other=dtypes.NA): """Return elements from `self` or `other` depending on `cond`. diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index d39232c04c8..5c7c7cc7104 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3146,3 +3146,21 @@ def test_raise_no_warning_for_nan_in_binary_ops(): with pytest.warns(None) as record: xr.DataArray([1, 2, np.NaN]) > 0 assert len(record) == 0 + + +@pytest.mark.parametrize('da', (1, 2), indirect=True) +def test_ffill_functions(da): + result = da.ffill('time') + assert result.isnull().sum() == 0 + +def test_ffill_limit(da): + da = DataArray( + [0, np.nan, np.nan, np.nan, np.nan, 3, 4, 5, np.nan, 6, 7], + dims='time') + result = da.ffill('time') + expected = DataArray([0, 0, 0, 0, 0, 3, 4, 5, 5, 6, 7], dims='time') + assert_array_equal(result, expected) + + result = da.ffill('time', limit=1) + expected = DataArray( + [0, 0, np.nan, np.nan, np.nan, 3, 4, 5, 5, 6, 7], dims='time')