Skip to content

Commit 2c92f69

Browse files
authored
keep attrs in interpolate_na (#3970)
* interp_na now preserves attrs * added test * Added keep_attrs kwarg to interpolate_na * updated what's new * changed default to true
1 parent cb3326e commit 2c92f69

File tree

4 files changed

+30
-3
lines changed

4 files changed

+30
-3
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ Bug fixes
7777
By `Tom Nicholas <https://github.com/TomNicholas>`_.
7878
- Fix ``RasterioDeprecationWarning`` when using a ``vrt`` in ``open_rasterio``. (:issue:`3964`)
7979
By `Taher Chegini <https://github.com/cheginit>`_.
80+
- Fix bug causing :py:meth:`DataArray.interpolate_na` to always drop attributes,
81+
and added `keep_attrs` argument. (:issue:`3968`)
82+
By `Tom Nicholas <https://github.com/TomNicholas>`_.
83+
8084

8185
Documentation
8286
~~~~~~~~~~~~~

xarray/core/dataarray.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2098,6 +2098,7 @@ def interpolate_na(
20982098
max_gap: Union[
20992099
int, float, str, pd.Timedelta, np.timedelta64, datetime.timedelta
21002100
] = None,
2101+
keep_attrs: bool = None,
21012102
**kwargs: Any,
21022103
) -> "DataArray":
21032104
"""Fill in NaNs by interpolating according to different methods.
@@ -2152,6 +2153,10 @@ def interpolate_na(
21522153
* x (x) int64 0 1 2 3 4 5 6 7 8
21532154
21542155
The gap lengths are 3-0 = 3; 6-3 = 3; and 8-6 = 2 respectively
2156+
keep_attrs : bool, default True
2157+
If True, the dataarray's attributes (`attrs`) will be copied from
2158+
the original object to the new one. If False, the new
2159+
object will be returned without attributes.
21552160
kwargs : dict, optional
21562161
parameters passed verbatim to the underlying interpolation function
21572162
@@ -2174,6 +2179,7 @@ def interpolate_na(
21742179
limit=limit,
21752180
use_coordinate=use_coordinate,
21762181
max_gap=max_gap,
2182+
keep_attrs=keep_attrs,
21772183
**kwargs,
21782184
)
21792185

xarray/core/missing.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .common import _contains_datetime_like_objects, ones_like
1212
from .computation import apply_ufunc
1313
from .duck_array_ops import dask_array_type, datetime_to_numeric, timedelta_to_numeric
14+
from .options import _get_keep_attrs
1415
from .utils import OrderedSet, is_scalar
1516
from .variable import Variable, broadcast_variables
1617

@@ -294,6 +295,7 @@ def interp_na(
294295
method: str = "linear",
295296
limit: int = None,
296297
max_gap: Union[int, float, str, pd.Timedelta, np.timedelta64, dt.timedelta] = None,
298+
keep_attrs: bool = None,
297299
**kwargs,
298300
):
299301
"""Interpolate values according to different methods.
@@ -330,19 +332,22 @@ def interp_na(
330332
interp_class, kwargs = _get_interpolator(method, **kwargs)
331333
interpolator = partial(func_interpolate_na, interp_class, **kwargs)
332334

335+
if keep_attrs is None:
336+
keep_attrs = _get_keep_attrs(default=True)
337+
333338
with warnings.catch_warnings():
334339
warnings.filterwarnings("ignore", "overflow", RuntimeWarning)
335340
warnings.filterwarnings("ignore", "invalid value", RuntimeWarning)
336341
arr = apply_ufunc(
337342
interpolator,
338-
index,
339343
self,
344+
index,
340345
input_core_dims=[[dim], [dim]],
341346
output_core_dims=[[dim]],
342347
output_dtypes=[self.dtype],
343348
dask="parallelized",
344349
vectorize=True,
345-
keep_attrs=True,
350+
keep_attrs=keep_attrs,
346351
).transpose(*self.dims)
347352

348353
if limit is not None:
@@ -359,8 +364,9 @@ def interp_na(
359364
return arr
360365

361366

362-
def func_interpolate_na(interpolator, x, y, **kwargs):
367+
def func_interpolate_na(interpolator, y, x, **kwargs):
363368
"""helper function to apply interpolation along 1 dimension"""
369+
# reversed arguments are so that attrs are preserved from da, not index
364370
# it would be nice if this wasn't necessary, works around:
365371
# "ValueError: assignment destination is read-only" in assignment below
366372
out = y.copy()

xarray/tests/test_missing.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,17 @@ def test_interpolate_kwargs():
231231
assert_equal(actual, expected)
232232

233233

234+
def test_interpolate_keep_attrs():
235+
vals = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64)
236+
mvals = vals.copy()
237+
mvals[2] = np.nan
238+
missing = xr.DataArray(mvals, dims="x")
239+
missing.attrs = {"test": "value"}
240+
241+
actual = missing.interpolate_na(dim="x", keep_attrs=True)
242+
assert actual.attrs == {"test": "value"}
243+
244+
234245
def test_interpolate():
235246

236247
vals = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64)

0 commit comments

Comments
 (0)