@@ -443,10 +443,17 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
443443 # Decide on a default for the colorbar before facetgrids
444444 if add_colorbar is None :
445445 add_colorbar = plotfunc .__name__ != 'contour'
446+ imshow_rgb = (
447+ plotfunc .__name__ == 'imshow' and
448+ darray .ndim == (3 + (row is not None ) + (col is not None )))
449+ if imshow_rgb :
450+ # Don't add a colorbar when showing an image with explicit colors
451+ add_colorbar = False
446452
447453 # Handle facetgrids first
448454 if row or col :
449455 allargs = locals ().copy ()
456+ allargs .pop ('imshow_rgb' )
450457 allargs .update (allargs .pop ('kwargs' ))
451458
452459 # Need the decorated plotting function
@@ -470,12 +477,19 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
470477 "Use colors keyword instead." ,
471478 DeprecationWarning , stacklevel = 3 )
472479
473- xlab , ylab = _infer_xy_labels (darray = darray , x = x , y = y )
480+ rgb = kwargs .pop ('rgb' , None )
481+ xlab , ylab = _infer_xy_labels (
482+ darray = darray , x = x , y = y , imshow = imshow_rgb , rgb = rgb )
483+
484+ if rgb is not None and plotfunc .__name__ != 'imshow' :
485+ raise ValueError ('The "rgb" keyword is only valid for imshow()' )
486+ elif rgb is not None and not imshow_rgb :
487+ raise ValueError ('The "rgb" keyword is only valid for imshow()'
488+ 'with a three-dimensional array (per facet)' )
474489
475490 # better to pass the ndarrays directly to plotting functions
476491 xval = darray [xlab ].values
477492 yval = darray [ylab ].values
478- zval = darray .to_masked_array (copy = False )
479493
480494 # check if we need to broadcast one dimension
481495 if xval .ndim < yval .ndim :
@@ -486,8 +500,19 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
486500
487501 # May need to transpose for correct x, y labels
488502 # xlab may be the name of a coord, we have to check for dim names
489- if darray [xlab ].dims [- 1 ] == darray .dims [0 ]:
490- zval = zval .T
503+ if imshow_rgb :
504+ # For RGB[A] images, matplotlib requires the color dimension
505+ # to be last. In Xarray the order should be unimportant, so
506+ # we transpose to (y, x, color) to make this work.
507+ yx_dims = (ylab , xlab )
508+ dims = yx_dims + tuple (d for d in darray .dims if d not in yx_dims )
509+ if dims != darray .dims :
510+ darray = darray .transpose (* dims )
511+ elif darray [xlab ].dims [- 1 ] == darray .dims [0 ]:
512+ darray = darray .transpose ()
513+
514+ # Pass the data as a masked ndarray too
515+ zval = darray .to_masked_array (copy = False )
491516
492517 _ensure_plottable (xval , yval )
493518
@@ -595,6 +620,11 @@ def imshow(x, y, z, ax, **kwargs):
595620
596621 Wraps :func:`matplotlib:matplotlib.pyplot.imshow`
597622
623+ While other plot methods require the DataArray to be strictly
624+ two-dimensional, ``imshow`` also accepts a 3D array where some
625+ dimension can be interpreted as RGB or RGBA color channels and
626+ allows this dimension to be specified via the kwarg ``rgb=``.
627+
598628 .. note::
599629 This function needs uniformly spaced coordinates to
600630 properly label the axes. Call DataArray.plot() to check.
@@ -632,6 +662,15 @@ def imshow(x, y, z, ax, **kwargs):
632662 # Allow user to override these defaults
633663 defaults .update (kwargs )
634664
665+ if z .ndim == 3 :
666+ # matplotlib imshow uses black for missing data, but Xarray makes
667+ # missing data transparent. We therefore add an alpha channel if
668+ # there isn't one, and set it to transparent where data is masked.
669+ if z .shape [- 1 ] == 3 :
670+ z = np .ma .concatenate ((z , np .ma .ones (z .shape [:2 ] + (1 ,))), 2 )
671+ z = z .copy ()
672+ z [np .any (z .mask , axis = - 1 ), - 1 ] = 0
673+
635674 primitive = ax .imshow (z , ** defaults )
636675
637676 return primitive
0 commit comments