Skip to content

Commit 1e06cce

Browse files
authored
[feat] Support statistics print by adding results manager object (#334)
* [feat] Support statistics print by adding results manager object * [refactor] Make SearchResults extract run_history at __init__ Since the search results should not be kept in eternally, I made this class to take run_history in __init__ so that we can implicitly call extraction inside. From this change, the call of extraction from outside is not recommended. However, you can still call it from outside and to prevent mixup of the environment, self.clear() will be called. * [fix] Separate those changes into PR#336 * [fix] Fix so that test_loss includes all the metrics * [enhance] Strengthen the test for sprint and SearchResults * [fix] Fix an issue in documentation * [enhance] Increase the coverage * [refactor] Separate the test for results_manager to organize the structure * [test] Add the test for get_incumbent_Result * [test] Remove the previous test_get_incumbent and see the coverage * [fix] [test] Fix reversion of metric and strengthen the test cases * [fix] Fix flake8 issues and increase coverage * [fix] Address Ravin's comments * [enhance] Increase the coverage * [fix] Fix a flake8 issu
1 parent 2d2f6d1 commit 1e06cce

15 files changed

+2505
-118
lines changed

autoPyTorch/api/base_task.py

Lines changed: 72 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from smac.stats.stats import Stats
3030
from smac.tae import StatusType
3131

32+
from autoPyTorch.api.results_manager import ResultsManager, SearchResults
3233
from autoPyTorch.automl_common.common.utils.backend import Backend, create
3334
from autoPyTorch.constants import (
3435
REGRESSION_TASKS,
@@ -192,12 +193,13 @@ def __init__(
192193
self.search_space: Optional[ConfigurationSpace] = None
193194
self._dataset_requirements: Optional[List[FitRequirement]] = None
194195
self._metric: Optional[autoPyTorchMetric] = None
196+
self._scoring_functions: Optional[List[autoPyTorchMetric]] = None
195197
self._logger: Optional[PicklableClientLogger] = None
196-
self.run_history: RunHistory = RunHistory()
197-
self.trajectory: Optional[List] = None
198198
self.dataset_name: Optional[str] = None
199199
self.cv_models_: Dict = {}
200200

201+
self._results_manager = ResultsManager()
202+
201203
# By default try to use the TCP logging port or get a new port
202204
self._logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT
203205

@@ -240,6 +242,18 @@ def build_pipeline(self, dataset_properties: Dict[str, Any]) -> BasePipeline:
240242
"""
241243
raise NotImplementedError
242244

245+
@property
246+
def run_history(self) -> RunHistory:
247+
return self._results_manager.run_history
248+
249+
@property
250+
def ensemble_performance_history(self) -> List[Dict[str, Any]]:
251+
return self._results_manager.ensemble_performance_history
252+
253+
@property
254+
def trajectory(self) -> Optional[List]:
255+
return self._results_manager.trajectory
256+
243257
def set_pipeline_config(self, **pipeline_config_kwargs: Any) -> None:
244258
"""
245259
Check whether arguments are valid and
@@ -883,6 +897,12 @@ def _search(
883897

884898
self.pipeline_options['optimize_metric'] = optimize_metric
885899

900+
if all_supported_metrics:
901+
self._scoring_functions = get_metrics(dataset_properties=dataset_properties,
902+
all_supported_metrics=True)
903+
else:
904+
self._scoring_functions = [self._metric]
905+
886906
self.search_space = self.get_search_space(dataset)
887907

888908
# Incorporate budget to pipeline config
@@ -1037,12 +1057,14 @@ def _search(
10371057
pynisher_context=self._multiprocessing_context,
10381058
)
10391059
try:
1040-
run_history, self.trajectory, budget_type = \
1060+
run_history, self._results_manager.trajectory, budget_type = \
10411061
_proc_smac.run_smbo(func=tae_func)
10421062
self.run_history.update(run_history, DataOrigin.INTERNAL)
10431063
trajectory_filename = os.path.join(
10441064
self._backend.get_smac_output_directory_for_run(self.seed),
10451065
'trajectory.json')
1066+
1067+
assert self.trajectory is not None # mypy check
10461068
saveable_trajectory = \
10471069
[list(entry[:2]) + [entry[2].get_dictionary()] + list(entry[3:])
10481070
for entry in self.trajectory]
@@ -1059,7 +1081,7 @@ def _search(
10591081
self._logger.info("Starting Shutdown")
10601082

10611083
if proc_ensemble is not None:
1062-
self.ensemble_performance_history = list(proc_ensemble.history)
1084+
self._results_manager.ensemble_performance_history = list(proc_ensemble.history)
10631085

10641086
if len(proc_ensemble.futures) > 0:
10651087
# Also add ensemble runs that did not finish within smac time
@@ -1068,7 +1090,7 @@ def _search(
10681090
result = proc_ensemble.futures.pop().result()
10691091
if result:
10701092
ensemble_history, _, _, _ = result
1071-
self.ensemble_performance_history.extend(ensemble_history)
1093+
self._results_manager.ensemble_performance_history.extend(ensemble_history)
10721094
self._logger.info("Ensemble script finished, continue shutdown.")
10731095

10741096
# save the ensemble performance history file
@@ -1356,28 +1378,12 @@ def get_incumbent_results(
13561378
The incumbent configuration
13571379
Dict[str, Union[int, str, float]]:
13581380
Additional information about the run of the incumbent configuration.
1359-
13601381
"""
1361-
assert self.run_history is not None, "No Run History found, search has not been called."
1362-
if self.run_history.empty():
1363-
raise ValueError("Run History is empty. Something went wrong, "
1364-
"smac was not able to fit any model?")
1365-
1366-
run_history_data = self.run_history.data
1367-
if not include_traditional:
1368-
# traditional classifiers have trainer_configuration in their additional info
1369-
run_history_data = dict(
1370-
filter(lambda elem: elem[1].status == StatusType.SUCCESS and elem[1].
1371-
additional_info is not None and elem[1].
1372-
additional_info['configuration_origin'] != 'traditional',
1373-
run_history_data.items()))
1374-
run_history_data = dict(
1375-
filter(lambda elem: 'SUCCESS' in str(elem[1].status), run_history_data.items()))
1376-
sorted_runvalue_by_cost = sorted(run_history_data.items(), key=lambda item: item[1].cost)
1377-
incumbent_run_key, incumbent_run_value = sorted_runvalue_by_cost[0]
1378-
incumbent_config = self.run_history.ids_config[incumbent_run_key.config_id]
1379-
incumbent_results = incumbent_run_value.additional_info
1380-
return incumbent_config, incumbent_results
1382+
1383+
if self._metric is None:
1384+
raise RuntimeError("`search_results` is only available after a search has finished.")
1385+
1386+
return self._results_manager.get_incumbent_results(metric=self._metric, include_traditional=include_traditional)
13811387

13821388
def get_models_with_weights(self) -> List:
13831389
if self.models_ is None or len(self.models_) == 0 or \
@@ -1417,3 +1423,43 @@ def _print_debug_info_to_log(self) -> None:
14171423
self._logger.debug(' multiprocessing_context: %s', str(self._multiprocessing_context))
14181424
for key, value in vars(self).items():
14191425
self._logger.debug(f"\t{key}->{value}")
1426+
1427+
def get_search_results(self) -> SearchResults:
1428+
"""
1429+
Get the interface to obtain the search results easily.
1430+
"""
1431+
if self._scoring_functions is None or self._metric is None:
1432+
raise RuntimeError("`search_results` is only available after a search has finished.")
1433+
1434+
return self._results_manager.get_search_results(
1435+
metric=self._metric,
1436+
scoring_functions=self._scoring_functions
1437+
)
1438+
1439+
def sprint_statistics(self) -> str:
1440+
"""
1441+
Prints statistics about the SMAC search.
1442+
1443+
These statistics include:
1444+
1445+
1. Optimisation Metric
1446+
2. Best Optimisation score achieved by individual pipelines
1447+
3. Total number of target algorithm runs
1448+
4. Total number of successful target algorithm runs
1449+
5. Total number of crashed target algorithm runs
1450+
6. Total number of target algorithm runs that exceeded the time limit
1451+
7. Total number of successful target algorithm runs that exceeded the memory limit
1452+
1453+
Returns:
1454+
(str):
1455+
Formatted string with statistics
1456+
"""
1457+
if self._scoring_functions is None or self._metric is None:
1458+
raise RuntimeError("`search_results` is only available after a search has finished.")
1459+
1460+
assert self.dataset_name is not None # my check
1461+
return self._results_manager.sprint_statistics(
1462+
dataset_name=self.dataset_name,
1463+
scoring_functions=self._scoring_functions,
1464+
metric=self._metric
1465+
)

0 commit comments

Comments
 (0)