Skip to content

Commit 7919802

Browse files
committed
Reg cocktails common paper modifications 2 (#417)
* remove remaining differences * Reg cocktails common paper modifications 5 (#418) * add hasttr * fix run summary
1 parent 0cd5cad commit 7919802

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

autoPyTorch/evaluation/train_evaluator.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import json
12
from multiprocessing.queues import Queue
3+
import os
24
from typing import Any, Dict, List, Optional, Tuple, Union
35

46
from ConfigSpace.configuration_space import Configuration
@@ -21,6 +23,7 @@
2123
)
2224
from autoPyTorch.evaluation.utils import DisableFileOutputParameters
2325
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
26+
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
2427
from autoPyTorch.utils.common import dict_repr, subsampler
2528
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
2629

@@ -193,6 +196,25 @@ def fit_predict_and_loss(self) -> None:
193196
additional_run_info = pipeline.get_additional_run_info() if hasattr(
194197
pipeline, 'get_additional_run_info') else {}
195198

199+
# add learning curve of configurations to additional_run_info
200+
if isinstance(pipeline, TabularClassificationPipeline):
201+
if hasattr(pipeline.named_steps['trainer'], 'run_summary'):
202+
run_summary = pipeline.named_steps['trainer'].run_summary
203+
split_types = ['train', 'val', 'test']
204+
run_summary_dict = dict(
205+
run_summary={},
206+
budget=self.budget,
207+
seed=self.seed,
208+
config_id=self.configuration.config_id,
209+
num_run=self.num_run
210+
)
211+
for split_type in split_types:
212+
run_summary_dict['run_summary'][f'{split_type}_loss'] = run_summary.performance_tracker.get(f'{split_type}_loss', None)
213+
run_summary_dict['run_summary'][f'{split_type}_metrics'] = run_summary.performance_tracker.get(f'{split_type}_metrics', None)
214+
self.logger.debug(f"run_summary_dict {json.dumps(run_summary_dict)}")
215+
with open(os.path.join(self.backend.temporary_directory, 'run_summary.txt'), 'a') as file:
216+
file.write(f"{json.dumps(run_summary_dict)}\n")
217+
196218
status = StatusType.SUCCESS
197219

198220
self.logger.debug("In train evaluator.fit_predict_and_loss, num_run: {} loss:{},"

autoPyTorch/pipeline/base_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,13 +351,13 @@ def _add_forbidden_conditions(self, cs: ConfigurationSpace) -> ConfigurationSpac
351351
if cyclic_lr_name in available_schedulers:
352352
# disable snapshot ensembles and stochastic weight averaging
353353
snapshot_ensemble_hyperparameter = cs.get_hyperparameter(f'trainer:{trainer}:use_snapshot_ensemble')
354-
if True in snapshot_ensemble_hyperparameter.choices:
354+
if hasattr(snapshot_ensemble_hyperparameter, 'choices') and True in snapshot_ensemble_hyperparameter.choices:
355355
cs.add_forbidden_clause(ForbiddenAndConjunction(
356356
ForbiddenEqualsClause(snapshot_ensemble_hyperparameter, True),
357357
ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name)
358358
))
359359
swa_hyperparameter = cs.get_hyperparameter(f'trainer:{trainer}:use_stochastic_weight_averaging')
360-
if True in swa_hyperparameter.choices:
360+
if hasattr(swa_hyperparameter, 'choices') and True in swa_hyperparameter.choices:
361361
cs.add_forbidden_clause(ForbiddenAndConjunction(
362362
ForbiddenEqualsClause(swa_hyperparameter, True),
363363
ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name)

autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def get_hyperparameter_search_space(self,
8686
"choices in {} got {}".format(self.__class__.__name__,
8787
available_preprocessors,
8888
choice_hyperparameter.value_range))
89-
if len(choice_hyperparameter) == 0:
89+
if len(categorical_columns) == 0:
9090
assert len(choice_hyperparameter.value_range) == 1
9191
assert 'NoEncoder' in choice_hyperparameter.value_range, \
9292
"Provided {} in choices, however, the dataset " \

0 commit comments

Comments
 (0)