diff --git a/examples/getting_started/plot_skore_getting_started.py b/examples/getting_started/plot_skore_getting_started.py index 8acaa65031..d73a4a53ea 100644 --- a/examples/getting_started/plot_skore_getting_started.py +++ b/examples/getting_started/plot_skore_getting_started.py @@ -165,7 +165,7 @@ # for example getting the report metrics for the first split only: # %% -cv_report.estimator_reports_[0].metrics.summarize().frame() +cv_report.reports_[0].metrics.summarize().frame() # %% # .. seealso:: diff --git a/examples/use_cases/plot_employee_salaries.py b/examples/use_cases/plot_employee_salaries.py index 83fc2e4a16..dd1f900b10 100644 --- a/examples/use_cases/plot_employee_salaries.py +++ b/examples/use_cases/plot_employee_salaries.py @@ -165,7 +165,7 @@ # :class:`skore.EstimatorReport` for each split. # %% -hgbt_split_1 = hgbt_model_report.estimator_reports_[0] +hgbt_split_1 = hgbt_model_report.reports_[0] hgbt_split_1.metrics.summarize(indicator_favorability=True).frame() # %% diff --git a/skore/src/skore/_sklearn/_comparison/metrics_accessor.py b/skore/src/skore/_sklearn/_comparison/metrics_accessor.py index eaa8b4de56..8abd545cbb 100644 --- a/skore/src/skore/_sklearn/_comparison/metrics_accessor.py +++ b/skore/src/skore/_sklearn/_comparison/metrics_accessor.py @@ -1272,7 +1272,7 @@ def _get_display( else: for report_name, report in self._parent.reports_.items(): - for split, estimator_report in enumerate(report.estimator_reports_): + for split, estimator_report in enumerate(report.reports_): report_X, report_y, _ = ( estimator_report.metrics._get_X_y_and_data_source_hash( data_source=data_source, @@ -1320,7 +1320,7 @@ def _get_display( estimators=[ estimator_report.estimator_ for report in self._parent.reports_.values() - for estimator_report in report.estimator_reports_ + for estimator_report in report.reports_ ], ml_task=self._parent._ml_task, data_source=data_source, diff --git a/skore/src/skore/_sklearn/_cross_validation/feature_importance_accessor.py b/skore/src/skore/_sklearn/_cross_validation/feature_importance_accessor.py index 92ef345411..73e79b0a1d 100644 --- a/skore/src/skore/_sklearn/_cross_validation/feature_importance_accessor.py +++ b/skore/src/skore/_sklearn/_cross_validation/feature_importance_accessor.py @@ -50,7 +50,7 @@ def coefficients(self) -> FeatureImportanceDisplay: for split, df in enumerate( list( report.feature_importance.coefficients().frame() - for report in self._parent.estimator_reports_ + for report in self._parent.reports_ ) ) }, diff --git a/skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py b/skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py index 34f1c7d30e..e2949c8a8e 100644 --- a/skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py +++ b/skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py @@ -218,7 +218,7 @@ def _compute_metric_scores( progress = self._parent._progress_info["current_progress"] main_task = self._parent._progress_info["current_task"] - total_estimators = len(self._parent.estimator_reports_) + total_estimators = len(self._parent.reports_) progress.update(main_task, total=total_estimators) if cache_key in self._parent._cache: @@ -233,7 +233,7 @@ def _compute_metric_scores( delayed(getattr(report.metrics, report_metric_name))( data_source=data_source, X=X, y=y, **metric_kwargs ) - for report in self._parent.estimator_reports_ + for report in self._parent.reports_ ) results = [] for result in generator: @@ -315,10 +315,10 @@ def timings( timings: pd.DataFrame = pd.concat( [ pd.Series(report.metrics.timings()) - for report in self._parent.estimator_reports_ + for report in self._parent.reports_ ], axis=1, - keys=[f"Split #{i}" for i in range(len(self._parent.estimator_reports_))], + keys=[f"Split #{i}" for i in range(len(self._parent.reports_))], ) if aggregate: if isinstance(aggregate, str): @@ -1128,7 +1128,7 @@ def _get_display( assert self._parent._progress_info is not None, "Progress info not set" progress = self._parent._progress_info["current_progress"] main_task = self._parent._progress_info["current_task"] - total_estimators = len(self._parent.estimator_reports_) + total_estimators = len(self._parent.reports_) progress.update(main_task, total=total_estimators) if cache_key in self._parent._cache: @@ -1136,7 +1136,7 @@ def _get_display( else: y_true: list[YPlotData] = [] y_pred: list[YPlotData] = [] - for report_idx, report in enumerate(self._parent.estimator_reports_): + for report_idx, report in enumerate(self._parent.reports_): if data_source != "X_y": # only retrieve data stored in the reports when we don't want to # use an external common X and y @@ -1178,7 +1178,7 @@ def _get_display( y_pred=y_pred, report_type="cross-validation", estimators=[ - report.estimator_ for report in self._parent.estimator_reports_ + report.estimator_ for report in self._parent.reports_ ], ml_task=self._parent._ml_task, data_source=data_source, diff --git a/skore/src/skore/_sklearn/_cross_validation/report.py b/skore/src/skore/_sklearn/_cross_validation/report.py index ecd730d9ba..b34859d2ce 100644 --- a/skore/src/skore/_sklearn/_cross_validation/report.py +++ b/skore/src/skore/_sklearn/_cross_validation/report.py @@ -120,7 +120,7 @@ class CrossValidationReport(_BaseReport, DirNamesMixin): estimator_name_ : str The name of the estimator. - estimator_reports_ : list of EstimatorReport + reports_ : list of EstimatorReport The estimator reports for each split. See Also @@ -170,16 +170,14 @@ def __init__( self._split_indices = tuple(self._splitter.split(self._X, self._y)) self.n_jobs = n_jobs - self.estimator_reports_: list[EstimatorReport] = self._fit_estimator_reports() + self.reports_: list[EstimatorReport] = self._fit_estimator_reports() self._rng = np.random.default_rng(time.time_ns()) self._hash = self._rng.integers( low=np.iinfo(np.int64).min, high=np.iinfo(np.int64).max ) self._cache: dict[tuple[Any, ...], Any] = {} - self._ml_task = _find_ml_task( - y, estimator=self.estimator_reports_[0]._estimator - ) + self._ml_task = _find_ml_task(y, estimator=self.reports_[0]._estimator) @progress_decorator( description=lambda self: ( @@ -293,7 +291,7 @@ def clear_cache(self) -> None: >>> report._cache {} """ - for report in self.estimator_reports_: + for report in self.reports_: report.clear_cache() self._cache = {} @@ -336,10 +334,10 @@ def cache_predictions( progress = self._progress_info["current_progress"] main_task = self._progress_info["current_task"] - total_estimators = len(self.estimator_reports_) + total_estimators = len(self.reports_) progress.update(main_task, total=total_estimators) - for split_idx, estimator_report in enumerate(self.estimator_reports_, 1): + for split_idx, estimator_report in enumerate(self.reports_, 1): # Share the parent's progress bar with child report estimator_report._progress_info = { "current_progress": progress, @@ -439,7 +437,7 @@ def get_predictions( X=X, pos_label=pos_label, ) - for report in self.estimator_reports_ + for report in self.reports_ ] @property @@ -489,6 +487,17 @@ def pos_label(self, value: PositiveLabel | None) -> None: f"Call the constructor of {self.__class__.__name__} to create a new report." ) + @property + def estimator_reports_(self) -> list[EstimatorReport]: + """ + The estimator reports for each split. + + .. deprecated + The ``report.estimator_reports_`` property will be removed in favor of + ``report.reports_`` in a near future. + """ + return self.reports_ + #################################################################################### # Methods related to the help and repr #################################################################################### diff --git a/skore/src/skore/_utils/_accessor.py b/skore/src/skore/_utils/_accessor.py index ebf374b232..7c6de4f628 100644 --- a/skore/src/skore/_utils/_accessor.py +++ b/skore/src/skore/_utils/_accessor.py @@ -140,7 +140,7 @@ def _check_estimator_report_has_method( method_name: str, ) -> Callable: def check(accessor: Any) -> bool: - estimator_report = accessor._parent.estimator_reports_[0] + estimator_report = accessor._parent.reports_[0] if not hasattr(estimator_report, accessor_name): raise AttributeError( @@ -163,7 +163,7 @@ def check(accessor: Any) -> bool: def _check_cross_validation_sub_estimator_has_coef() -> Callable: def check(accessor: Any) -> bool: """Check if the underlying estimator has a `coef_` attribute.""" - return _check_has_coef(accessor._parent.estimator_reports_[0].estimator) + return _check_has_coef(accessor._parent.reports_[0].estimator) return check @@ -183,7 +183,7 @@ def check(accessor: Any) -> bool: for parent_report in parent.reports_.values(): if parent._reports_type == "CrossValidationReport": parent_report = cast(CrossValidationReport, parent_report) - parent_estimators.append(parent_report.estimator_reports_[0].estimator_) + parent_estimators.append(parent_report.reports_[0].estimator_) elif parent._reports_type == "EstimatorReport": parent_estimators.append(parent_report.estimator_) else: diff --git a/skore/src/skore/_utils/_progress_bar.py b/skore/src/skore/_utils/_progress_bar.py index 19a467875c..8a2b8fbf72 100644 --- a/skore/src/skore/_utils/_progress_bar.py +++ b/skore/src/skore/_utils/_progress_bar.py @@ -38,17 +38,23 @@ def progress_decorator( def decorator(func: Callable[..., T]) -> Callable[..., T]: @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> T: + # Avoid circular import + from skore import ComparisonReport + + # The object to which a `rich.Progress` instance will be attached. + # Expected to be a Report + # (EstimatorReport | CrossValidationReport | ComparisonReport) self_obj: Any = args[0] + # If the decorated method is in an Accessor (e.g. MetricsAccessor), + # then make sure `self_obj` is the Report, not the Accessor. if hasattr(self_obj, "_parent"): # self_obj is an accessor self_obj = self_obj._parent - desc = description(self_obj) if callable(description) else description - created_progress = False - if getattr(self_obj, "_progress_info", None) is not None: + if self_obj._progress_info is not None: progress = self_obj._progress_info["current_progress"] else: progress = Progress( @@ -67,15 +73,20 @@ def wrapper(*args: Any, **kwargs: Any) -> T: progress.start() created_progress = True - # assigning progress to child reports + # Make child reports share their parent's Progress instance + # so that there is only one Progress instance at any given point reports_to_cleanup: list[Any] = [] - if hasattr(self_obj, "reports_"): + if isinstance(self_obj, ComparisonReport): for report in self_obj.reports_.values(): - if hasattr(report, "_progress_info"): - report._progress_info = {"current_progress": progress} - reports_to_cleanup.append(report) - - task = progress.add_task(desc, total=None) + report._progress_info = {"current_progress": progress} + reports_to_cleanup.append(report) + + task = progress.add_task( + description=( + description(self_obj) if callable(description) else description + ), + total=None, + ) self_obj._progress_info = { "current_progress": progress, "current_task": task, diff --git a/skore/tests/unit/displays/metrics_summary/test_cross_validation.py b/skore/tests/unit/displays/metrics_summary/test_cross_validation.py index bbdbffd248..92bc1bdf9d 100644 --- a/skore/tests/unit/displays/metrics_summary/test_cross_validation.py +++ b/skore/tests/unit/displays/metrics_summary/test_cross_validation.py @@ -56,7 +56,7 @@ def test_data_source_external( for split_idx in range(splitter): # check that it is equivalent to call the individual estimator report report_result = ( - report.estimator_reports_[split_idx] + report.reports_[split_idx] .metrics.summarize(data_source="X_y", X=X, y=y) .frame() ) @@ -389,7 +389,7 @@ def test_scorer(linear_regression_data): est_rep.y_test, est_rep.estimator_.predict(est_rep.X_test) ), ] - for est_rep in report.estimator_reports_ + for est_rep in report.reports_ ] np.testing.assert_allclose( result.to_numpy(), diff --git a/skore/tests/unit/displays/precision_recall_curve/test_comparison_cross_validation.py b/skore/tests/unit/displays/precision_recall_curve/test_comparison_cross_validation.py index 87d60c22c4..d85cda83d9 100644 --- a/skore/tests/unit/displays/precision_recall_curve/test_comparison_cross_validation.py +++ b/skore/tests/unit/displays/precision_recall_curve/test_comparison_cross_validation.py @@ -28,7 +28,7 @@ def test_binary_classification( pos_label = 1 n_reports = len(report.reports_) - n_splits = len(list(report.reports_.values())[0].estimator_reports_) + n_splits = len(list(report.reports_.values())[0].reports_) display.plot() assert isinstance(display.lines_, list) @@ -76,7 +76,7 @@ def test_multiclass_classification( labels = display.precision_recall["label"].cat.categories n_reports = len(report.reports_) - n_splits = len(list(report.reports_.values())[0].estimator_reports_) + n_splits = len(list(report.reports_.values())[0].reports_) display.plot() assert isinstance(display.lines_, list) diff --git a/skore/tests/unit/displays/precision_recall_curve/test_cross_validation.py b/skore/tests/unit/displays/precision_recall_curve/test_cross_validation.py index 7b8c0b7b00..141abd4de7 100644 --- a/skore/tests/unit/displays/precision_recall_curve/test_cross_validation.py +++ b/skore/tests/unit/displays/precision_recall_curve/test_cross_validation.py @@ -32,7 +32,7 @@ def test_binary_classification( display.plot() - pos_label = report.estimator_reports_[0].estimator_.classes_[1] + pos_label = report.reports_[0].estimator_.classes_[1] assert hasattr(display, "ax_") assert hasattr(display, "figure_") @@ -88,7 +88,7 @@ def test_multiclass_classification( display.plot() - class_labels = report.estimator_reports_[0].estimator_.classes_ + class_labels = report.reports_[0].estimator_.classes_ assert isinstance(display.lines_, list) assert len(display.lines_) == len(class_labels) * cv diff --git a/skore/tests/unit/displays/roc_curve/test_cross_validation.py b/skore/tests/unit/displays/roc_curve/test_cross_validation.py index fa8c52c4c9..83f052b95c 100644 --- a/skore/tests/unit/displays/roc_curve/test_cross_validation.py +++ b/skore/tests/unit/displays/roc_curve/test_cross_validation.py @@ -27,7 +27,7 @@ def test_binary_classification( assert isinstance(display, RocCurveDisplay) check_display_data(display) - pos_label = report.estimator_reports_[0].estimator_.classes_[1] + pos_label = report.reports_[0].estimator_.classes_[1] assert ( list(display.roc_curve["label"].unique()) == list(display.roc_auc["label"].unique()) @@ -88,7 +88,7 @@ def test_multiclass_classification( # check the structure of the attributes check_display_data(display) - class_labels = report.estimator_reports_[0].estimator_.classes_ + class_labels = report.reports_[0].estimator_.classes_ assert ( list(display.roc_curve["label"].unique()) == list(display.roc_auc["label"].unique()) diff --git a/skore/tests/unit/reports/comparison/cross_validation/test_report.py b/skore/tests/unit/reports/comparison/cross_validation/test_report.py index 777087c53b..e2dcb5ec28 100644 --- a/skore/tests/unit/reports/comparison/cross_validation/test_report.py +++ b/skore/tests/unit/reports/comparison/cross_validation/test_report.py @@ -62,3 +62,9 @@ def test_get_predictions( assert len(predictions) == len(report.reports_) for i, cv_report in enumerate(report.reports_.values()): assert len(predictions[i]) == cv_report._splitter.n_splits + + +def test_estimator_reports_(cross_validation_reports_regression): + cv, _ = cross_validation_reports_regression + + assert id(cv.estimator_reports_) == id(cv.reports_) diff --git a/skore/tests/unit/reports/cross_validation/test_report.py b/skore/tests/unit/reports/cross_validation/test_report.py index ad84f38eb9..4a79346760 100644 --- a/skore/tests/unit/reports/cross_validation/test_report.py +++ b/skore/tests/unit/reports/cross_validation/test_report.py @@ -74,13 +74,13 @@ def test_attributes(fixture_name, request, cv, n_jobs): estimator, X, y = request.getfixturevalue(fixture_name) report = CrossValidationReport(estimator, X, y, splitter=cv, n_jobs=n_jobs) assert isinstance(report, CrossValidationReport) - assert isinstance(report.estimator_reports_, list) - for estimator_report in report.estimator_reports_: + assert isinstance(report.reports_, list) + for estimator_report in report.reports_: assert isinstance(estimator_report, EstimatorReport) assert report.X is X assert report.y is y assert report.n_jobs == n_jobs - assert len(report.estimator_reports_) == cv + assert len(report.reports_) == cv if isinstance(estimator, Pipeline): assert report.estimator_name_ == estimator[-1].__class__.__name__ else: @@ -137,12 +137,12 @@ def test_cache_predictions(request, fixture_name, expected_n_keys, n_jobs): # underlying estimator reports assert report._cache == {} - for estimator_report in report.estimator_reports_: + for estimator_report in report.reports_: assert len(estimator_report._cache) == expected_n_keys report.clear_cache() assert report._cache == {} - for estimator_report in report.estimator_reports_: + for estimator_report in report.reports_: assert estimator_report._cache == {} @@ -174,9 +174,9 @@ def test_get_predictions( assert len(predictions) == 2 for split_idx, split_predictions in enumerate(predictions): if data_source == "train": - expected_shape = report.estimator_reports_[split_idx].y_train.shape + expected_shape = report.reports_[split_idx].y_train.shape elif data_source == "test": - expected_shape = report.estimator_reports_[split_idx].y_test.shape + expected_shape = report.reports_[split_idx].y_test.shape else: # data_source == "X_y" expected_shape = (X.shape[0],) assert split_predictions.shape == expected_shape diff --git a/skore/tests/unit/utils/test_accessors.py b/skore/tests/unit/utils/test_accessors.py index b807f7553c..f9ac95565a 100644 --- a/skore/tests/unit/utils/test_accessors.py +++ b/skore/tests/unit/utils/test_accessors.py @@ -130,7 +130,7 @@ def __init__(self, estimator): class MockParent: def __init__(self, estimator): - self.estimator_reports_ = [MockReport(estimator)] + self.reports_ = [MockReport(estimator)] class MockAccessor: def __init__(self, parent): diff --git a/skore/tests/unit/utils/test_progress_bar.py b/skore/tests/unit/utils/test_progress_bar.py index 22a1d11626..8905beffb8 100644 --- a/skore/tests/unit/utils/test_progress_bar.py +++ b/skore/tests/unit/utils/test_progress_bar.py @@ -125,41 +125,3 @@ def run(self): # Verify progress bar was cleaned up assert task._progress_info is None - - -def test_child_report_cleanup(): - """Ensure that child reports in reports_ get progress assigned and then cleaned - up.""" - - class Child: - def __init__(self): - self._progress_info = None - self.called = False - - @progress_decorator("Child Process") - def process(self): - self.called = True - return "child_done" - - class Parent: - def __init__(self): - self._progress_info = None - self.reports_ = {"child1": Child(), "child2": Child()} - - @progress_decorator("Parent Process") - def run(self): - results = [] - for rpt in self.reports_.values(): - results.append(rpt.process()) - return results - - parent = Parent() - results = parent.run() - - assert results == ["child_done", "child_done"] - assert all(rp.called for rp in parent.reports_.values()) - # Verify that progress attributes are cleaned for each child report - for rp in parent.reports_.values(): - assert rp._progress_info is None - # Also verify if parent reports are cleaned up as well - assert parent._progress_info is None diff --git a/sphinx/user_guide/reporters.rst b/sphinx/user_guide/reporters.rst index 064660f9d5..4d2c1f275b 100644 --- a/sphinx/user_guide/reporters.rst +++ b/sphinx/user_guide/reporters.rst @@ -120,7 +120,7 @@ difference is in the initialization. It accepts an estimator, a dataset (i.e. `X `y`) and a cross-validation strategy. Internally, the dataset is split according to the cross-validation strategy and an estimator report is created for each split. Therefore, a :class:`CrossValidationReport` is a collection of :class:`EstimatorReport` instances, -available through the :obj:`CrossValidationReport.estimator_reports_` attribute. +available through the :obj:`CrossValidationReport.reports_` attribute. For metrics and displays, the same API is exposed with an extra parameter, `aggregate`, to aggregate the metrics across the splits.