29
29
from smac .stats .stats import Stats
30
30
from smac .tae import StatusType
31
31
32
+ from autoPyTorch .api .results_manager import ResultsManager , SearchResults
32
33
from autoPyTorch .automl_common .common .utils .backend import Backend , create
33
34
from autoPyTorch .constants import (
34
35
REGRESSION_TASKS ,
@@ -192,12 +193,13 @@ def __init__(
192
193
self .search_space : Optional [ConfigurationSpace ] = None
193
194
self ._dataset_requirements : Optional [List [FitRequirement ]] = None
194
195
self ._metric : Optional [autoPyTorchMetric ] = None
196
+ self ._scoring_functions : Optional [List [autoPyTorchMetric ]] = None
195
197
self ._logger : Optional [PicklableClientLogger ] = None
196
- self .run_history : RunHistory = RunHistory ()
197
- self .trajectory : Optional [List ] = None
198
198
self .dataset_name : Optional [str ] = None
199
199
self .cv_models_ : Dict = {}
200
200
201
+ self ._results_manager = ResultsManager ()
202
+
201
203
# By default try to use the TCP logging port or get a new port
202
204
self ._logger_port = logging .handlers .DEFAULT_TCP_LOGGING_PORT
203
205
@@ -240,6 +242,18 @@ def build_pipeline(self, dataset_properties: Dict[str, Any]) -> BasePipeline:
240
242
"""
241
243
raise NotImplementedError
242
244
245
+ @property
246
+ def run_history (self ) -> RunHistory :
247
+ return self ._results_manager .run_history
248
+
249
+ @property
250
+ def ensemble_performance_history (self ) -> List [Dict [str , Any ]]:
251
+ return self ._results_manager .ensemble_performance_history
252
+
253
+ @property
254
+ def trajectory (self ) -> Optional [List ]:
255
+ return self ._results_manager .trajectory
256
+
243
257
def set_pipeline_config (self , ** pipeline_config_kwargs : Any ) -> None :
244
258
"""
245
259
Check whether arguments are valid and
@@ -883,6 +897,12 @@ def _search(
883
897
884
898
self .pipeline_options ['optimize_metric' ] = optimize_metric
885
899
900
+ if all_supported_metrics :
901
+ self ._scoring_functions = get_metrics (dataset_properties = dataset_properties ,
902
+ all_supported_metrics = True )
903
+ else :
904
+ self ._scoring_functions = [self ._metric ]
905
+
886
906
self .search_space = self .get_search_space (dataset )
887
907
888
908
# Incorporate budget to pipeline config
@@ -1037,12 +1057,14 @@ def _search(
1037
1057
pynisher_context = self ._multiprocessing_context ,
1038
1058
)
1039
1059
try :
1040
- run_history , self .trajectory , budget_type = \
1060
+ run_history , self ._results_manager . trajectory , budget_type = \
1041
1061
_proc_smac .run_smbo (func = tae_func )
1042
1062
self .run_history .update (run_history , DataOrigin .INTERNAL )
1043
1063
trajectory_filename = os .path .join (
1044
1064
self ._backend .get_smac_output_directory_for_run (self .seed ),
1045
1065
'trajectory.json' )
1066
+
1067
+ assert self .trajectory is not None # mypy check
1046
1068
saveable_trajectory = \
1047
1069
[list (entry [:2 ]) + [entry [2 ].get_dictionary ()] + list (entry [3 :])
1048
1070
for entry in self .trajectory ]
@@ -1059,7 +1081,7 @@ def _search(
1059
1081
self ._logger .info ("Starting Shutdown" )
1060
1082
1061
1083
if proc_ensemble is not None :
1062
- self .ensemble_performance_history = list (proc_ensemble .history )
1084
+ self ._results_manager . ensemble_performance_history = list (proc_ensemble .history )
1063
1085
1064
1086
if len (proc_ensemble .futures ) > 0 :
1065
1087
# Also add ensemble runs that did not finish within smac time
@@ -1068,7 +1090,7 @@ def _search(
1068
1090
result = proc_ensemble .futures .pop ().result ()
1069
1091
if result :
1070
1092
ensemble_history , _ , _ , _ = result
1071
- self .ensemble_performance_history .extend (ensemble_history )
1093
+ self ._results_manager . ensemble_performance_history .extend (ensemble_history )
1072
1094
self ._logger .info ("Ensemble script finished, continue shutdown." )
1073
1095
1074
1096
# save the ensemble performance history file
@@ -1356,28 +1378,12 @@ def get_incumbent_results(
1356
1378
The incumbent configuration
1357
1379
Dict[str, Union[int, str, float]]:
1358
1380
Additional information about the run of the incumbent configuration.
1359
-
1360
1381
"""
1361
- assert self .run_history is not None , "No Run History found, search has not been called."
1362
- if self .run_history .empty ():
1363
- raise ValueError ("Run History is empty. Something went wrong, "
1364
- "smac was not able to fit any model?" )
1365
-
1366
- run_history_data = self .run_history .data
1367
- if not include_traditional :
1368
- # traditional classifiers have trainer_configuration in their additional info
1369
- run_history_data = dict (
1370
- filter (lambda elem : elem [1 ].status == StatusType .SUCCESS and elem [1 ].
1371
- additional_info is not None and elem [1 ].
1372
- additional_info ['configuration_origin' ] != 'traditional' ,
1373
- run_history_data .items ()))
1374
- run_history_data = dict (
1375
- filter (lambda elem : 'SUCCESS' in str (elem [1 ].status ), run_history_data .items ()))
1376
- sorted_runvalue_by_cost = sorted (run_history_data .items (), key = lambda item : item [1 ].cost )
1377
- incumbent_run_key , incumbent_run_value = sorted_runvalue_by_cost [0 ]
1378
- incumbent_config = self .run_history .ids_config [incumbent_run_key .config_id ]
1379
- incumbent_results = incumbent_run_value .additional_info
1380
- return incumbent_config , incumbent_results
1382
+
1383
+ if self ._metric is None :
1384
+ raise RuntimeError ("`search_results` is only available after a search has finished." )
1385
+
1386
+ return self ._results_manager .get_incumbent_results (metric = self ._metric , include_traditional = include_traditional )
1381
1387
1382
1388
def get_models_with_weights (self ) -> List :
1383
1389
if self .models_ is None or len (self .models_ ) == 0 or \
@@ -1417,3 +1423,43 @@ def _print_debug_info_to_log(self) -> None:
1417
1423
self ._logger .debug (' multiprocessing_context: %s' , str (self ._multiprocessing_context ))
1418
1424
for key , value in vars (self ).items ():
1419
1425
self ._logger .debug (f"\t { key } ->{ value } " )
1426
+
1427
+ def get_search_results (self ) -> SearchResults :
1428
+ """
1429
+ Get the interface to obtain the search results easily.
1430
+ """
1431
+ if self ._scoring_functions is None or self ._metric is None :
1432
+ raise RuntimeError ("`search_results` is only available after a search has finished." )
1433
+
1434
+ return self ._results_manager .get_search_results (
1435
+ metric = self ._metric ,
1436
+ scoring_functions = self ._scoring_functions
1437
+ )
1438
+
1439
+ def sprint_statistics (self ) -> str :
1440
+ """
1441
+ Prints statistics about the SMAC search.
1442
+
1443
+ These statistics include:
1444
+
1445
+ 1. Optimisation Metric
1446
+ 2. Best Optimisation score achieved by individual pipelines
1447
+ 3. Total number of target algorithm runs
1448
+ 4. Total number of successful target algorithm runs
1449
+ 5. Total number of crashed target algorithm runs
1450
+ 6. Total number of target algorithm runs that exceeded the time limit
1451
+ 7. Total number of successful target algorithm runs that exceeded the memory limit
1452
+
1453
+ Returns:
1454
+ (str):
1455
+ Formatted string with statistics
1456
+ """
1457
+ if self ._scoring_functions is None or self ._metric is None :
1458
+ raise RuntimeError ("`search_results` is only available after a search has finished." )
1459
+
1460
+ assert self .dataset_name is not None # my check
1461
+ return self ._results_manager .sprint_statistics (
1462
+ dataset_name = self .dataset_name ,
1463
+ scoring_functions = self ._scoring_functions ,
1464
+ metric = self ._metric
1465
+ )
0 commit comments