Skip to content

Commit ddc352f

Browse files
authored
Basic curvefit implementation (#4849)
1 parent 57a4479 commit ddc352f

File tree

6 files changed

+452
-1
lines changed

6 files changed

+452
-1
lines changed

doc/api.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ Computation
179179
Dataset.integrate
180180
Dataset.map_blocks
181181
Dataset.polyfit
182+
Dataset.curvefit
182183

183184
**Aggregation**:
184185
:py:attr:`~Dataset.all`
@@ -375,7 +376,7 @@ Computation
375376
DataArray.integrate
376377
DataArray.polyfit
377378
DataArray.map_blocks
378-
379+
DataArray.curvefit
379380

380381
**Aggregation**:
381382
:py:attr:`~DataArray.all`

doc/user-guide/computation.rst

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,89 @@ The inverse operation is done with :py:meth:`~xarray.polyval`,
444444
.. note::
445445
These methods replicate the behaviour of :py:func:`numpy.polyfit` and :py:func:`numpy.polyval`.
446446

447+
448+
.. _compute.curvefit:
449+
450+
Fitting arbitrary functions
451+
===========================
452+
453+
Xarray objects also provide an interface for fitting more complex functions using
454+
:py:meth:`scipy.optimize.curve_fit`. :py:meth:`~xarray.DataArray.curvefit` accepts
455+
user-defined functions and can fit along multiple coordinates.
456+
457+
For example, we can fit a relationship between two ``DataArray`` objects, maintaining
458+
a unique fit at each spatial coordinate but aggregating over the time dimension:
459+
460+
.. ipython:: python
461+
462+
def exponential(x, a, xc):
463+
return np.exp((x - xc) / a)
464+
465+
466+
x = np.arange(-5, 5, 0.1)
467+
t = np.arange(-5, 5, 0.1)
468+
X, T = np.meshgrid(x, t)
469+
Z1 = np.random.uniform(low=-5, high=5, size=X.shape)
470+
Z2 = exponential(Z1, 3, X)
471+
Z3 = exponential(Z1, 1, -X)
472+
473+
ds = xr.Dataset(
474+
data_vars=dict(
475+
var1=(["t", "x"], Z1), var2=(["t", "x"], Z2), var3=(["t", "x"], Z3)
476+
),
477+
coords={"t": t, "x": x},
478+
)
479+
ds[["var2", "var3"]].curvefit(
480+
coords=ds.var1,
481+
func=exponential,
482+
reduce_dims="t",
483+
bounds={"a": (0.5, 5), "xc": (-5, 5)},
484+
)
485+
486+
We can also fit multi-dimensional functions, and even use a wrapper function to
487+
simultaneously fit a summation of several functions, such as this field containing
488+
two gaussian peaks:
489+
490+
.. ipython:: python
491+
492+
def gaussian_2d(coords, a, xc, yc, xalpha, yalpha):
493+
x, y = coords
494+
z = a * np.exp(
495+
-np.square(x - xc) / 2 / np.square(xalpha)
496+
- np.square(y - yc) / 2 / np.square(yalpha)
497+
)
498+
return z
499+
500+
501+
def multi_peak(coords, *args):
502+
z = np.zeros(coords[0].shape)
503+
for i in range(len(args) // 5):
504+
z += gaussian_2d(coords, *args[i * 5 : i * 5 + 5])
505+
return z
506+
507+
508+
x = np.arange(-5, 5, 0.1)
509+
y = np.arange(-5, 5, 0.1)
510+
X, Y = np.meshgrid(x, y)
511+
512+
n_peaks = 2
513+
names = ["a", "xc", "yc", "xalpha", "yalpha"]
514+
names = [f"{name}{i}" for i in range(n_peaks) for name in names]
515+
Z = gaussian_2d((X, Y), 3, 1, 1, 2, 1) + gaussian_2d((X, Y), 2, -1, -2, 1, 1)
516+
Z += np.random.normal(scale=0.1, size=Z.shape)
517+
518+
da = xr.DataArray(Z, dims=["y", "x"], coords={"y": y, "x": x})
519+
da.curvefit(
520+
coords=["x", "y"],
521+
func=multi_peak,
522+
param_names=names,
523+
kwargs={"maxfev": 10000},
524+
)
525+
526+
.. note::
527+
This method replicates the behavior of :py:func:`scipy.optimize.curve_fit`.
528+
529+
447530
.. _compute.broadcasting:
448531

449532
Broadcasting by dimension name

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ New Features
6565
:py:meth:`~pandas.core.groupby.GroupBy.get_group`.
6666
By `Deepak Cherian <https://github.com/dcherian>`_.
6767
- Disable the `cfgrib` backend if the `eccodes` library is not installed (:pull:`5083`). By `Baudouin Raoult <https://github.com/b8raoult>`_.
68+
- Added :py:meth:`DataArray.curvefit` and :py:meth:`Dataset.curvefit` for general curve fitting applications. (:issue:`4300`, :pull:`4849`)
69+
By `Sam Levang <https://github.com/slevang>`_.
6870

6971
Breaking changes
7072
~~~~~~~~~~~~~~~~

xarray/core/dataarray.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4418,6 +4418,84 @@ def query(
44184418
)
44194419
return ds[self.name]
44204420

4421+
def curvefit(
4422+
self,
4423+
coords: Union[Union[str, "DataArray"], Iterable[Union[str, "DataArray"]]],
4424+
func: Callable[..., Any],
4425+
reduce_dims: Union[Hashable, Iterable[Hashable]] = None,
4426+
skipna: bool = True,
4427+
p0: Dict[str, Any] = None,
4428+
bounds: Dict[str, Any] = None,
4429+
param_names: Sequence[str] = None,
4430+
kwargs: Dict[str, Any] = None,
4431+
):
4432+
"""
4433+
Curve fitting optimization for arbitrary functions.
4434+
4435+
Wraps `scipy.optimize.curve_fit` with `apply_ufunc`.
4436+
4437+
Parameters
4438+
----------
4439+
coords : DataArray, str or sequence of DataArray, str
4440+
Independent coordinate(s) over which to perform the curve fitting. Must share
4441+
at least one dimension with the calling object. When fitting multi-dimensional
4442+
functions, supply `coords` as a sequence in the same order as arguments in
4443+
`func`. To fit along existing dimensions of the calling object, `coords` can
4444+
also be specified as a str or sequence of strs.
4445+
func : callable
4446+
User specified function in the form `f(x, *params)` which returns a numpy
4447+
array of length `len(x)`. `params` are the fittable parameters which are optimized
4448+
by scipy curve_fit. `x` can also be specified as a sequence containing multiple
4449+
coordinates, e.g. `f((x0, x1), *params)`.
4450+
reduce_dims : str or sequence of str
4451+
Additional dimension(s) over which to aggregate while fitting. For example,
4452+
calling `ds.curvefit(coords='time', reduce_dims=['lat', 'lon'], ...)` will
4453+
aggregate all lat and lon points and fit the specified function along the
4454+
time dimension.
4455+
skipna : bool, optional
4456+
Whether to skip missing values when fitting. Default is True.
4457+
p0 : dictionary, optional
4458+
Optional dictionary of parameter names to initial guesses passed to the
4459+
`curve_fit` `p0` arg. If none or only some parameters are passed, the rest will
4460+
be assigned initial values following the default scipy behavior.
4461+
bounds : dictionary, optional
4462+
Optional dictionary of parameter names to bounding values passed to the
4463+
`curve_fit` `bounds` arg. If none or only some parameters are passed, the rest
4464+
will be unbounded following the default scipy behavior.
4465+
param_names: seq, optional
4466+
Sequence of names for the fittable parameters of `func`. If not supplied,
4467+
this will be automatically determined by arguments of `func`. `param_names`
4468+
should be manually supplied when fitting a function that takes a variable
4469+
number of parameters.
4470+
kwargs : dictionary
4471+
Additional keyword arguments to passed to scipy curve_fit.
4472+
4473+
Returns
4474+
-------
4475+
curvefit_results : Dataset
4476+
A single dataset which contains:
4477+
4478+
[var]_curvefit_coefficients
4479+
The coefficients of the best fit.
4480+
[var]_curvefit_covariance
4481+
The covariance matrix of the coefficient estimates.
4482+
4483+
See also
4484+
--------
4485+
DataArray.polyfit
4486+
scipy.optimize.curve_fit
4487+
"""
4488+
return self._to_temp_dataset().curvefit(
4489+
coords,
4490+
func,
4491+
reduce_dims=reduce_dims,
4492+
skipna=skipna,
4493+
p0=p0,
4494+
bounds=bounds,
4495+
param_names=param_names,
4496+
kwargs=kwargs,
4497+
)
4498+
44214499
# this needs to be at the end, or mypy will confuse with `str`
44224500
# https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names
44234501
str = utils.UncachedAccessor(StringAccessor)

0 commit comments

Comments
 (0)