Skip to content

Commit 4f0b873

Browse files
committed
Add tests + raise nicer error when asked to plot unsupported types
1 parent 3242534 commit 4f0b873

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

xarray/plot/plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ def newplotfunc(
689689
xplt, xlab_extra = _resolve_intervals_2dplot(xval, plotfunc.__name__)
690690
yplt, ylab_extra = _resolve_intervals_2dplot(yval, plotfunc.__name__)
691691

692-
_ensure_plottable(xplt, yplt)
692+
_ensure_plottable(xplt, yplt, zval)
693693

694694
cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
695695
plotfunc, zval.data, **locals()

xarray/plot/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def _ensure_plottable(*args):
510510
Raise exception if there is anything in args that can't be plotted on an
511511
axis by matplotlib.
512512
"""
513-
numpy_types = [np.floating, np.integer, np.timedelta64, np.datetime64, np.bool8]
513+
numpy_types = [np.floating, np.integer, np.timedelta64, np.datetime64, np.bool_]
514514
other_types = [datetime]
515515
try:
516516
import cftime
@@ -525,10 +525,10 @@ def _ensure_plottable(*args):
525525
or _valid_other_type(np.array(x), other_types)
526526
):
527527
raise TypeError(
528-
"Plotting requires coordinates to be numeric "
529-
"or dates of type np.datetime64, "
528+
"Plotting requires coordinates to be numeric, boolean, "
529+
"or dates of type numpy.datetime64, "
530530
"datetime.datetime, cftime.datetime or "
531-
"pd.Interval."
531+
f"pandas.Interval. Received data of type {np.array(x).dtype} instead."
532532
)
533533
if (
534534
_valid_other_type(np.array(x), cftime_datetime)

xarray/tests/test_plot.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,12 @@ def test1d(self):
138138
with raises_regex(ValueError, "None"):
139139
self.darray[:, 0, 0].plot(x="dim_1")
140140

141+
with raises_regex(TypeError, "complex128"):
142+
(self.darray[:, 0, 0] + 1j).plot()
143+
144+
def test_1d_bool(self):
145+
xr.ones_like(self.darray[:, 0, 0], dtype=np.bool).plot()
146+
141147
def test_1d_x_y_kw(self):
142148
z = np.arange(10)
143149
da = DataArray(np.cos(z), dims=["z"], coords=[z], name="f")
@@ -919,6 +925,13 @@ def test_1d_raises_valueerror(self):
919925
with raises_regex(ValueError, r"DataArray must be 2d"):
920926
self.plotfunc(self.darray[0, :])
921927

928+
def test_bool(self):
929+
xr.ones_like(self.darray, dtype=np.bool).plot()
930+
931+
def test_complex_raises_typeerror(self):
932+
with raises_regex(TypeError, "complex128"):
933+
(self.darray + 1j).plot()
934+
922935
def test_3d_raises_valueerror(self):
923936
a = DataArray(easy_array((2, 3, 4)))
924937
if self.plotfunc.__name__ == "imshow":

0 commit comments

Comments
 (0)