diff --git a/autoPyTorch/pipeline/components/training/trainer/AdversarialTrainer.py b/autoPyTorch/pipeline/components/training/trainer/AdversarialTrainer.py index f50997181..0865e2a59 100644 --- a/autoPyTorch/pipeline/components/training/trainer/AdversarialTrainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/AdversarialTrainer.py @@ -196,22 +196,35 @@ def get_hyperparameter_search_space( add_hyperparameter(cs, epsilon, UniformFloatHyperparameter) add_hyperparameter(cs, use_stochastic_weight_averaging, CategoricalHyperparameter) + snapshot_ensemble_flag = False + if any(use_snapshot_ensemble.value_range): + snapshot_ensemble_flag = True + use_snapshot_ensemble = get_hyperparameter(use_snapshot_ensemble, CategoricalHyperparameter) - se_lastk = get_hyperparameter(se_lastk, Constant) - cs.add_hyperparameters([use_snapshot_ensemble, se_lastk]) - cond = EqualsCondition(se_lastk, use_snapshot_ensemble, True) - cs.add_condition(cond) + cs.add_hyperparameter(use_snapshot_ensemble) + + if snapshot_ensemble_flag: + se_lastk = get_hyperparameter(se_lastk, Constant) + cs.add_hyperparameter(se_lastk) + cond = EqualsCondition(se_lastk, use_snapshot_ensemble, True) + cs.add_condition(cond) + + lookahead_flag = False + if any(use_lookahead_optimizer.value_range): + lookahead_flag = True use_lookahead_optimizer = get_hyperparameter(use_lookahead_optimizer, CategoricalHyperparameter) cs.add_hyperparameter(use_lookahead_optimizer) - la_config_space = Lookahead.get_hyperparameter_search_space(la_steps=la_steps, - la_alpha=la_alpha) - parent_hyperparameter = {'parent': use_lookahead_optimizer, 'value': True} - cs.add_configuration_space( - Lookahead.__name__, - la_config_space, - parent_hyperparameter=parent_hyperparameter - ) + + if lookahead_flag: + la_config_space = Lookahead.get_hyperparameter_search_space(la_steps=la_steps, + la_alpha=la_alpha) + parent_hyperparameter = {'parent': use_lookahead_optimizer, 'value': True} + cs.add_configuration_space( + Lookahead.__name__, + la_config_space, + parent_hyperparameter=parent_hyperparameter + ) """ if dataset_properties is not None: