Skip to content

Commit e2c29f6

Browse files
authored
Fix utils.get_axis with kwargs (#7080)
1 parent 226c23b commit e2c29f6

File tree

4 files changed

+83
-19
lines changed

4 files changed

+83
-19
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ Bug fixes
8787
- Allow writing NetCDF files including only dimensionless variables using the distributed or multiprocessing scheduler
8888
(:issue:`7013`, :pull:`7040`).
8989
By `Francesco Nattino <https://github.com/fnattino>`_.
90+
- Fix bug where subplot_kwargs were not working when plotting with figsize, size or aspect (:issue:`7078`, :pull:`7080`)
91+
By `Michael Niklas <https://github.com/headtr1ck>`_.
9092

9193
Documentation
9294
~~~~~~~~~~~~~

xarray/plot/plot.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,8 @@ def plot(
280280
col_wrap : int, optional
281281
Use together with ``col`` to wrap faceted plots.
282282
ax : matplotlib axes object, optional
283-
If ``None``, use the current axes. Not applicable when using facets.
283+
Axes on which to plot. By default, use the current axes.
284+
Mutually exclusive with ``size``, ``figsize`` and facets.
284285
rtol : float, optional
285286
Relative tolerance used to determine if the indexes
286287
are uniformly spaced. Usually a small positive number.

xarray/plot/utils.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131

3232
if TYPE_CHECKING:
33+
from matplotlib.axes import Axes
34+
3335
from ..core.dataarray import DataArray
3436

3537

@@ -423,7 +425,13 @@ def _assert_valid_xy(darray: DataArray, xy: None | Hashable, name: str) -> None:
423425
)
424426

425427

426-
def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):
428+
def get_axis(
429+
figsize: Iterable[float] | None = None,
430+
size: float | None = None,
431+
aspect: float | None = None,
432+
ax: Axes | None = None,
433+
**subplot_kws: Any,
434+
) -> Axes:
427435
try:
428436
import matplotlib as mpl
429437
import matplotlib.pyplot as plt
@@ -435,28 +443,32 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):
435443
raise ValueError("cannot provide both `figsize` and `ax` arguments")
436444
if size is not None:
437445
raise ValueError("cannot provide both `figsize` and `size` arguments")
438-
_, ax = plt.subplots(figsize=figsize)
439-
elif size is not None:
446+
_, ax = plt.subplots(figsize=figsize, subplot_kw=subplot_kws)
447+
return ax
448+
449+
if size is not None:
440450
if ax is not None:
441451
raise ValueError("cannot provide both `size` and `ax` arguments")
442452
if aspect is None:
443453
width, height = mpl.rcParams["figure.figsize"]
444454
aspect = width / height
445455
figsize = (size * aspect, size)
446-
_, ax = plt.subplots(figsize=figsize)
447-
elif aspect is not None:
456+
_, ax = plt.subplots(figsize=figsize, subplot_kw=subplot_kws)
457+
return ax
458+
459+
if aspect is not None:
448460
raise ValueError("cannot provide `aspect` argument without `size`")
449461

450-
if kwargs and ax is not None:
462+
if subplot_kws and ax is not None:
451463
raise ValueError("cannot use subplot_kws with existing ax")
452464

453465
if ax is None:
454-
ax = _maybe_gca(**kwargs)
466+
ax = _maybe_gca(**subplot_kws)
455467

456468
return ax
457469

458470

459-
def _maybe_gca(**kwargs):
471+
def _maybe_gca(**subplot_kws: Any) -> Axes:
460472

461473
import matplotlib.pyplot as plt
462474

@@ -468,7 +480,7 @@ def _maybe_gca(**kwargs):
468480
# can not pass kwargs to active axes
469481
return plt.gca()
470482

471-
return plt.axes(**kwargs)
483+
return plt.axes(**subplot_kws)
472484

473485

474486
def _get_units_from_attrs(da) -> str:

xarray/tests/test_plot.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2955,9 +2955,8 @@ def test_facetgrid_single_contour():
29552955

29562956

29572957
@requires_matplotlib
2958-
def test_get_axis():
2959-
# test get_axis works with different args combinations
2960-
# and return the right type
2958+
def test_get_axis_raises():
2959+
# test get_axis raises an error if trying to do invalid things
29612960

29622961
# cannot provide both ax and figsize
29632962
with pytest.raises(ValueError, match="both `figsize` and `ax`"):
@@ -2975,18 +2974,68 @@ def test_get_axis():
29752974
with pytest.raises(ValueError, match="`aspect` argument without `size`"):
29762975
get_axis(figsize=None, size=None, aspect=4 / 3, ax=None)
29772976

2977+
# cannot provide axis and subplot_kws
2978+
with pytest.raises(ValueError, match="cannot use subplot_kws with existing ax"):
2979+
get_axis(figsize=None, size=None, aspect=None, ax=1, something_else=5)
2980+
2981+
2982+
@requires_matplotlib
2983+
@pytest.mark.parametrize(
2984+
["figsize", "size", "aspect", "ax", "kwargs"],
2985+
[
2986+
pytest.param((3, 2), None, None, False, {}, id="figsize"),
2987+
pytest.param(
2988+
(3.5, 2.5), None, None, False, {"label": "test"}, id="figsize_kwargs"
2989+
),
2990+
pytest.param(None, 5, None, False, {}, id="size"),
2991+
pytest.param(None, 5.5, None, False, {"label": "test"}, id="size_kwargs"),
2992+
pytest.param(None, 5, 1, False, {}, id="size+aspect"),
2993+
pytest.param(None, None, None, True, {}, id="ax"),
2994+
pytest.param(None, None, None, False, {}, id="default"),
2995+
pytest.param(None, None, None, False, {"label": "test"}, id="default_kwargs"),
2996+
],
2997+
)
2998+
def test_get_axis(
2999+
figsize: tuple[float, float] | None,
3000+
size: float | None,
3001+
aspect: float | None,
3002+
ax: bool,
3003+
kwargs: dict[str, Any],
3004+
) -> None:
29783005
with figure_context():
2979-
ax = get_axis()
2980-
assert isinstance(ax, mpl.axes.Axes)
3006+
inp_ax = plt.axes() if ax else None
3007+
out_ax = get_axis(
3008+
figsize=figsize, size=size, aspect=aspect, ax=inp_ax, **kwargs
3009+
)
3010+
assert isinstance(out_ax, mpl.axes.Axes)
29813011

29823012

3013+
@requires_matplotlib
29833014
@requires_cartopy
2984-
def test_get_axis_cartopy():
2985-
3015+
@pytest.mark.parametrize(
3016+
["figsize", "size", "aspect"],
3017+
[
3018+
pytest.param((3, 2), None, None, id="figsize"),
3019+
pytest.param(None, 5, None, id="size"),
3020+
pytest.param(None, 5, 1, id="size+aspect"),
3021+
pytest.param(None, None, None, id="default"),
3022+
],
3023+
)
3024+
def test_get_axis_cartopy(
3025+
figsize: tuple[float, float] | None, size: float | None, aspect: float | None
3026+
) -> None:
29863027
kwargs = {"projection": cartopy.crs.PlateCarree()}
29873028
with figure_context():
2988-
ax = get_axis(**kwargs)
2989-
assert isinstance(ax, cartopy.mpl.geoaxes.GeoAxesSubplot)
3029+
out_ax = get_axis(figsize=figsize, size=size, aspect=aspect, **kwargs)
3030+
assert isinstance(out_ax, cartopy.mpl.geoaxes.GeoAxesSubplot)
3031+
3032+
3033+
@requires_matplotlib
3034+
def test_get_axis_current() -> None:
3035+
with figure_context():
3036+
_, ax = plt.subplots()
3037+
out_ax = get_axis()
3038+
assert ax is out_ax
29903039

29913040

29923041
@requires_matplotlib

0 commit comments

Comments
 (0)