Skip to content

Commit b88f38c

Browse files
auguste-probabldirekkakkar319-ops
authored andcommitted
feat(skore/EstimatorReport)!: Move arguments from .summarize() to .frame() (probabl-ai#2536)
<!-- 🙌 Thanks for contributing a pull request! If this is your first contribution, take a look at our contribution guidelines: https://docs.skore.probabl.ai/dev/contributing.html --> #### Change description <!-- Please describe what your contribution changes for skore. Here is some inspiration for what to write here: - Are you adding a new feature? Fixing a bug? - Can you give an example of your change in action, e.g. a snippet of code or a plot? - Is your change likely to break users' code? - Are there any other details the reviewer should be aware of, such as API design choices, performance characteristics or edge cases? Please reference issues/PRs when possible, e.g. "Fixes probabl-ai#1234", "Closes #3456", "See also #7890". More information [here](https://github.com/blog/1506-closing-issues-via-pull-requests). --> `EstimatorReport.metrics.summarize()` no longer accepts arguments `flat_index` or `favorability`. These have been moved to `MetricsSummaryDisplay.frame()`. This moves the responsibility of displaying things from `EstimatorReport._MetricsAccessor` to `MetricsSummaryDisplay`. ```python # Before report.metrics.summarize(flat_index=True, favorability=True).frame() # After report.metrics.summarize().frame(flat_index=True, favorability=True) ``` This is a breaking change. The rest of the PR consists in various refactorings, in particular the tests have been updated to reflect the change in responsibility: - Tests of `summarize()` were extracted from `tests/unit/reports/estimator/metrics/test_numeric.py` to a new file, `tests/unit/reports/estimator/metrics/test_summarize.py`. - `test_summarize.py` now tests that `summarize()` behaves well, and that the output of `summarize()` has a well-formed DataFrame. `.frame()` is never called. - `displays/metrics_summary/test_estimator.py` has been rewritten to specifically test `.frame()` arguments. A number of commits are included that could be pulled out into a separate PR if needed. Closes probabl-ai#2533 Supersedes probabl-ai#1839 in part #### Contribution checklist <!-- Below are some of the criteria that the review will include. Feel free to use it as a checklist to ensure that your contribution is high-quality. --> - [x] Unit tests were added or updated (if necessary) - [x] Documentation was added or updated (if necessary) #### AI usage disclosure <!-- If AI tools were involved in creating this PR, please check all boxes that apply below and make sure you understand our [automated contributions policy](https://docs.skore.probabl.ai/dev/contributing.html#automated-contributions-policy) --> AI tools were involved for: - [ ] Code generation (e.g., when writing an implementation or fixing a bug) - [x] Test/benchmark generation - [ ] Documentation (including examples) - [ ] Research and understanding In particular I used Claude to increase coverage. <!-- Any other comments can go here. Thanks again for contributing! -->
1 parent c097861 commit b88f38c

File tree

14 files changed

+1217
-948
lines changed

14 files changed

+1217
-948
lines changed

examples/use_cases/plot_employee_salaries.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@
156156

157157
# %%
158158
hgbt_split_1 = hgbt_model_report.estimator_reports_[0]
159-
hgbt_split_1.metrics.summarize(favorability=True).frame()
159+
hgbt_split_1.metrics.summarize().frame(favorability=True)
160160

161161
# %%
162162
# The favorability of each metric indicates whether the metric is better

skore/src/skore/_sklearn/_comparison/metrics_accessor.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,9 @@ class is set to the one provided when creating the report. If `None`,
169169
results.index = results.index.str.replace(
170170
r"\((.*)\)$", r"\1", regex=True
171171
)
172-
return MetricsSummaryDisplay(results)
172+
return MetricsSummaryDisplay(
173+
data=results, report_type=self._parent._report_type
174+
)
173175

174176
def _compute_metric_scores(
175177
self,
@@ -199,15 +201,32 @@ def _compute_metric_scores(
199201
)
200202
)
201203

202-
kwargs = dict(
203-
data_source=data_source,
204-
**metric_kwargs,
205-
)
204+
kwargs = dict(data_source=data_source, **metric_kwargs)
206205
if is_cv_report:
207206
kwargs["aggregate"] = None
208207

208+
# FIXME(#1837)
209+
# "favorability" and "flat_index" are passed to `.frame()` for
210+
# EstimatorReports while they are passed to `.summarize()` for
211+
# CrossValidationReports
212+
frame_kwargs = {}
213+
if not is_cv_report:
214+
for key in ("favorability", "flat_index"):
215+
if key in kwargs:
216+
frame_kwargs[key] = kwargs.pop(key)
217+
218+
if "favorability" in kwargs:
219+
favorability = kwargs["favorability"]
220+
elif "favorability" in frame_kwargs:
221+
favorability = frame_kwargs["favorability"]
222+
else:
223+
favorability = False
224+
favorability = cast(bool, favorability)
225+
209226
individual_results = [
210-
result.frame() if report_metric_name == "summarize" else result
227+
result.frame(**frame_kwargs)
228+
if report_metric_name == "summarize"
229+
else result
211230
for result in track(
212231
parallel(
213232
joblib.delayed(getattr(report.metrics, report_metric_name))(
@@ -224,14 +243,14 @@ def _compute_metric_scores(
224243
results = _combine_estimator_results(
225244
individual_results,
226245
estimator_names=self._parent.reports_.keys(),
227-
favorability=metric_kwargs.get("favorability", False),
246+
favorability=favorability,
228247
data_source=data_source,
229248
)
230249
else: # "CrossValidationReport"
231250
results = _combine_cross_validation_results(
232251
individual_results,
233252
estimator_names=self._parent.reports_.keys(),
234-
favorability=metric_kwargs.get("favorability", False),
253+
favorability=favorability,
235254
aggregate=aggregate,
236255
)
237256

skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ class is set to the one provided when creating the report. If `None`,
168168
results.index = results.index.str.replace(
169169
r"\((.*)\)$", r"\1", regex=True
170170
)
171-
return MetricsSummaryDisplay(summarize_data=results)
171+
return MetricsSummaryDisplay(data=results, report_type="cross-validation")
172172

173173
def _compute_metric_scores(
174174
self,
@@ -196,8 +196,15 @@ def _compute_metric_scores(
196196
)
197197
)
198198

199+
frame_kwargs = {}
200+
for key in ("favorability", "flat_index"):
201+
if key in metric_kwargs:
202+
frame_kwargs[key] = metric_kwargs.pop(key)
203+
199204
results = [
200-
result.frame() if report_metric_name == "summarize" else result
205+
result.frame(**frame_kwargs)
206+
if report_metric_name == "summarize"
207+
else result
201208
for result in track(
202209
parallel(
203210
delayed(getattr(report.metrics, report_metric_name))(
@@ -220,7 +227,7 @@ def _compute_metric_scores(
220227
# Pop the favorability column if it exists, to:
221228
# - not use it in the aggregate operation
222229
# - later to only report a single column and not by split columns
223-
if metric_kwargs.get("favorability", False):
230+
if frame_kwargs.get("favorability", False):
224231
favorability = results.pop("Favorability").iloc[:, 0]
225232
else:
226233
favorability = None

0 commit comments

Comments
 (0)