Skip to content

Commit eda89f4

Browse files
[ADD] Enable long running regression (#251)
* Early stop on metric * Enable long run regression * Move from deterministic score to lower bound
1 parent 6a8155f commit eda89f4

File tree

6 files changed

+267
-6
lines changed

6 files changed

+267
-6
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
name: Tests
2+
3+
on:
4+
schedule:
5+
# Every Truesday at 7AM UTC
6+
# TODO teporary set to every day just for the PR
7+
#- cron: '0 07 * * 2'
8+
- cron: '0 07 * * *'
9+
10+
11+
jobs:
12+
ubuntu:
13+
14+
runs-on: ubuntu-latest
15+
strategy:
16+
matrix:
17+
python-version: [3.8]
18+
fail-fast: false
19+
20+
steps:
21+
- uses: actions/checkout@v2
22+
with:
23+
ref: development
24+
- name: Setup Python ${{ matrix.python-version }}
25+
uses: actions/setup-python@v2
26+
with:
27+
python-version: ${{ matrix.python-version }}
28+
- name: Install test dependencies
29+
run: |
30+
git submodule update --init --recursive
31+
python -m pip install --upgrade pip
32+
pip install -e .[test]
33+
- name: Run tests
34+
run: |
35+
python -m pytest --durations=200 cicd/test_preselected_configs.py -vs

autoPyTorch/api/base_task.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,8 @@ def _search(
838838
self._metric = get_metrics(
839839
names=[optimize_metric], dataset_properties=dataset_properties)[0]
840840

841+
self.pipeline_options['optimize_metric'] = optimize_metric
842+
841843
self.search_space = self.get_search_space(dataset)
842844

843845
budget_config: Dict[str, Union[float, str]] = {}

autoPyTorch/pipeline/components/training/trainer/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
271271
self.run_summary = RunSummary(
272272
total_parameter_count,
273273
trainable_parameter_count,
274+
optimize_metric=None if not X['metrics_during_training'] else X.get('optimize_metric'),
274275
)
275276

276277
epoch = 1
@@ -329,9 +330,9 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
329330

330331
# wrap up -- add score if not evaluating every epoch
331332
if not self.eval_valid_each_epoch(X):
332-
val_loss, val_metrics = self.choice.evaluate(X['val_data_loader'])
333+
val_loss, val_metrics = self.choice.evaluate(X['val_data_loader'], epoch, writer)
333334
if 'test_data_loader' in X and X['val_data_loader']:
334-
test_loss, test_metrics = self.choice.evaluate(X['test_data_loader'])
335+
test_loss, test_metrics = self.choice.evaluate(X['test_data_loader'], epoch, writer)
335336
self.run_summary.add_performance(
336337
epoch=epoch,
337338
start_time=start_time,

autoPyTorch/pipeline/components/training/trainer/base_trainer.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from autoPyTorch.constants import REGRESSION_TASKS
1717
from autoPyTorch.pipeline.components.training.base_training import autoPyTorchTrainingComponent
18+
from autoPyTorch.pipeline.components.training.metrics.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS
1819
from autoPyTorch.pipeline.components.training.metrics.utils import calculate_score
1920
from autoPyTorch.utils.implementations import get_loss_weight_strategy
2021

@@ -61,6 +62,7 @@ def __init__(
6162
self,
6263
total_parameter_count: float,
6364
trainable_parameter_count: float,
65+
optimize_metric: Optional[str] = None,
6466
):
6567
"""
6668
A useful object to track performance per epoch.
@@ -77,6 +79,7 @@ def __init__(
7779

7880
self.total_parameter_count = total_parameter_count
7981
self.trainable_parameter_count = trainable_parameter_count
82+
self.optimize_metric = optimize_metric
8083

8184
# Allow to track the training performance
8285
self.performance_tracker['train_loss'] = {}
@@ -116,10 +119,26 @@ def add_performance(self,
116119
self.performance_tracker['test_metrics'][epoch] = test_metrics
117120

118121
def get_best_epoch(self, loss_type: str = 'val_loss') -> int:
119-
return np.argmin(
120-
[self.performance_tracker[loss_type][e]
121-
for e in range(1, len(self.performance_tracker[loss_type]) + 1)]
122-
) + 1 # Epochs start at 1
122+
123+
# If we compute validation scores, prefer the performance
124+
# metric to the loss
125+
if self.optimize_metric is not None:
126+
scorer = CLASSIFICATION_METRICS[
127+
self.optimize_metric
128+
] if self.optimize_metric in CLASSIFICATION_METRICS else REGRESSION_METRICS[
129+
self.optimize_metric
130+
]
131+
# Some metrics maximize, other minimize!
132+
opt_func = np.argmax if scorer._sign > 0 else np.argmin
133+
return opt_func(
134+
[self.performance_tracker['val_metrics'][e][self.optimize_metric]
135+
for e in range(1, len(self.performance_tracker['val_metrics']) + 1)]
136+
) + 1 # Epochs start at 1
137+
else:
138+
return np.argmin(
139+
[self.performance_tracker[loss_type][e]
140+
for e in range(1, len(self.performance_tracker[loss_type]) + 1)]
141+
) + 1 # Epochs start at 1
123142

124143
def get_last_epoch(self) -> int:
125144
if 'train_loss' not in self.performance_tracker:

cicd/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
###########################################################
2+
# Continuous integration and continuous delivery/deployment
3+
###########################################################
4+
5+
This part of the code is tasked to make sure that we can perform reliable NAS.
6+
To this end, we rely on pytest to run some long running configurations from both
7+
the greedy portafolio and the default configuration.
8+
9+
```
10+
python -m pytest cicd/test_preselected_configs.py -vs
11+
```

cicd/test_preselected_configs.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import copy
2+
import logging.handlers
3+
import os
4+
import random
5+
import tempfile
6+
import time
7+
8+
import numpy as np
9+
10+
import openml
11+
12+
import pytest
13+
14+
import sklearn.datasets
15+
16+
import torch
17+
18+
from autoPyTorch.automl_common.common.utils.backend import create
19+
from autoPyTorch.data.tabular_validator import TabularInputValidator
20+
from autoPyTorch.datasets.resampling_strategy import (
21+
CrossValTypes,
22+
HoldoutValTypes,
23+
)
24+
from autoPyTorch.datasets.tabular_dataset import TabularDataset
25+
from autoPyTorch.optimizer.utils import read_return_initial_configurations
26+
from autoPyTorch.pipeline.components.training.metrics.metrics import (
27+
accuracy,
28+
balanced_accuracy,
29+
roc_auc,
30+
)
31+
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
32+
from autoPyTorch.utils.pipeline import get_dataset_requirements
33+
34+
35+
def get_backend_dirs_for_openml_task(openml_task_id):
36+
temporary_directory = os.path.join(tempfile.gettempdir(), f"tmp_{openml_task_id}_{time.time()}")
37+
output_directory = os.path.join(tempfile.gettempdir(), f"out_{openml_task_id}_{time.time()}")
38+
return temporary_directory, output_directory
39+
40+
41+
def get_fit_dictionary(openml_task_id):
42+
# Make sure everything from here onwards is reproducible
43+
# Add CUDA for future testing also
44+
seed = 42
45+
random.seed(seed)
46+
torch.manual_seed(seed)
47+
torch.cuda.manual_seed(seed)
48+
torch.backends.cudnn.enabled = False
49+
torch.backends.cudnn.deterministic = True
50+
torch.backends.cudnn.benchmark = False
51+
np.random.seed(seed)
52+
53+
task = openml.tasks.get_task(openml_task_id)
54+
temporary_directory, output_directory = get_backend_dirs_for_openml_task(openml_task_id)
55+
backend = create(
56+
temporary_directory=temporary_directory,
57+
output_directory=output_directory,
58+
delete_tmp_folder_after_terminate=False,
59+
delete_output_folder_after_terminate=False,
60+
prefix='autoPyTorch'
61+
)
62+
X, y = sklearn.datasets.fetch_openml(data_id=task.dataset_id, return_X_y=True, as_frame=True)
63+
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
64+
X, y, random_state=seed)
65+
validator = TabularInputValidator(
66+
is_classification='classification' in task.task_type.lower()).fit(X.copy(), y.copy())
67+
datamanager = TabularDataset(
68+
dataset_name=openml.datasets.get_dataset(task.dataset_id, download_data=False).name,
69+
X=X_train, Y=y_train,
70+
validator=validator,
71+
X_test=X_test, Y_test=y_test,
72+
resampling_strategy=CrossValTypes.stratified_k_fold_cross_validation
73+
if 'cross' in str(task.estimation_procedure) else HoldoutValTypes.holdout_validation
74+
)
75+
76+
info = datamanager.get_required_dataset_info()
77+
78+
dataset_properties = datamanager.get_dataset_properties(get_dataset_requirements(info))
79+
fit_dictionary = {
80+
'X_train': datamanager.train_tensors[0],
81+
'y_train': datamanager.train_tensors[1],
82+
'train_indices': datamanager.splits[0][0],
83+
'val_indices': datamanager.splits[0][1],
84+
'dataset_properties': dataset_properties,
85+
'num_run': openml_task_id,
86+
'device': 'cpu',
87+
'budget_type': 'epochs',
88+
'epochs': 200,
89+
'torch_num_threads': 1,
90+
'early_stopping': 100,
91+
'working_dir': '/tmp',
92+
'use_tensorboard_logger': False,
93+
'metrics_during_training': True,
94+
'split_id': 0,
95+
'backend': backend,
96+
'logger_port': logging.handlers.DEFAULT_TCP_LOGGING_PORT,
97+
}
98+
backend.save_datamanager(datamanager)
99+
return fit_dictionary
100+
101+
102+
@pytest.mark.parametrize(
103+
'openml_task_id,configuration,scorer,lower_bound_score',
104+
(
105+
# Australian
106+
(146818, 0, balanced_accuracy, 0.85),
107+
(146818, 1, roc_auc, 0.90),
108+
(146818, 2, balanced_accuracy, 0.80),
109+
(146818, 3, balanced_accuracy, 0.85),
110+
# credit-g
111+
(31, 0, accuracy, 0.75),
112+
(31, 1, accuracy, 0.75),
113+
(31, 2, accuracy, 0.75),
114+
(31, 3, accuracy, 0.70),
115+
(31, 4, accuracy, 0.70),
116+
# segment
117+
(146822, 'default', accuracy, 0.90),
118+
# kr-vs-kp
119+
(3, 'default', accuracy, 0.90),
120+
# vehicle
121+
(53, 'default', accuracy, 0.75),
122+
),
123+
)
124+
def test_can_properly_fit_a_config(openml_task_id, configuration, scorer, lower_bound_score):
125+
126+
fit_dictionary = get_fit_dictionary(openml_task_id)
127+
fit_dictionary['additional_metrics'] = [scorer.name]
128+
fit_dictionary['optimize_metric'] = scorer.name
129+
130+
pipeline = TabularClassificationPipeline(
131+
dataset_properties=fit_dictionary['dataset_properties'])
132+
cs = pipeline.get_hyperparameter_search_space()
133+
if configuration == 'default':
134+
config = cs.get_default_configuration()
135+
else:
136+
# Else configuration indicates what index of the greedy config
137+
config = read_return_initial_configurations(
138+
config_space=cs,
139+
portfolio_selection="greedy",
140+
)[configuration]
141+
pipeline.set_hyperparameters(config)
142+
pipeline.fit(copy.deepcopy(fit_dictionary))
143+
144+
# First we make sure performance is deterministic
145+
# As we use the validation performance for early stopping, this is
146+
# not the true generalization performance, but our goal is to test
147+
# that we can learn the data and capture wrong configurations
148+
149+
# Sadly, when using batch norm we have results that are dependent on the current
150+
# torch manual seed. Set seed zero here to make this test reproducible
151+
torch.manual_seed(0)
152+
val_indices = fit_dictionary['val_indices']
153+
train_data, target_data = fit_dictionary['backend'].load_datamanager().train_tensors
154+
predictions = pipeline.predict(train_data[val_indices])
155+
score = scorer(fit_dictionary['y_train'][val_indices], predictions)
156+
assert pytest.approx(score) >= lower_bound_score
157+
158+
# Check that we reverted to the best score
159+
run_summary = pipeline.named_steps['trainer'].run_summary
160+
161+
# Then check that the training progressed nicely
162+
# We fit a file to have the trajectory-tendency
163+
# Some epochs might be bad, but overall we should make progress
164+
train_scores = [run_summary.performance_tracker['train_metrics'][e][scorer.name]
165+
for e in range(1, len(run_summary.performance_tracker['train_metrics']) + 1)]
166+
slope, intersect = np.polyfit(np.arange(len(train_scores)), train_scores, 1)
167+
if scorer._sign > 0:
168+
# We expect an increasing trajectory of training
169+
assert train_scores[0] < train_scores[-1]
170+
assert slope > 0
171+
else:
172+
# We expect a decreasing trajectory of training
173+
assert train_scores[0] > train_scores[-1]
174+
assert slope < 0
175+
176+
# We do not expect the network to output zeros during training.
177+
# We add this check to prevent a dropout bug we had, where dropout probability
178+
# was a bool, not a float
179+
network = pipeline.named_steps['network'].network
180+
network.train()
181+
global_accumulator = {}
182+
183+
def forward_hook(module, X_in, X_out):
184+
global_accumulator[f"{id(module)}_{module.__class__.__name__}"] = torch.mean(X_out)
185+
186+
for i, (hierarchy, module) in enumerate(network.named_modules()):
187+
module.register_forward_hook(forward_hook)
188+
pipeline.predict(train_data[val_indices])
189+
for module_name, mean_tensor in global_accumulator.items():
190+
# The global accumulator has the output of each layer
191+
# of the network. If an output of any layer is zero, this
192+
# check will fail
193+
assert mean_tensor != 0, module_name

0 commit comments

Comments
 (0)