Skip to content

Commit cdec18f

Browse files
aulemahaldcherian
andauthored
Use the same function to floatize coords in polyfit and polyval (#9691)
* use same function to floatize coords in polyfit and polyval * Add whats new - fit typing - avoid warnings in tests * requires cftime * Ignore mypy issues with rcond * Apply suggestions from code review Co-authored-by: Deepak Cherian <[email protected]> --------- Co-authored-by: Deepak Cherian <[email protected]>
1 parent 0c6cded commit cdec18f

File tree

3 files changed

+34
-13
lines changed

3 files changed

+34
-13
lines changed

doc/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ Bug fixes
3737

3838
- Fix inadvertent deep-copying of child data in DataTree.
3939
By `Stephan Hoyer <https://github.com/shoyer>`_.
40+
- Fix regression in the interoperability of :py:meth:`DataArray.polyfit` and :py:meth:`xr.polyval` for date-time coordinates. (:pull:`9691`).
41+
By `Pascal Bourgault <https://github.com/aulemahal>`_.
4042

4143
Documentation
4244
~~~~~~~~~~~~~

xarray/core/dataset.py

+3-12
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
_contains_datetime_like_objects,
6161
get_chunksizes,
6262
)
63-
from xarray.core.computation import unify_chunks
63+
from xarray.core.computation import _ensure_numeric, unify_chunks
6464
from xarray.core.coordinates import (
6565
Coordinates,
6666
DatasetCoordinates,
@@ -87,7 +87,6 @@
8787
merge_coordinates_without_align,
8888
merge_core,
8989
)
90-
from xarray.core.missing import _floatize_x
9190
from xarray.core.options import OPTIONS, _get_keep_attrs
9291
from xarray.core.types import (
9392
Bins,
@@ -9066,22 +9065,14 @@ def polyfit(
90669065
variables = {}
90679066
skipna_da = skipna
90689067

9069-
x: Any = self.coords[dim].variable
9070-
x = _floatize_x((x,), (x,))[0][0]
9071-
9072-
try:
9073-
x = x.values.astype(np.float64)
9074-
except TypeError as e:
9075-
raise TypeError(
9076-
f"Dim {dim!r} must be castable to float64, got {type(x).__name__}."
9077-
) from e
9068+
x = np.asarray(_ensure_numeric(self.coords[dim]).astype(np.float64))
90789069

90799070
xname = f"{self[dim].name}_"
90809071
order = int(deg) + 1
90819072
lhs = np.vander(x, order)
90829073

90839074
if rcond is None:
9084-
rcond = x.shape[0] * np.finfo(x.dtype).eps
9075+
rcond = x.shape[0] * np.finfo(x.dtype).eps # type: ignore[assignment]
90859076

90869077
# Weights:
90879078
if w is not None:

xarray/tests/test_dataset.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -6720,6 +6720,34 @@ def test_polyfit_warnings(self) -> None:
67206720
ds.var1.polyfit("dim2", 10, full=True)
67216721
assert len(ws) == 1
67226722

6723+
def test_polyfit_polyval(self) -> None:
6724+
da = xr.DataArray(
6725+
np.arange(1, 10).astype(np.float64), dims=["x"], coords=dict(x=np.arange(9))
6726+
)
6727+
6728+
out = da.polyfit("x", 3, full=False)
6729+
da_fitval = xr.polyval(da.x, out.polyfit_coefficients)
6730+
# polyval introduces very small errors (1e-16 here)
6731+
xr.testing.assert_allclose(da_fitval, da)
6732+
6733+
da = da.assign_coords(x=xr.date_range("2001-01-01", periods=9, freq="YS"))
6734+
out = da.polyfit("x", 3, full=False)
6735+
da_fitval = xr.polyval(da.x, out.polyfit_coefficients)
6736+
xr.testing.assert_allclose(da_fitval, da, rtol=1e-3)
6737+
6738+
@requires_cftime
6739+
def test_polyfit_polyval_cftime(self) -> None:
6740+
da = xr.DataArray(
6741+
np.arange(1, 10).astype(np.float64),
6742+
dims=["x"],
6743+
coords=dict(
6744+
x=xr.date_range("2001-01-01", periods=9, freq="YS", calendar="noleap")
6745+
),
6746+
)
6747+
out = da.polyfit("x", 3, full=False)
6748+
da_fitval = xr.polyval(da.x, out.polyfit_coefficients)
6749+
np.testing.assert_allclose(da_fitval, da)
6750+
67236751
@staticmethod
67246752
def _test_data_var_interior(
67256753
original_data_var, padded_data_var, padded_dim_name, expected_pad_values
@@ -7230,7 +7258,7 @@ def test_differentiate_datetime(dask) -> None:
72307258
assert np.allclose(actual, 1.0)
72317259

72327260

7233-
@pytest.mark.skipif(not has_cftime, reason="Test requires cftime.")
7261+
@requires_cftime
72347262
@pytest.mark.parametrize("dask", [True, False])
72357263
def test_differentiate_cftime(dask) -> None:
72367264
rs = np.random.RandomState(42)

0 commit comments

Comments
 (0)