Skip to content

Commit 7f8ab82

Browse files
auguste-probablSharkyii
authored andcommitted
chore(skore/estimator-report): Clean some parts of the metrics accessor (probabl-ai#2331)
In preparation for a larger refactor of `_MetricsAccessor`
1 parent 6bcebdd commit 7f8ab82

16 files changed

+2161
-704
lines changed

reproduce_output.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
3+
Plotting coefficients...
4+
Done.

reproduce_warning.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from sklearn.datasets import load_iris
2+
from sklearn.pipeline import Pipeline
3+
from sklearn.preprocessing import PolynomialFeatures
4+
from sklearn.linear_model import LogisticRegression
5+
from skore import ComparisonReport, CrossValidationReport
6+
import matplotlib.pyplot as plt
7+
import warnings
8+
9+
# Ensure warnings are shown
10+
warnings.simplefilter("always")
11+
12+
X, y = load_iris(return_X_y=True)
13+
estimator_1 = LogisticRegression(max_iter=10000, random_state=42)
14+
estimator_2 = Pipeline(
15+
[
16+
("poly", PolynomialFeatures()),
17+
("predictor", LogisticRegression(max_iter=10000, random_state=0)),
18+
]
19+
)
20+
report = ComparisonReport(
21+
[CrossValidationReport(estimator_1, X, y), CrossValidationReport(estimator_2, X, y)]
22+
)
23+
display = report.inspection.coefficients()
24+
print("Plotting coefficients...")
25+
display.plot()
26+
plt.close('all')
27+
print("Done.")

skore/src/skore/_sklearn/_plot/inspection/coefficients.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def plot(
288288
subplot_by: Literal["auto", "estimator", "label", "output"] | None = "auto",
289289
select_k: int | None = None,
290290
sorting_order: Literal["descending", "ascending", None] = None,
291-
) -> None:
291+
) -> Any:
292292
"""Plot the coefficients for the different features.
293293
294294
Parameters
@@ -354,7 +354,7 @@ def _plot_matplotlib(
354354
subplot_by: Literal["estimator", "label", "output"] | None = None,
355355
select_k: int | None = None,
356356
sorting_order: Literal["descending", "ascending", None] = None,
357-
) -> None:
357+
) -> Any:
358358
"""Dispatch the plotting function for matplotlib backend."""
359359
frame = self.frame(
360360
include_intercept=include_intercept,
@@ -409,9 +409,9 @@ def _categorical_plot(
409409
barplot_kwargs: dict[str, Any] | None = None,
410410
boxplot_kwargs: dict[str, Any] | None = None,
411411
stripplot_kwargs: dict[str, Any] | None = None,
412-
):
412+
) -> Any:
413413
if "estimator" in report_type:
414-
self.facet_ = sns.catplot(
414+
facet_ = sns.catplot(
415415
data=frame,
416416
x="coefficient",
417417
y="feature",
@@ -421,7 +421,7 @@ def _categorical_plot(
421421
**barplot_kwargs,
422422
)
423423
else: # "cross-validation" in report_type
424-
self.facet_ = sns.catplot(
424+
facet_ = sns.catplot(
425425
data=frame,
426426
x="coefficient",
427427
y="feature",
@@ -451,14 +451,15 @@ def _categorical_plot(
451451
)
452452
for ax, n_feature in zip(self.ax_.flatten(), n_features, strict=True):
453453
_decorate_matplotlib_axis(
454-
ax=ax,
454+
ax=axis,
455455
add_background_features=add_background_features,
456456
n_features=n_feature,
457457
xlabel="Magnitude of coefficient",
458458
ylabel="",
459459
)
460-
if len(self.ax_.flatten()) == 1:
461-
self.ax_ = self.ax_.flatten()[0]
460+
if len(ax_.flatten()) == 1:
461+
ax_ = ax_.flatten()[0]
462+
return facet_
462463

463464
def _plot_single_estimator(
464465
self,
@@ -470,7 +471,7 @@ def _plot_single_estimator(
470471
barplot_kwargs: dict[str, Any],
471472
boxplot_kwargs: dict[str, Any],
472473
stripplot_kwargs: dict[str, Any],
473-
) -> None:
474+
) -> Any:
474475
"""Plot the coefficients for an `EstimatorReport` or a `CrossValidationReport`.
475476
476477
An `EstimatorReport` will use a bar plot while a `CrossValidationReport` will
@@ -528,7 +529,7 @@ def _plot_single_estimator(
528529
barplot_kwargs.pop("palette", None)
529530
stripplot_kwargs.pop("palette", None)
530531

531-
self._categorical_plot(
532+
facet_ = self._categorical_plot(
532533
frame=frame,
533534
report_type=report_type,
534535
hue=hue,
@@ -541,7 +542,8 @@ def _plot_single_estimator(
541542
title = f"Coefficients of {estimator_name}"
542543
if subplot_by is not None:
543544
title += f" by {subplot_by}"
544-
self.figure_.suptitle(title)
545+
facet_.figure.suptitle(title)
546+
return facet_
545547

546548
@staticmethod
547549
def _has_same_features(*, frame: pd.DataFrame) -> bool:
@@ -565,7 +567,7 @@ def _plot_comparison(
565567
barplot_kwargs: dict[str, Any],
566568
boxplot_kwargs: dict[str, Any],
567569
stripplot_kwargs: dict[str, Any],
568-
) -> None:
570+
) -> Any:
569571
"""Plot the coefficients for a `ComparisonReport`.
570572
571573
Parameters
@@ -657,7 +659,7 @@ def _plot_comparison(
657659
"different axis using `subplot_by='estimator'`."
658660
)
659661

660-
self._categorical_plot(
662+
facet_ = self._categorical_plot(
661663
frame=frame,
662664
report_type=report_type,
663665
hue=hue,
@@ -670,7 +672,8 @@ def _plot_comparison(
670672
title = "Coefficients"
671673
if subplot_by is not None:
672674
title += f" by {subplot_by}"
673-
self.figure_.suptitle(title)
675+
facet_.figure.suptitle(title)
676+
return facet_
674677

675678
@classmethod
676679
def _compute_data_for_display(

skore/src/skore/_sklearn/_plot/inspection/impurity_decrease.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,6 @@ class ImpurityDecreaseDisplay(DisplayMixin):
3030
"comparison-cross-validation"}
3131
Report type from which the display is created.
3232
33-
Attributes
34-
----------
35-
ax_ : matplotlib Axes
36-
Matplotlib Axes with the plot.
37-
38-
facet_ : seaborn FacetGrid
39-
FacetGrid containing the plot.
40-
41-
figure_ : matplotlib Figure
42-
Figure containing the plot.
43-
4433
Examples
4534
--------
4635
>>> from sklearn.datasets import load_iris
@@ -209,7 +198,7 @@ def plot(self) -> None:
209198
"""
210199
return self._plot()
211200

212-
def _plot_matplotlib(self) -> None:
201+
def _plot_matplotlib(self) -> Any:
213202
"""Dispatch the plotting function for matplotlib backend.
214203
215204
This method creates a bar plot showing the mean decrease in impurity for each
@@ -221,7 +210,7 @@ def _plot_matplotlib(self) -> None:
221210
boxplot_kwargs = self._default_boxplot_kwargs.copy()
222211
frame = self.frame()
223212

224-
self._plot_single_estimator(
213+
return self._plot_single_estimator(
225214
frame=frame,
226215
estimator_name=self.importances["estimator"].unique()[0],
227216
report_type=self.report_type,
@@ -300,7 +289,7 @@ def _plot_single_estimator(
300289
self.figure_, self.ax_ = self.facet_.figure, self.facet_.axes.squeeze()
301290
self.ax_ = self.ax_[()] # 0-d array
302291
_decorate_matplotlib_axis(
303-
ax=self.ax_,
292+
ax=ax_,
304293
add_background_features=False,
305294
n_features=frame["feature"].nunique(),
306295
xlabel="Mean Decrease in Impurity (MDI)",

skore/src/skore/_sklearn/_plot/inspection/permutation_importance.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,6 @@ class PermutationImportanceDisplay(DisplayMixin):
3434
report_type : {"estimator"}
3535
Report type from which the display is created.
3636
37-
Attributes
38-
----------
39-
facet_ : seaborn FacetGrid
40-
FacetGrid containing the permutation importance.
41-
42-
figure_ : matplotlib Figure
43-
Figure containing the permutation importance.
44-
45-
ax_ : matplotlib Axes
46-
Axes with permutation importance.
4737
"""
4838

4939
_default_boxplot_kwargs: dict[str, Any] = {
@@ -166,7 +156,7 @@ def plot(
166156
*,
167157
subplot_by: str | tuple[str, str] | None = "auto",
168158
metric: str | list[str] | None = None,
169-
) -> None:
159+
) -> Any:
170160
"""Plot the permutation importance.
171161
172162
Parameters
@@ -197,7 +187,7 @@ def _plot_matplotlib(
197187
*,
198188
subplot_by: str | tuple[str, str] | None = "auto",
199189
metric: str | list[str] | None = None,
200-
) -> None:
190+
) -> Any:
201191
"""Dispatch the plotting function for matplotlib backend."""
202192
boxplot_kwargs = self._default_boxplot_kwargs.copy()
203193
stripplot_kwargs = self._default_stripplot_kwargs.copy()
@@ -213,7 +203,7 @@ def _plot_matplotlib(
213203
elif "output" in frame.columns and frame["output"].isna().any():
214204
raise ValueError(err_msg.format("outputs"))
215205

216-
self._plot_single_estimator(
206+
return self._plot_single_estimator(
217207
subplot_by=subplot_by,
218208
frame=frame,
219209
estimator_name=self.importances["estimator"].unique()[0],
@@ -229,7 +219,7 @@ def _plot_single_estimator(
229219
estimator_name: str,
230220
boxplot_kwargs: dict[str, Any],
231221
stripplot_kwargs: dict[str, Any],
232-
) -> None:
222+
) -> Any:
233223
"""Plot the permutation importance for an `EstimatorReport`."""
234224
if subplot_by == "auto":
235225
is_multi_metric = frame["metric"].nunique() > 1
@@ -295,7 +285,7 @@ def _plot_single_estimator(
295285
# deprecation warning if passing palette without a hue
296286
stripplot_kwargs.pop("palette", None)
297287

298-
self.facet_ = sns.catplot(
288+
facet_ = sns.catplot(
299289
data=frame,
300290
x="value",
301291
y="feature",
@@ -317,9 +307,9 @@ def _plot_single_estimator(
317307
add_background_features = hue is not None
318308

319309
metrics = frame["metric"].unique()
320-
self.figure_, self.ax_ = self.facet_.figure, self.facet_.axes.squeeze()
321-
for row_index, row_axes in enumerate(self.facet_.axes):
322-
for col_index, ax in enumerate(row_axes):
310+
figure_, ax_ = facet_.figure, facet_.axes.squeeze()
311+
for row_index, row_axes in enumerate(facet_.axes):
312+
for col_index, axis in enumerate(row_axes):
323313
if len(metrics) > 1:
324314
if row == "metric":
325315
xlabel = f"Decrease in {metrics[row_index]}"
@@ -331,18 +321,19 @@ def _plot_single_estimator(
331321
xlabel = f"Decrease in {metrics[0]}"
332322

333323
_decorate_matplotlib_axis(
334-
ax=ax,
324+
ax=axis,
335325
add_background_features=add_background_features,
336326
n_features=frame["feature"].nunique(),
337327
xlabel=xlabel,
338328
ylabel="",
339329
)
340-
if len(self.ax_.flatten()) == 1:
341-
self.ax_ = self.ax_.flatten()[0]
330+
if len(ax_.flatten()) == 1:
331+
ax_ = ax_.flatten()[0]
342332
data_source = frame["data_source"].unique()[0]
343-
self.figure_.suptitle(
333+
figure_.suptitle(
344334
f"Permutation importance of {estimator_name} on {data_source} set"
345335
)
336+
return facet_
346337

347338
def frame(
348339
self,

0 commit comments

Comments
 (0)