Skip to content

Support RGB[A] arrays in plot.imshow() #1796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions doc/gallery/plot_rasterio.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,10 @@
da.coords['lon'] = (('y', 'x'), lon)
da.coords['lat'] = (('y', 'x'), lat)

# Compute a greyscale out of the rgb image
greyscale = da.mean(dim='band')

# Plot on a map
ax = plt.subplot(projection=ccrs.PlateCarree())
greyscale.plot(ax=ax, x='lon', y='lat', transform=ccrs.PlateCarree(),
cmap='Greys_r', add_colorbar=False)
da.plot.imshow(ax=ax, x='lon', y='lat', rgb='band',
transform=ccrs.PlateCarree())
ax.coastlines('10m', color='r')
plt.show()

Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ Enhancements
By `Joe Hamman <https://github.com/jhamman>`_.
- Support for using `Zarr`_ as storage layer for xarray.
By `Ryan Abernathey <https://github.com/rabernat>`_.
- :func:`xarray.plot.imshow` now handles RGB and RGBA images.
By `Zac Hatfield-Dodds <https://github.com/Zac-HD>`_.
- Experimental support for parsing ENVI metadata to coordinates and attributes
in :py:func:`xarray.open_rasterio`.
By `Matti Eskelinen <https://github.com/maaleske>`_.
Expand Down
5 changes: 3 additions & 2 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,9 @@ def map_dataarray(self, func, x, y, **kwargs):
func_kwargs.update({'add_colorbar': False, 'add_labels': False})

# Get x, y labels for the first subplot
x, y = _infer_xy_labels(darray=self.data.loc[self.name_dicts.flat[0]],
x=x, y=y)
x, y = _infer_xy_labels(
darray=self.data.loc[self.name_dicts.flat[0]], x=x, y=y,
imshow=func.__name__ == 'imshow', rgb=kwargs.get('rgb', None))

for d, ax in zip(self.name_dicts.flat, self.axes.flat):
# None is the sentinel value
Expand Down
47 changes: 43 additions & 4 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,17 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
# Decide on a default for the colorbar before facetgrids
if add_colorbar is None:
add_colorbar = plotfunc.__name__ != 'contour'
imshow_rgb = (
plotfunc.__name__ == 'imshow' and
darray.ndim == (3 + (row is not None) + (col is not None)))
if imshow_rgb:
# Don't add a colorbar when showing an image with explicit colors
add_colorbar = False

# Handle facetgrids first
if row or col:
allargs = locals().copy()
allargs.pop('imshow_rgb')
allargs.update(allargs.pop('kwargs'))

# Need the decorated plotting function
Expand All @@ -470,12 +477,19 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
"Use colors keyword instead.",
DeprecationWarning, stacklevel=3)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add rgb as an actual argument inline, after x and y?

Copy link
Contributor Author

@Zac-HD Zac-HD Dec 24, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could, but this would be a breaking change for anyone using positional arguments. I'd therefore prefer not to, as it's just as easy to do the breaking bit later.

Please confirm if you'd like the breaking change; otherwise I'll leave it as-is.

xlab, ylab = _infer_xy_labels(darray=darray, x=x, y=y)
rgb = kwargs.pop('rgb', None)
xlab, ylab = _infer_xy_labels(
darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb)

if rgb is not None and plotfunc.__name__ != 'imshow':
raise ValueError('The "rgb" keyword is only valid for imshow()')
elif rgb is not None and not imshow_rgb:
raise ValueError('The "rgb" keyword is only valid for imshow()'
'with a three-dimensional array (per facet)')

# better to pass the ndarrays directly to plotting functions
xval = darray[xlab].values
yval = darray[ylab].values
zval = darray.to_masked_array(copy=False)

# check if we need to broadcast one dimension
if xval.ndim < yval.ndim:
Expand All @@ -486,8 +500,19 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,

# May need to transpose for correct x, y labels
# xlab may be the name of a coord, we have to check for dim names
if darray[xlab].dims[-1] == darray.dims[0]:
zval = zval.T
if imshow_rgb:
# For RGB[A] images, matplotlib requires the color dimension
# to be last. In Xarray the order should be unimportant, so
# we transpose to (y, x, color) to make this work.
yx_dims = (ylab, xlab)
dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims)
if dims != darray.dims:
darray = darray.transpose(*dims)
elif darray[xlab].dims[-1] == darray.dims[0]:
darray = darray.transpose()

# Pass the data as a masked ndarray too
zval = darray.to_masked_array(copy=False)

_ensure_plottable(xval, yval)

Expand Down Expand Up @@ -595,6 +620,11 @@ def imshow(x, y, z, ax, **kwargs):

Wraps :func:`matplotlib:matplotlib.pyplot.imshow`

While other plot methods require the DataArray to be strictly
two-dimensional, ``imshow`` also accepts a 3D array where some
dimension can be interpreted as RGB or RGBA color channels and
allows this dimension to be specified via the kwarg ``rgb=``.

.. note::
This function needs uniformly spaced coordinates to
properly label the axes. Call DataArray.plot() to check.
Expand Down Expand Up @@ -632,6 +662,15 @@ def imshow(x, y, z, ax, **kwargs):
# Allow user to override these defaults
defaults.update(kwargs)

if z.ndim == 3:
# matplotlib imshow uses black for missing data, but Xarray makes
# missing data transparent. We therefore add an alpha channel if
# there isn't one, and set it to transparent where data is masked.
if z.shape[-1] == 3:
z = np.ma.concatenate((z, np.ma.ones(z.shape[:2] + (1,))), 2)
z = z.copy()
z[np.any(z.mask, axis=-1), -1] = 0

primitive = ax.imshow(z, **defaults)

return primitive
Expand Down
57 changes: 55 additions & 2 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,65 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
levels=levels, norm=norm)


def _infer_xy_labels(darray, x, y):
def _infer_xy_labels_3d(darray, x, y, rgb):
"""
Determine x and y labels for showing RGB images.

Attempts to infer which dimension is RGB/RGBA by size and order of dims.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than all this complex logic with warnings in ambiguous cases, why not always treat the last and/or remaining (after explicit x/y labels) dimension as RGB? I think that solves the convenience use cases, without hard to understand/predict inference logic.

Copy link
Contributor Author

@Zac-HD Zac-HD Dec 24, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


"""
assert rgb is None or rgb != x
assert rgb is None or rgb != y
# Start by detecting and reporting invalid combinations of arguments
assert darray.ndim == 3
not_none = [a for a in (x, y, rgb) if a is not None]
if len(set(not_none)) < len(not_none):
raise ValueError(
'Dimension names must be None or unique strings, but imshow was '
'passed x=%r, y=%r, and rgb=%r.' % (x, y, rgb))
for label in not_none:
if label not in darray.dims:
raise ValueError('%r is not a dimension' % (label,))

# Then calculate rgb dimension if certain and check validity
could_be_color = [label for label in darray.dims
if darray[label].size in (3, 4) and label not in (x, y)]
if rgb is None and not could_be_color:
raise ValueError(
'A 3-dimensional array was passed to imshow(), but there is no '
'dimension that could be color. At least one dimension must be '
'of size 3 (RGB) or 4 (RGBA), and not given as x or y.')
if rgb is None and len(could_be_color) == 1:
rgb = could_be_color[0]
if rgb is not None and darray[rgb].size not in (3, 4):
raise ValueError('Cannot interpret dim %r of size %s as RGB or RGBA.'
% (rgb, darray[rgb].size))

# If rgb dimension is still unknown, there must be two or three dimensions
# in could_be_color. We therefore warn, and use a heuristic to break ties.
if rgb is None:
assert len(could_be_color) in (2, 3)
rgb = could_be_color[-1]
warnings.warn(
'Several dimensions of this array could be colors. Xarray '
'will use the last possible dimension (%r) to match '
'matplotlib.pyplot.imshow. You can pass names of x, y, '
'and/or rgb dimensions to override this guess.' % rgb)
assert rgb is not None

# Finally, we pick out the red slice and delegate to the 2D version:
return _infer_xy_labels(darray.isel(**{rgb: 0}).squeeze(), x, y)


def _infer_xy_labels(darray, x, y, imshow=False, rgb=None):
"""
Determine x and y labels. For use in _plot2d

darray must be a 2 dimensional data array.
darray must be a 2 dimensional data array, or 3d for imshow only.
"""
assert x is None or x != y
if imshow and darray.ndim == 3:
return _infer_xy_labels_3d(darray, x, y, rgb)

if x is None and y is None:
if darray.ndim != 2:
Expand Down
53 changes: 53 additions & 0 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,8 @@ def test_1d_raises_valueerror(self):

def test_3d_raises_valueerror(self):
a = DataArray(easy_array((2, 3, 4)))
if self.plotfunc.__name__ == 'imshow':
pytest.skip()
with raises_regex(ValueError, r'DataArray must be 2d'):
self.plotfunc(a)

Expand Down Expand Up @@ -670,6 +672,11 @@ def test_can_plot_axis_size_one(self):
if self.plotfunc.__name__ not in ('contour', 'contourf'):
self.plotfunc(DataArray(np.ones((1, 1))))

def test_disallows_rgb_arg(self):
with pytest.raises(ValueError):
# Always invalid for most plots. Invalid for imshow with 2D data.
self.plotfunc(DataArray(np.ones((2, 2))), rgb='not None')

def test_viridis_cmap(self):
cmap_name = self.plotmethod(cmap='viridis').get_cmap().name
self.assertEqual('viridis', cmap_name)
Expand Down Expand Up @@ -1062,6 +1069,52 @@ def test_2d_coord_names(self):
with raises_regex(ValueError, 'requires 1D coordinates'):
self.plotmethod(x='x2d', y='y2d')

def test_plot_rgb_image(self):
DataArray(
easy_array((10, 15, 3), start=0),
dims=['y', 'x', 'band'],
).plot.imshow()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to check that the colors have actually rendered correctly by introspecting deeper into the figures / axes that are generated by imshow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it's almost certainly possible - but I have absolutely no idea where to start!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay relying on matplotlib to plot colors properly. One sane option would be to mock plt.imshow to ensure it gets called properly, but we don't bother with that for our current tests. But that shouldn't get in the way of integration tests that verify that we can actually call matplotlib without any errors being raised.

self.assertEqual(0, len(find_possible_colorbars()))

def test_plot_rgb_image_explicit(self):
DataArray(
easy_array((10, 15, 3), start=0),
dims=['y', 'x', 'band'],
).plot.imshow(y='y', x='x', rgb='band')
self.assertEqual(0, len(find_possible_colorbars()))

def test_plot_rgb_faceted(self):
DataArray(
easy_array((2, 2, 10, 15, 3), start=0),
dims=['a', 'b', 'y', 'x', 'band'],
).plot.imshow(row='a', col='b')
self.assertEqual(0, len(find_possible_colorbars()))

def test_plot_rgba_image_transposed(self):
# We can handle the color axis being in any position
DataArray(
easy_array((4, 10, 15), start=0),
dims=['band', 'y', 'x'],
).plot.imshow()

def test_warns_ambigious_dim(self):
arr = DataArray(easy_array((3, 3, 3)), dims=['y', 'x', 'band'])
with pytest.warns(UserWarning):
arr.plot.imshow()
# but doesn't warn if dimensions specified
arr.plot.imshow(rgb='band')
arr.plot.imshow(x='x', y='y')

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add tests for errors related to rgb.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

def test_rgb_errors_too_many_dims(self):
arr = DataArray(easy_array((3, 3, 3, 3)), dims=['y', 'x', 'z', 'band'])
with pytest.raises(ValueError):
arr.plot.imshow(rgb='band')

def test_rgb_errors_bad_dim_sizes(self):
arr = DataArray(easy_array((5, 5, 5)), dims=['y', 'x', 'band'])
with pytest.raises(ValueError):
arr.plot.imshow(rgb='band')


class TestFacetGrid(PlotTestCase):

Expand Down