Skip to content

Commit 85ded91

Browse files
max-sixtyshoyer
authored andcommitted
fill_value in shift (#2470)
* enable fill_value in shift * whatsnew * docstrings * should we make some dataarray methods avoid rtp-ing to dataset? * revert joining doc start * code comments * WIP * pad use dict rather than kwargs * handle 'missing' values in a more consistent way in shift * whatsnew move
1 parent 2667deb commit 85ded91

File tree

8 files changed

+76
-39
lines changed

8 files changed

+76
-39
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ Enhancements
7373
- 0d slices of ndarrays are now obtained directly through indexing, rather than
7474
extracting and wrapping a scalar, avoiding unnecessary copying. By `Daniel
7575
Wennberg <https://github.com/danielwe>`_.
76+
- Added support for ``fill_value`` with
77+
:py:meth:`~xarray.DataArray.shift` and :py:meth:`~xarray.Dataset.shift`
78+
By `Maximilian Roos <https://github.com/max-sixty>`_
7679

7780
Bug fixes
7881
~~~~~~~~~

xarray/core/dataarray.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import numpy as np
77
import pandas as pd
88

9-
from . import computation, groupby, indexing, ops, resample, rolling, utils
9+
from . import (
10+
computation, dtypes, groupby, indexing, ops, resample, rolling, utils)
1011
from ..plot.plot import _PlotMethods
1112
from .accessors import DatetimeAccessor
1213
from .alignment import align, reindex_like_indexers
@@ -2085,7 +2086,7 @@ def diff(self, dim, n=1, label='upper'):
20852086
ds = self._to_temp_dataset().diff(n=n, dim=dim, label=label)
20862087
return self._from_temp_dataset(ds)
20872088

2088-
def shift(self, shifts=None, **shifts_kwargs):
2089+
def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs):
20892090
"""Shift this array by an offset along one or more dimensions.
20902091
20912092
Only the data is moved; coordinates stay in place. Values shifted from
@@ -2098,6 +2099,8 @@ def shift(self, shifts=None, **shifts_kwargs):
20982099
Integer offset to shift along each of the given dimensions.
20992100
Positive offsets shift to the right; negative offsets shift to the
21002101
left.
2102+
fill_value: scalar, optional
2103+
Value to use for newly missing values
21012104
**shifts_kwargs:
21022105
The keyword arguments form of ``shifts``.
21032106
One of shifts or shifts_kwarg must be provided.
@@ -2122,8 +2125,9 @@ def shift(self, shifts=None, **shifts_kwargs):
21222125
Coordinates:
21232126
* x (x) int64 0 1 2
21242127
"""
2125-
ds = self._to_temp_dataset().shift(shifts=shifts, **shifts_kwargs)
2126-
return self._from_temp_dataset(ds)
2128+
variable = self.variable.shift(
2129+
shifts=shifts, fill_value=fill_value, **shifts_kwargs)
2130+
return self._replace(variable=variable)
21272131

21282132
def roll(self, shifts=None, roll_coords=None, **shifts_kwargs):
21292133
"""Roll this array by an offset along one or more dimensions.

xarray/core/dataset.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import xarray as xr
1414

1515
from . import (
16-
alignment, duck_array_ops, formatting, groupby, indexing, ops, pdcompat,
17-
resample, rolling, utils)
16+
alignment, dtypes, duck_array_ops, formatting, groupby, indexing, ops,
17+
pdcompat, resample, rolling, utils)
1818
from ..coding.cftimeindex import _parse_array_of_cftime_strings
1919
from .alignment import align
2020
from .common import (
@@ -3476,7 +3476,7 @@ def diff(self, dim, n=1, label='upper'):
34763476
else:
34773477
return difference
34783478

3479-
def shift(self, shifts=None, **shifts_kwargs):
3479+
def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs):
34803480
"""Shift this dataset by an offset along one or more dimensions.
34813481
34823482
Only data variables are moved; coordinates stay in place. This is
@@ -3488,6 +3488,8 @@ def shift(self, shifts=None, **shifts_kwargs):
34883488
Integer offset to shift along each of the given dimensions.
34893489
Positive offsets shift to the right; negative offsets shift to the
34903490
left.
3491+
fill_value: scalar, optional
3492+
Value to use for newly missing values
34913493
**shifts_kwargs:
34923494
The keyword arguments form of ``shifts``.
34933495
One of shifts or shifts_kwarg must be provided.
@@ -3522,9 +3524,10 @@ def shift(self, shifts=None, **shifts_kwargs):
35223524
variables = OrderedDict()
35233525
for name, var in iteritems(self.variables):
35243526
if name in self.data_vars:
3525-
var_shifts = dict((k, v) for k, v in shifts.items()
3526-
if k in var.dims)
3527-
variables[name] = var.shift(**var_shifts)
3527+
var_shifts = {k: v for k, v in shifts.items()
3528+
if k in var.dims}
3529+
variables[name] = var.shift(
3530+
fill_value=fill_value, shifts=var_shifts)
35283531
else:
35293532
variables[name] = var
35303533

xarray/core/rolling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def wrapped_func(self, **kwargs):
301301
else:
302302
shift = (-self.window // 2) + 1
303303
valid = (slice(None), ) * axis + (slice(-shift, None), )
304-
padded = padded.pad_with_fill_value(**{self.dim: (0, -shift)})
304+
padded = padded.pad_with_fill_value({self.dim: (0, -shift)})
305305

306306
if isinstance(padded.data, dask_array_type):
307307
values = dask_rolling_wrapper(func, padded,

xarray/core/variable.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -933,7 +933,7 @@ def squeeze(self, dim=None):
933933
dims = common.get_squeeze_dims(self, dim)
934934
return self.isel({d: 0 for d in dims})
935935

936-
def _shift_one_dim(self, dim, count):
936+
def _shift_one_dim(self, dim, count, fill_value=dtypes.NA):
937937
axis = self.get_axis_num(dim)
938938

939939
if count > 0:
@@ -944,7 +944,11 @@ def _shift_one_dim(self, dim, count):
944944
keep = slice(None)
945945

946946
trimmed_data = self[(slice(None),) * axis + (keep,)].data
947-
dtype, fill_value = dtypes.maybe_promote(self.dtype)
947+
948+
if fill_value is dtypes.NA:
949+
dtype, fill_value = dtypes.maybe_promote(self.dtype)
950+
else:
951+
dtype = self.dtype
948952

949953
shape = list(self.shape)
950954
shape[axis] = min(abs(count), shape[axis])
@@ -956,12 +960,12 @@ def _shift_one_dim(self, dim, count):
956960
else:
957961
full = np.full
958962

959-
nans = full(shape, fill_value, dtype=dtype)
963+
filler = full(shape, fill_value, dtype=dtype)
960964

961965
if count > 0:
962-
arrays = [nans, trimmed_data]
966+
arrays = [filler, trimmed_data]
963967
else:
964-
arrays = [trimmed_data, nans]
968+
arrays = [trimmed_data, filler]
965969

966970
data = duck_array_ops.concatenate(arrays, axis)
967971

@@ -973,7 +977,7 @@ def _shift_one_dim(self, dim, count):
973977

974978
return type(self)(self.dims, data, self._attrs, fastpath=True)
975979

976-
def shift(self, shifts=None, **shifts_kwargs):
980+
def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs):
977981
"""
978982
Return a new Variable with shifted data.
979983
@@ -983,6 +987,8 @@ def shift(self, shifts=None, **shifts_kwargs):
983987
Integer offset to shift along each of the given dimensions.
984988
Positive offsets shift to the right; negative offsets shift to the
985989
left.
990+
fill_value: scalar, optional
991+
Value to use for newly missing values
986992
**shifts_kwargs:
987993
The keyword arguments form of ``shifts``.
988994
One of shifts or shifts_kwarg must be provided.
@@ -995,7 +1001,7 @@ def shift(self, shifts=None, **shifts_kwargs):
9951001
shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'shift')
9961002
result = self
9971003
for dim, count in shifts.items():
998-
result = result._shift_one_dim(dim, count)
1004+
result = result._shift_one_dim(dim, count, fill_value=fill_value)
9991005
return result
10001006

10011007
def pad_with_fill_value(self, pad_widths=None, fill_value=dtypes.NA,

xarray/tests/test_dataarray.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
DataArray, Dataset, IndexVariable, Variable, align, broadcast)
1515
from xarray.coding.times import CFDatetimeCoder, _import_cftime
1616
from xarray.convert import from_cdms2
17+
from xarray.core import dtypes
1718
from xarray.core.common import ALL_DIMS, full_like
1819
from xarray.core.pycompat import OrderedDict, iteritems
1920
from xarray.tests import (
@@ -3128,12 +3129,19 @@ def test_coordinate_diff(self):
31283129
actual = lon.diff('lon')
31293130
assert_equal(expected, actual)
31303131

3131-
@pytest.mark.parametrize('offset', [-5, -2, -1, 0, 1, 2, 5])
3132-
def test_shift(self, offset):
3132+
@pytest.mark.parametrize('offset', [-5, 0, 1, 2])
3133+
@pytest.mark.parametrize('fill_value, dtype',
3134+
[(2, int), (dtypes.NA, float)])
3135+
def test_shift(self, offset, fill_value, dtype):
31333136
arr = DataArray([1, 2, 3], dims='x')
3134-
actual = arr.shift(x=1)
3135-
expected = DataArray([np.nan, 1, 2], dims='x')
3136-
assert_identical(expected, actual)
3137+
actual = arr.shift(x=1, fill_value=fill_value)
3138+
if fill_value == dtypes.NA:
3139+
# if we supply the default, we expect the missing value for a
3140+
# float array
3141+
fill_value = np.nan
3142+
expected = DataArray([fill_value, 1, 2], dims='x')
3143+
assert_identical(expected, actual)
3144+
assert actual.dtype == dtype
31373145

31383146
arr = DataArray([1, 2, 3], [('x', ['a', 'b', 'c'])])
31393147
expected = DataArray(arr.to_pandas().shift(offset))

xarray/tests/test_dataset.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from xarray import (
1616
ALL_DIMS, DataArray, Dataset, IndexVariable, MergeError, Variable, align,
1717
backends, broadcast, open_dataset, set_options)
18-
from xarray.core import indexing, npcompat, utils
18+
from xarray.core import dtypes, indexing, npcompat, utils
1919
from xarray.core.common import full_like
2020
from xarray.core.pycompat import (
2121
OrderedDict, integer_types, iteritems, unicode_type)
@@ -3917,12 +3917,17 @@ def test_dataset_diff_exception_label_str(self):
39173917
with raises_regex(ValueError, '\'label\' argument has to'):
39183918
ds.diff('dim2', label='raise_me')
39193919

3920-
def test_shift(self):
3920+
@pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0])
3921+
def test_shift(self, fill_value):
39213922
coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]}
39223923
attrs = {'meta': 'data'}
39233924
ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs)
3924-
actual = ds.shift(x=1)
3925-
expected = Dataset({'foo': ('x', [np.nan, 1, 2])}, coords, attrs)
3925+
actual = ds.shift(x=1, fill_value=fill_value)
3926+
if fill_value == dtypes.NA:
3927+
# if we supply the default, we expect the missing value for a
3928+
# float array
3929+
fill_value = np.nan
3930+
expected = Dataset({'foo': ('x', [fill_value, 1, 2])}, coords, attrs)
39263931
assert_identical(expected, actual)
39273932

39283933
with raises_regex(ValueError, 'dimensions'):

xarray/tests/test_variable.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import pytz
1414

1515
from xarray import Coordinate, Dataset, IndexVariable, Variable, set_options
16-
from xarray.core import indexing
16+
from xarray.core import dtypes, indexing
1717
from xarray.core.common import full_like, ones_like, zeros_like
1818
from xarray.core.indexing import (
1919
BasicIndexer, CopyOnWriteArray, DaskIndexingAdapter,
@@ -1179,33 +1179,41 @@ def test_indexing_0d_unicode(self):
11791179
expected = Variable((), u'tmax')
11801180
assert_identical(actual, expected)
11811181

1182-
def test_shift(self):
1182+
@pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0])
1183+
def test_shift(self, fill_value):
11831184
v = Variable('x', [1, 2, 3, 4, 5])
11841185

11851186
assert_identical(v, v.shift(x=0))
11861187
assert v is not v.shift(x=0)
11871188

1188-
expected = Variable('x', [np.nan, 1, 2, 3, 4])
1189-
assert_identical(expected, v.shift(x=1))
1190-
11911189
expected = Variable('x', [np.nan, np.nan, 1, 2, 3])
11921190
assert_identical(expected, v.shift(x=2))
11931191

1194-
expected = Variable('x', [2, 3, 4, 5, np.nan])
1195-
assert_identical(expected, v.shift(x=-1))
1192+
if fill_value == dtypes.NA:
1193+
# if we supply the default, we expect the missing value for a
1194+
# float array
1195+
fill_value_exp = np.nan
1196+
else:
1197+
fill_value_exp = fill_value
1198+
1199+
expected = Variable('x', [fill_value_exp, 1, 2, 3, 4])
1200+
assert_identical(expected, v.shift(x=1, fill_value=fill_value))
1201+
1202+
expected = Variable('x', [2, 3, 4, 5, fill_value_exp])
1203+
assert_identical(expected, v.shift(x=-1, fill_value=fill_value))
11961204

1197-
expected = Variable('x', [np.nan] * 5)
1198-
assert_identical(expected, v.shift(x=5))
1199-
assert_identical(expected, v.shift(x=6))
1205+
expected = Variable('x', [fill_value_exp] * 5)
1206+
assert_identical(expected, v.shift(x=5, fill_value=fill_value))
1207+
assert_identical(expected, v.shift(x=6, fill_value=fill_value))
12001208

12011209
with raises_regex(ValueError, 'dimension'):
12021210
v.shift(z=0)
12031211

12041212
v = Variable('x', [1, 2, 3, 4, 5], {'foo': 'bar'})
12051213
assert_identical(v, v.shift(x=0))
12061214

1207-
expected = Variable('x', [np.nan, 1, 2, 3, 4], {'foo': 'bar'})
1208-
assert_identical(expected, v.shift(x=1))
1215+
expected = Variable('x', [fill_value_exp, 1, 2, 3, 4], {'foo': 'bar'})
1216+
assert_identical(expected, v.shift(x=1, fill_value=fill_value))
12091217

12101218
def test_shift2d(self):
12111219
v = Variable(('x', 'y'), [[1, 2], [3, 4]])

0 commit comments

Comments
 (0)