Skip to content

Commit d4452bc

Browse files
caenrigenIllviljan
authored andcommitted
Display coords' units for slice plots (pydata#5847)
Co-authored-by: Illviljan <[email protected]>
1 parent c8a8ef6 commit d4452bc

File tree

5 files changed

+89
-23
lines changed

5 files changed

+89
-23
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ New Features
3838
`Nathan Lis <https://github.com/wxman22>`_.
3939
- Histogram plots are set with a title displaying the scalar coords if any, similarly to the other plots (:issue:`5791`, :pull:`5792`).
4040
By `Maxime Liquet <https://github.com/maximlt>`_.
41+
- Slice plots display the coords units in the same way as x/y/colorbar labels (:pull:`5847`).
42+
By `Victor Negîrneac <https://github.com/caenrigen>`_.
4143
- Added a new :py:attr:`Dataset.chunksizes`, :py:attr:`DataArray.chunksizes`, and :py:attr:`Variable.chunksizes`
4244
property, which will always return a mapping from dimension names to chunking pattern along that dimension,
4345
regardless of whether the object is a Dataset, DataArray, or Variable. (:issue:`5846`, :pull:`5900`)

xarray/core/dataarray.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import pandas as pd
2323

2424
from ..plot.plot import _PlotMethods
25+
from ..plot.utils import _get_units_from_attrs
2526
from . import (
2627
computation,
2728
dtypes,
@@ -3134,7 +3135,11 @@ def _title_for_slice(self, truncate: int = 50) -> str:
31343135
for dim, coord in self.coords.items():
31353136
if coord.size == 1:
31363137
one_dims.append(
3137-
"{dim} = {v}".format(dim=dim, v=format_item(coord.values))
3138+
"{dim} = {v}{unit}".format(
3139+
dim=dim,
3140+
v=format_item(coord.values),
3141+
unit=_get_units_from_attrs(coord),
3142+
)
31383143
)
31393144

31403145
title = ", ".join(one_dims)

xarray/core/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
"""Internal utilties; not for external use
2-
"""
1+
"""Internal utilities; not for external use"""
32
import contextlib
43
import functools
54
import io

xarray/plot/utils.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,21 @@ def _maybe_gca(**kwargs):
467467
return plt.axes(**kwargs)
468468

469469

470+
def _get_units_from_attrs(da):
471+
"""Extracts and formats the unit/units from a attributes."""
472+
pint_array_type = DuckArrayModule("pint").type
473+
units = " [{}]"
474+
if isinstance(da.data, pint_array_type):
475+
units = units.format(str(da.data.units))
476+
elif da.attrs.get("units"):
477+
units = units.format(da.attrs["units"])
478+
elif da.attrs.get("unit"):
479+
units = units.format(da.attrs["unit"])
480+
else:
481+
units = ""
482+
return units
483+
484+
470485
def label_from_attrs(da, extra=""):
471486
"""Makes informative labels if variable metadata (attrs) follows
472487
CF conventions."""
@@ -480,20 +495,7 @@ def label_from_attrs(da, extra=""):
480495
else:
481496
name = ""
482497

483-
def _get_units_from_attrs(da):
484-
if da.attrs.get("units"):
485-
units = " [{}]".format(da.attrs["units"])
486-
elif da.attrs.get("unit"):
487-
units = " [{}]".format(da.attrs["unit"])
488-
else:
489-
units = ""
490-
return units
491-
492-
pint_array_type = DuckArrayModule("pint").type
493-
if isinstance(da.data, pint_array_type):
494-
units = " [{}]".format(str(da.data.units))
495-
else:
496-
units = _get_units_from_attrs(da)
498+
units = _get_units_from_attrs(da)
497499

498500
# Treat `name` differently if it's a latex sequence
499501
if name.startswith("$") and (name.count("$") % 2 == 0):

xarray/tests/test_units.py

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5600,19 +5600,77 @@ def test_duck_array_ops(self):
56005600

56015601
@requires_matplotlib
56025602
class TestPlots(PlotTestCase):
5603-
def test_units_in_line_plot_labels(self):
5603+
@pytest.mark.parametrize(
5604+
"coord_unit, coord_attrs",
5605+
[
5606+
(1, {"units": "meter"}),
5607+
pytest.param(
5608+
unit_registry.m,
5609+
{},
5610+
marks=pytest.mark.xfail(reason="indexes don't support units"),
5611+
),
5612+
],
5613+
)
5614+
def test_units_in_line_plot_labels(self, coord_unit, coord_attrs):
56045615
arr = np.linspace(1, 10, 3) * unit_registry.Pa
5605-
# TODO make coord a Quantity once unit-aware indexes supported
5606-
x_coord = xr.DataArray(
5607-
np.linspace(1, 3, 3), dims="x", attrs={"units": "meters"}
5608-
)
5616+
coord_arr = np.linspace(1, 3, 3) * coord_unit
5617+
x_coord = xr.DataArray(coord_arr, dims="x", attrs=coord_attrs)
56095618
da = xr.DataArray(data=arr, dims="x", coords={"x": x_coord}, name="pressure")
56105619

56115620
da.plot.line()
56125621

56135622
ax = plt.gca()
56145623
assert ax.get_ylabel() == "pressure [pascal]"
5615-
assert ax.get_xlabel() == "x [meters]"
5624+
assert ax.get_xlabel() == "x [meter]"
5625+
5626+
@pytest.mark.parametrize(
5627+
"coord_unit, coord_attrs",
5628+
[
5629+
(1, {"units": "meter"}),
5630+
pytest.param(
5631+
unit_registry.m,
5632+
{},
5633+
marks=pytest.mark.xfail(reason="indexes don't support units"),
5634+
),
5635+
],
5636+
)
5637+
def test_units_in_slice_line_plot_labels_sel(self, coord_unit, coord_attrs):
5638+
arr = xr.DataArray(
5639+
name="var_a",
5640+
data=np.array([[1, 2], [3, 4]]),
5641+
coords=dict(
5642+
a=("a", np.array([5, 6]) * coord_unit, coord_attrs),
5643+
b=("b", np.array([7, 8]) * coord_unit, coord_attrs),
5644+
),
5645+
dims=("a", "b"),
5646+
)
5647+
arr.sel(a=5).plot(marker="o")
5648+
5649+
assert plt.gca().get_title() == "a = 5 [meter]"
5650+
5651+
@pytest.mark.parametrize(
5652+
"coord_unit, coord_attrs",
5653+
[
5654+
(1, {"units": "meter"}),
5655+
pytest.param(
5656+
unit_registry.m,
5657+
{},
5658+
marks=pytest.mark.xfail(reason="pint.errors.UnitStrippedWarning"),
5659+
),
5660+
],
5661+
)
5662+
def test_units_in_slice_line_plot_labels_isel(self, coord_unit, coord_attrs):
5663+
arr = xr.DataArray(
5664+
name="var_a",
5665+
data=np.array([[1, 2], [3, 4]]),
5666+
coords=dict(
5667+
a=("x", np.array([5, 6]) * coord_unit, coord_attrs),
5668+
b=("y", np.array([7, 8])),
5669+
),
5670+
dims=("x", "y"),
5671+
)
5672+
arr.isel(x=0).plot(marker="o")
5673+
assert plt.gca().get_title() == "a = 5 [meter]"
56165674

56175675
def test_units_in_2d_plot_colorbar_label(self):
56185676
arr = np.ones((2, 3)) * unit_registry.Pa

0 commit comments

Comments
 (0)