@@ -254,21 +254,72 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
254254 levels = levels , norm = norm )
255255
256256
257- def _infer_xy_labels (darray , x , y , imshow = False ):
257+ def _infer_xy_labels_3d (darray , x , y , rgb ):
258+ """
259+ Determine x and y labels for showing RGB images.
260+
261+ Attempts to infer which dimension is RGB/RGBA by size and order of dims.
262+
263+ """
264+ # Start by detecting and reporting invalid combinations of arguments
265+ assert darray .ndim == 3
266+ not_none = [a for a in (x , y , rgb ) if a is not None ]
267+ if len (set (not_none )) < len (not_none ):
268+ raise ValueError ('Dimensions passed as x, y, and rgb must be unique.' )
269+ for label in not_none :
270+ if label not in darray .dims :
271+ raise ValueError ('%r is not a dimension' % (label ,))
272+
273+ # Then calculate rgb dimension if certain and check validity
274+ could_be_color = [label for label in darray .dims
275+ if darray [label ].size in (3 , 4 ) and label not in (x , y )]
276+ if rgb is None and not could_be_color :
277+ raise ValueError (
278+ 'A 3-dimensional array was passed to imshow(), but there is no '
279+ 'dimension that could be color. At least one dimension must be '
280+ 'of size 3 (RGB) or 4 (RGBA), and not given as x or y.' )
281+ if rgb is None and len (could_be_color ) == 1 :
282+ rgb = could_be_color [0 ]
283+ if rgb is not None and darray [rgb ].size not in (3 , 4 ):
284+ raise ValueError ('Cannot interpret dim %r of size %s as RGB or RGBA.'
285+ % (rgb , darray [rgb ].size ))
286+
287+ # If rgb dimension is still unknown, there must be two or three dimensions
288+ # in could_be_color. We therefore warn, and use a heuristic to break ties.
289+ if rgb is None :
290+ assert len (could_be_color ) in (2 , 3 )
291+ if darray .dims [- 1 ] in could_be_color :
292+ rgb = darray .dims [- 1 ]
293+ warnings .warn (
294+ 'Several dimensions of this array could be colors. Xarray '
295+ 'will use the last dimension (%r) to match '
296+ 'matplotlib.pyplot.imshow. You can pass names of x, y, '
297+ 'and/or rgb dimensions to override this guess.' % rgb )
298+ else :
299+ rgb = darray .dims [0 ]
300+ warnings .warn (
301+ '%r has been selected as the color dimension, but %r would '
302+ 'also be valid. Pass names of x, y, and/or rgb dimensions to '
303+ 'override this guess.' % darray .dims [:2 ])
304+ assert rgb is not None
305+
306+ # Finally, we pick out the red slice and delegate to the 2D version:
307+ return _infer_xy_labels (darray .isel (** {rgb : 0 }).squeeze (), x , y )
308+
309+
310+ def _infer_xy_labels (darray , x , y , imshow = False , rgb = None ):
258311 """
259312 Determine x and y labels. For use in _plot2d
260313
261314 darray must be a 2 dimensional data array, or 3d for imshow only.
262315 """
316+ if imshow and darray .ndim == 3 :
317+ return _infer_xy_labels_3d (darray , x , y , rgb )
263318
264319 if x is None and y is None :
265320 if darray .ndim != 2 :
266- if not imshow :
267- raise ValueError ('DataArray must be 2d' )
268- elif darray .ndim != 3 or darray .shape [2 ] not in (3 , 4 ):
269- raise ValueError ('DataArray for imshow must be 2d, MxNx3 for '
270- 'RGB image, or MxNx4 for RGBA image.' )
271- y , x , * _ = darray .dims
321+ raise ValueError ('DataArray must be 2d' )
322+ y , x = darray .dims
272323 elif x is None :
273324 if y not in darray .dims :
274325 raise ValueError ('y must be a dimension name if x is not supplied' )
0 commit comments