Skip to content

Commit 0f72418

Browse files
authored
seaborn: complete and fix axisgrid module (#11096)
1 parent fe48f37 commit 0f72418

File tree

4 files changed

+244
-49
lines changed

4 files changed

+244
-49
lines changed

stubs/seaborn/@tests/stubtest_allowlist.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@ seaborn.external.docscrape.ClassDoc.__init__ # stubtest doesn't like ABC class
33
seaborn.external.docscrape.NumpyDocString.__str__ # weird signature
44

55
seaborn(\.regression)?\.lmplot # the `data` argument is required but it defaults to `None` at runtime
6+
7+
seaborn.axisgrid.Grid.tight_layout # the method doesn't really take pos args but runtime has *args

stubs/seaborn/seaborn/_core/typing.pyi

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
from _typeshed import Incomplete
2-
from collections.abc import Iterable
2+
from collections.abc import Iterable, Mapping
33
from datetime import date, datetime, timedelta
4-
from typing import Any
4+
from typing import Any, Protocol
55
from typing_extensions import TypeAlias
66

77
from matplotlib.colors import Colormap, Normalize
88
from numpy import ndarray
9-
from pandas import Index, Series, Timedelta, Timestamp
9+
from pandas import DataFrame, Index, Series, Timedelta, Timestamp
10+
11+
class _SupportsDataFrame(Protocol):
12+
# `__dataframe__` should return pandas.core.interchange.dataframe_protocol.DataFrame
13+
# but this class needs to be defined as a Protocol, not as an ABC.
14+
def __dataframe__(self, nan_as_null: bool = ..., allow_copy: bool = ...) -> Incomplete: ...
1015

1116
ColumnName: TypeAlias = str | bytes | date | datetime | timedelta | bool | complex | Timestamp | Timedelta
1217
Vector: TypeAlias = Series[Any] | Index[Any] | ndarray[Any, Any]
1318
VariableSpec: TypeAlias = ColumnName | Vector | None
1419
VariableSpecList: TypeAlias = list[VariableSpec] | Index[Any] | None
15-
DataSource: TypeAlias = Incomplete
20+
DataSource: TypeAlias = DataFrame | _SupportsDataFrame | Mapping[ColumnName, Incomplete] | None
1621
OrderSpec: TypeAlias = Iterable[str] | None
1722
NormSpec: TypeAlias = tuple[float | None, float | None] | Normalize | None
1823
PaletteSpec: TypeAlias = str | list[Incomplete] | dict[Incomplete, Incomplete] | Colormap | None

stubs/seaborn/seaborn/axisgrid.pyi

Lines changed: 201 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,207 @@
1+
import os
12
from _typeshed import Incomplete
23
from collections.abc import Callable, Generator, Iterable, Mapping
3-
from typing import Any, TypeVar
4-
from typing_extensions import Concatenate, Literal, ParamSpec, Self
4+
from typing import IO, Any, TypeVar
5+
from typing_extensions import Concatenate, Literal, ParamSpec, Self, TypeAlias
56

7+
import numpy as np
68
from matplotlib.artist import Artist
79
from matplotlib.axes import Axes
8-
from matplotlib.colors import Colormap
10+
from matplotlib.backend_bases import MouseEvent, RendererBase
11+
from matplotlib.colors import Colormap, Normalize
912
from matplotlib.figure import Figure
13+
from matplotlib.font_manager import FontProperties
14+
from matplotlib.gridspec import SubplotSpec
1015
from matplotlib.legend import Legend
16+
from matplotlib.patches import Patch
17+
from matplotlib.path import Path as mpl_Path
18+
from matplotlib.patheffects import AbstractPathEffect
19+
from matplotlib.scale import ScaleBase
1120
from matplotlib.text import Text
12-
from matplotlib.typing import ColorType
13-
from numpy.typing import NDArray
21+
from matplotlib.transforms import Bbox, BboxBase, Transform, TransformedPath
22+
from matplotlib.typing import ColorType, LineStyleType, MarkerType
23+
from numpy.typing import ArrayLike, NDArray
1424
from pandas import DataFrame, Series
1525

26+
from ._core.typing import ColumnName, DataSource, _SupportsDataFrame
1627
from .palettes import _RGBColorPalette
17-
from .utils import _Palette
28+
from .utils import _DataSourceWideForm, _Palette, _Vector
1829

1930
__all__ = ["FacetGrid", "PairGrid", "JointGrid", "pairplot", "jointplot"]
2031

2132
_P = ParamSpec("_P")
2233
_R = TypeVar("_R")
2334

35+
_LiteralFont: TypeAlias = Literal["xx-small", "x-small", "small", "medium", "large", "x-large", "xx-large"]
36+
2437
class _BaseGrid:
25-
def set(self, **kwargs: Incomplete) -> Self: ... # **kwargs are passed to `matplotlib.axes.Axes.set`
38+
def set(
39+
self,
40+
*,
41+
# Keywords follow `matplotlib.axes.Axes.set`. Each keyword <KW> corresponds to a `set_<KW>` method
42+
adjustable: Literal["box", "datalim"] = ...,
43+
agg_filter: Callable[[ArrayLike, float], tuple[NDArray[np.floating[Any]], float, float]] | None = ...,
44+
alpha: float | None = ...,
45+
anchor: str | tuple[float, float] = ...,
46+
animated: bool = ...,
47+
aspect: float | Literal["auto", "equal"] = ...,
48+
autoscale_on: bool = ...,
49+
autoscalex_on: bool = ...,
50+
autoscaley_on: bool = ...,
51+
axes_locator: Callable[[Axes, RendererBase], Bbox] = ...,
52+
axisbelow: bool | Literal["line"] = ...,
53+
box_aspect: float | None = ...,
54+
clip_box: BboxBase | None = ...,
55+
clip_on: bool = ...,
56+
clip_path: Patch | mpl_Path | TransformedPath | None = ...,
57+
facecolor: ColorType | None = ...,
58+
frame_on: bool = ...,
59+
gid: str | None = ...,
60+
in_layout: bool = ...,
61+
label: object = ...,
62+
mouseover: bool = ...,
63+
navigate: bool = ...,
64+
path_effects: list[AbstractPathEffect] = ...,
65+
picker: bool | float | Callable[[Artist, MouseEvent], tuple[bool, dict[Any, Any]]] | None = ...,
66+
position: Bbox | tuple[float, float, float, float] = ...,
67+
prop_cycle: Incomplete = ..., # TODO: use cycler.Cycler when cycler gets typed
68+
rasterization_zorder: float | None = ...,
69+
rasterized: bool = ...,
70+
sketch_params: float | None = ...,
71+
snap: bool | None = ...,
72+
subplotspec: SubplotSpec = ...,
73+
title: str = ...,
74+
transform: Transform | None = ...,
75+
url: str | None = ...,
76+
visible: bool = ...,
77+
xbound: float | None | tuple[float | None, float | None] = ...,
78+
xlabel: str = ...,
79+
xlim: float | None | tuple[float | None, float | None] = ...,
80+
xmargin: float = ...,
81+
xscale: str | ScaleBase = ...,
82+
xticklabels: Iterable[str | Text] = ...,
83+
xticks: ArrayLike = ...,
84+
ybound: float | None | tuple[float | None, float | None] = ...,
85+
ylabel: str = ...,
86+
ylim: float | None | tuple[float | None, float | None] = ...,
87+
ymargin: float = ...,
88+
yscale: str | ScaleBase = ...,
89+
yticklabels: Iterable[str | Text] = ...,
90+
yticks: ArrayLike = ...,
91+
zorder: float = ...,
92+
**kwargs: Any,
93+
) -> Self: ...
2694
@property
2795
def fig(self) -> Figure: ...
2896
@property
2997
def figure(self) -> Figure: ...
3098
def apply(self, func: Callable[Concatenate[Self, _P], object], *args: _P.args, **kwargs: _P.kwargs) -> Self: ...
3199
def pipe(self, func: Callable[Concatenate[Self, _P], _R], *args: _P.args, **kwargs: _P.kwargs) -> _R: ...
32100
def savefig(
33-
self, *args: Incomplete, **kwargs: Incomplete
34-
) -> None: ... # *args and **kwargs are passed to `matplotlib.figure.Figure.savefig`
101+
self,
102+
# Signature follows `matplotlib.figure.Figure.savefig`
103+
fname: str | os.PathLike[Any] | IO[Any],
104+
*,
105+
transparent: bool | None = None,
106+
dpi: float | Literal["figure"] | None = 96,
107+
facecolor: ColorType | Literal["auto"] | None = "auto",
108+
edgecolor: ColorType | Literal["auto"] | None = "auto",
109+
orientation: Literal["landscape", "portrait"] = "portrait",
110+
format: str | None = None,
111+
bbox_inches: Literal["tight"] | Bbox | None = "tight",
112+
pad_inches: float | Literal["layout"] | None = None,
113+
backend: str | None = None,
114+
**kwargs: Any,
115+
) -> None: ...
35116

36117
class Grid(_BaseGrid):
37118
def __init__(self) -> None: ...
38119
def tight_layout(
39-
self, *args: Incomplete, **kwargs: Incomplete
40-
) -> Self: ... # *args and **kwargs are passed to `matplotlib.figure.Figure.tight_layout`
120+
self,
121+
*,
122+
# Keywords follow `matplotlib.figure.Figure.tight_layout`
123+
pad: float = 1.08,
124+
h_pad: float | None = None,
125+
w_pad: float | None = None,
126+
rect: tuple[float, float, float, float] | None = None,
127+
) -> Self: ...
41128
def add_legend(
42129
self,
43-
legend_data: Mapping[Any, Artist] | None = None, # cannot use precise key type because of invariant Mapping keys
130+
# Cannot use precise key type with union for legend_data because of invariant Mapping keys
131+
legend_data: Mapping[Any, Artist] | None = None,
44132
title: str | None = None,
45133
label_order: list[str] | None = None,
46134
adjust_subtitles: bool = False,
47-
**kwargs: Incomplete, # **kwargs are passed to `matplotlib.figure.Figure.legend`
135+
*,
136+
# Keywords follow `matplotlib.legend.Legend`
137+
loc: str | int | tuple[float, float] | None = None,
138+
numpoints: int | None = None,
139+
markerscale: float | None = None,
140+
markerfirst: bool = True,
141+
reverse: bool = False,
142+
scatterpoints: int | None = None,
143+
scatteryoffsets: Iterable[float] | None = None,
144+
prop: FontProperties | dict[str, Any] | None = None,
145+
fontsize: int | _LiteralFont | None = None,
146+
labelcolor: str | Iterable[str] | None = None,
147+
borderpad: float | None = None,
148+
labelspacing: float | None = None,
149+
handlelength: float | None = None,
150+
handleheight: float | None = None,
151+
handletextpad: float | None = None,
152+
borderaxespad: float | None = None,
153+
columnspacing: float | None = None,
154+
ncols: int = 1,
155+
mode: Literal["expand"] | None = None,
156+
fancybox: bool | None = None,
157+
shadow: bool | dict[str, float] | None = None,
158+
title_fontsize: int | _LiteralFont | None = None,
159+
framealpha: float | None = None,
160+
edgecolor: ColorType | None = None,
161+
facecolor: ColorType | None = None,
162+
bbox_to_anchor: BboxBase | tuple[float, float] | tuple[float, float, float, float] | None = None,
163+
bbox_transform: Transform | None = None,
164+
frameon: bool | None = None,
165+
handler_map: None = None,
166+
title_fontproperties: FontProperties | None = None,
167+
alignment: Literal["center", "left", "right"] = "center",
168+
ncol: int = 1,
169+
draggable: bool = False,
48170
) -> Self: ...
49171
@property
50172
def legend(self) -> Legend | None: ...
51173
def tick_params(
52-
self, axis: Literal["x", "y", "both"] = "both", **kwargs: Incomplete
53-
) -> Self: ... # **kwargs are passed to `matplotlib.axes.Axes.tick_params`
174+
self,
175+
axis: Literal["x", "y", "both"] = "both",
176+
*,
177+
# Keywords follow `matplotlib.axes.Axes.tick_params`
178+
which: Literal["major", "minor", "both"] = "major",
179+
reset: bool = False,
180+
direction: Literal["in", "out", "inout"] = ...,
181+
length: float = ...,
182+
width: float = ...,
183+
color: ColorType = ...,
184+
pad: float = ...,
185+
labelsize: float | str = ...,
186+
labelcolor: ColorType = ...,
187+
labelfontfamily: str = ...,
188+
colors: ColorType = ...,
189+
zorder: float = ...,
190+
bottom: bool = ...,
191+
top: bool = ...,
192+
left: bool = ...,
193+
right: bool = ...,
194+
labelbottom: bool = ...,
195+
labeltop: bool = ...,
196+
labelleft: bool = ...,
197+
labelright: bool = ...,
198+
labelrotation: float = ...,
199+
grid_color: ColorType = ...,
200+
grid_alpha: float = ...,
201+
grid_linewidth: float = ...,
202+
grid_linestyle: str = ...,
203+
**kwargs: Any,
204+
) -> Self: ...
54205

55206
class FacetGrid(Grid):
56207
data: DataFrame
@@ -60,7 +211,7 @@ class FacetGrid(Grid):
60211
hue_kws: dict[str, Any]
61212
def __init__(
62213
self,
63-
data: DataFrame,
214+
data: DataFrame | _SupportsDataFrame,
64215
*,
65216
row: str | None = None,
66217
col: str | None = None,
@@ -88,10 +239,10 @@ class FacetGrid(Grid):
88239
def map(self, func: Callable[..., object], *args: str, **kwargs: Any) -> Self: ...
89240
def map_dataframe(self, func: Callable[..., object], *args: str, **kwargs: Any) -> Self: ...
90241
def facet_axis(self, row_i: int, col_j: int, modify_state: bool = True) -> Axes: ...
242+
# `despine` should be kept roughly in line with `seaborn.utils.despine`
91243
def despine(
92244
self,
93245
*,
94-
fig: Figure | None = None,
95246
ax: Axes | None = None,
96247
top: bool = True,
97248
right: bool = True,
@@ -111,7 +262,13 @@ class FacetGrid(Grid):
111262
self, template: str | None = None, row_template: str | None = None, col_template: str | None = None, **kwargs: Any
112263
) -> Self: ...
113264
def refline(
114-
self, *, x: float | None = None, y: float | None = None, color: ColorType = ".5", linestyle: str = "--", **line_kws: Any
265+
self,
266+
*,
267+
x: float | None = None,
268+
y: float | None = None,
269+
color: ColorType = ".5",
270+
linestyle: LineStyleType = "--",
271+
**line_kws: Any,
115272
) -> Self: ...
116273
@property
117274
def axes(self) -> NDArray[Incomplete]: ... # array of `Axes`
@@ -127,15 +284,15 @@ class PairGrid(Grid):
127284
axes: NDArray[Incomplete] # two-dimensional array of `Axes`
128285
data: DataFrame
129286
diag_sharey: bool
130-
diag_vars: NDArray[Incomplete] | None # array of `str`
131-
diag_axes: NDArray[Incomplete] | None # array of `Axes`
287+
diag_vars: list[str] | None
288+
diag_axes: list[Axes] | None
132289
hue_names: list[str]
133-
hue_vals: Series[Incomplete]
290+
hue_vals: Series[Any]
134291
hue_kws: dict[str, Any]
135292
palette: _RGBColorPalette
136293
def __init__(
137294
self,
138-
data: DataFrame,
295+
data: DataFrame | _SupportsDataFrame,
139296
*,
140297
hue: str | None = None,
141298
vars: Iterable[str] | None = None,
@@ -162,25 +319,25 @@ class JointGrid(_BaseGrid):
162319
ax_joint: Axes
163320
ax_marg_x: Axes
164321
ax_marg_y: Axes
165-
x: Series[Incomplete]
166-
y: Series[Incomplete]
167-
hue: Series[Incomplete]
322+
x: Series[Any]
323+
y: Series[Any]
324+
hue: Series[Any]
168325
def __init__(
169326
self,
170-
data: Incomplete | None = None,
327+
data: DataSource | _DataSourceWideForm | None = None,
171328
*,
172-
x: Incomplete | None = None,
173-
y: Incomplete | None = None,
174-
hue: Incomplete | None = None,
329+
x: ColumnName | _Vector | None = None,
330+
y: ColumnName | _Vector | None = None,
331+
hue: ColumnName | _Vector | None = None,
175332
height: float = 6,
176333
ratio: float = 5,
177334
space: float = 0.2,
178335
palette: _Palette | Colormap | None = None,
179-
hue_order: Iterable[str] | None = None,
180-
hue_norm: Incomplete | None = None,
336+
hue_order: Iterable[ColumnName] | None = None,
337+
hue_norm: tuple[float, float] | Normalize | None = None,
181338
dropna: bool = False,
182-
xlim: Incomplete | None = None,
183-
ylim: Incomplete | None = None,
339+
xlim: float | tuple[float, float] | None = None,
340+
ylim: float | tuple[float, float] | None = None,
184341
marginal_ticks: bool = False,
185342
) -> None: ...
186343
def plot(self, joint_func: Callable[..., object], marginal_func: Callable[..., object], **kwargs: Any) -> Self: ...
@@ -194,7 +351,7 @@ class JointGrid(_BaseGrid):
194351
joint: bool = True,
195352
marginal: bool = True,
196353
color: ColorType = ".5",
197-
linestyle: str = "--",
354+
linestyle: LineStyleType = "--",
198355
**line_kws: Any,
199356
) -> Self: ...
200357
def set_axis_labels(self, xlabel: str = "", ylabel: str = "", **kwargs: Any) -> Self: ...
@@ -210,7 +367,7 @@ def pairplot(
210367
y_vars: Iterable[str] | str | None = None,
211368
kind: Literal["scatter", "kde", "hist", "reg"] = "scatter",
212369
diag_kind: Literal["auto", "hist", "kde"] | None = "auto",
213-
markers: Incomplete | None = None,
370+
markers: MarkerType | list[MarkerType] | None = None,
214371
height: float = 2.5,
215372
aspect: float = 1,
216373
corner: bool = False,
@@ -221,22 +378,22 @@ def pairplot(
221378
size: float | None = None, # deprecated
222379
) -> PairGrid: ...
223380
def jointplot(
224-
data: Incomplete | None = None,
381+
data: DataSource | _DataSourceWideForm | None = None,
225382
*,
226-
x: Incomplete | None = None,
227-
y: Incomplete | None = None,
228-
hue: Incomplete | None = None,
229-
kind: str = "scatter", # ideally Literal["scatter", "kde", "hist", "hex", "reg", "resid"] but it is checked with startswith
383+
x: ColumnName | _Vector | None = None,
384+
y: ColumnName | _Vector | None = None,
385+
hue: ColumnName | _Vector | None = None,
386+
kind: Literal["scatter", "kde", "hist", "hex", "reg", "resid"] = "scatter",
230387
height: float = 6,
231388
ratio: float = 5,
232389
space: float = 0.2,
233390
dropna: bool = False,
234-
xlim: Incomplete | None = None,
235-
ylim: Incomplete | None = None,
391+
xlim: float | tuple[float, float] | None = None,
392+
ylim: float | tuple[float, float] | None = None,
236393
color: ColorType | None = None,
237394
palette: _Palette | Colormap | None = None,
238-
hue_order: Iterable[str] | None = None,
239-
hue_norm: Incomplete | None = None,
395+
hue_order: Iterable[ColumnName] | None = None,
396+
hue_norm: tuple[float, float] | Normalize | None = None,
240397
marginal_ticks: bool = False,
241398
joint_kws: dict[str, Any] | None = None,
242399
marginal_kws: dict[str, Any] | None = None,

0 commit comments

Comments
 (0)