Skip to content

Commit 4d3a1b7

Browse files
committed
Normalisation for RGB imshow
1 parent 502a988 commit 4d3a1b7

File tree

5 files changed

+42
-3
lines changed

5 files changed

+42
-3
lines changed

doc/plotting.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,8 @@ example, consider the original data in Kelvins rather than Celsius:
305305
The Celsius data contain 0, so a diverging color map was used. The
306306
Kelvins do not have 0, so the default color map was used.
307307

308+
.. _robust-plotting:
309+
308310
Robust
309311
~~~~~~
310312

doc/whats-new.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ Enhancements
4141
- Support for using `Zarr`_ as storage layer for xarray.
4242
By `Ryan Abernathey <https://github.com/rabernat>`_.
4343
- :func:`xarray.plot.imshow` now handles RGB and RGBA images.
44+
Saturation can be adjusted with ``vmin`` and ``vmax``, or with ``robust=True``.
4445
By `Zac Hatfield-Dodds <https://github.com/Zac-HD>`_.
4546
- Experimental support for parsing ENVI metadata to coordinates and attributes
4647
in :py:func:`xarray.open_rasterio`.

xarray/plot/plot.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
import pandas as pd
1616
from datetime import datetime
1717

18-
from .utils import (_determine_cmap_params, _infer_xy_labels, get_axis,
19-
import_matplotlib_pyplot)
18+
from .utils import (ROBUST_PERCENTILE, _determine_cmap_params,
19+
_infer_xy_labels, get_axis, import_matplotlib_pyplot)
2020
from .facetgrid import FacetGrid
2121
from xarray.core.pycompat import basestring
2222

@@ -449,10 +449,29 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
449449
if imshow_rgb:
450450
# Don't add a colorbar when showing an image with explicit colors
451451
add_colorbar = False
452+
# Calculate vmin and vmax automatically for `robust=True`
453+
if robust:
454+
if vmin is None:
455+
vmin = np.nanpercentile(darray, ROBUST_PERCENTILE)
456+
if vmax is None:
457+
vmax = np.nanpercentile(darray, 100 - ROBUST_PERCENTILE)
458+
robust = False
459+
# Scale interval [vmin .. vmax] to [0 .. 1] and clip to bounds
460+
if vmin is not None or vmax is not None:
461+
vmin = vmin if vmin is not None else darray.min()
462+
vmax = vmax if vmax is not None else darray.max()
463+
darray = darray.astype('f4' if vmax - vmin < 2 ** 32 else 'f8')
464+
darray = ((darray - vmin) / (vmax - vmin)).astype('f4')
465+
vmin, vmax = None, None
466+
# There's a cyclic dependency via DataArray, so we can't
467+
# import xarray.ufuncs in global or outer scope.
468+
import xarray.ufuncs as xu
469+
darray = xu.minimum(xu.maximum(darray, 0), 1)
452470

453471
# Handle facetgrids first
454472
if row or col:
455473
allargs = locals().copy()
474+
allargs.pop('xu', None)
456475
allargs.pop('imshow_rgb')
457476
allargs.update(allargs.pop('kwargs'))
458477

@@ -625,6 +644,11 @@ def imshow(x, y, z, ax, **kwargs):
625644
dimension can be interpreted as RGB or RGBA color channels and
626645
allows this dimension to be specified via the kwarg ``rgb=``.
627646
647+
Unlike matplotlib, Xarray can apply ``vmin`` and ``vmax`` to RGB or RGBA
648+
data, by applying a single scaling factor and offset to all bands.
649+
Passing ``robust=True`` infers ``vmin`` and ``vmax``
650+
:ref:`in the usual way <robust-plotting>`.
651+
628652
.. note::
629653
This function needs uniformly spaced coordinates to
630654
properly label the axes. Call DataArray.plot() to check.

xarray/plot/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from ..core.utils import is_scalar
1212

1313

14+
ROBUST_PERCENTILE = 2.0
15+
16+
1417
def _load_default_cmap(fname='default_colormap.csv'):
1518
"""
1619
Returns viridis color map
@@ -165,7 +168,6 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
165168
cmap_params : dict
166169
Use depends on the type of the plotting function
167170
"""
168-
ROBUST_PERCENTILE = 2.0
169171
import matplotlib as mpl
170172

171173
calc_data = np.ravel(plot_data[~pd.isnull(plot_data)])

xarray/tests/test_plot.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,16 @@ def test_rgb_errors_bad_dim_sizes(self):
11151115
with pytest.raises(ValueError):
11161116
arr.plot.imshow(rgb='band')
11171117

1118+
def test_normalize_rgb_imshow(self):
1119+
for kwds in (
1120+
dict(vmin=-1), dict(vmax=2),
1121+
dict(vmin=-1, vmax=1), dict(vmin=0, vmax=0),
1122+
dict(vmin=0, robust=True), dict(vmax=-1, robust=True),
1123+
):
1124+
da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4))
1125+
arr = da.plot.imshow(**kwds).get_array()
1126+
assert 0 <= arr.min() <= arr.max() <= 1, kwds
1127+
11181128

11191129
class TestFacetGrid(PlotTestCase):
11201130

0 commit comments

Comments
 (0)