Skip to content

Commit f0994c0

Browse files
committed
Fix bug in getting dataset requirements
1 parent 65fbe33 commit f0994c0

File tree

3 files changed

+64
-60
lines changed

3 files changed

+64
-60
lines changed

autoPyTorch/api/base_task.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,10 @@ def get_search_space(self, dataset: BaseDataset = None) -> ConfigurationSpace:
259259
return self.search_space
260260
elif dataset is not None:
261261
dataset_requirements = get_dataset_requirements(
262-
info=dataset.get_required_dataset_info())
262+
info=dataset.get_required_dataset_info(),
263+
include=self.include_components,
264+
exclude=self.exclude_components,
265+
search_space_updates=self.search_space_updates)
263266
return get_configuration_space(info=dataset.get_dataset_properties(dataset_requirements),
264267
include=self.include_components,
265268
exclude=self.exclude_components,
@@ -771,7 +774,10 @@ def _search(
771774
# Initialise information needed for the experiment
772775
experiment_task_name = 'runSearch'
773776
dataset_requirements = get_dataset_requirements(
774-
info=dataset.get_required_dataset_info())
777+
info=dataset.get_required_dataset_info(),
778+
include=self.include_components,
779+
exclude=self.exclude_components,
780+
search_space_updates=self.search_space_updates)
775781
self._dataset_requirements = dataset_requirements
776782
dataset_properties = dataset.get_dataset_properties(dataset_requirements)
777783
self._stopwatch.start_task(experiment_task_name)
@@ -1027,7 +1033,10 @@ def refit(
10271033
self._logger = self._get_logger(self.dataset_name)
10281034

10291035
dataset_requirements = get_dataset_requirements(
1030-
info=dataset.get_required_dataset_info())
1036+
info=dataset.get_required_dataset_info(),
1037+
include=self.include_components,
1038+
exclude=self.exclude_components,
1039+
search_space_updates=self.search_space_updates)
10311040
dataset_properties = dataset.get_dataset_properties(dataset_requirements)
10321041
self._backend.save_datamanager(dataset)
10331042

@@ -1098,7 +1107,10 @@ def fit(self,
10981107

10991108
# get dataset properties
11001109
dataset_requirements = get_dataset_requirements(
1101-
info=dataset.get_required_dataset_info())
1110+
info=dataset.get_required_dataset_info(),
1111+
include=self.include_components,
1112+
exclude=self.exclude_components,
1113+
search_space_updates=self.search_space_updates)
11021114
dataset_properties = dataset.get_dataset_properties(dataset_requirements)
11031115
self._backend.save_datamanager(dataset)
11041116

autoPyTorch/evaluation/abstract_evaluator.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class MyTraditionalTabularClassificationPipeline(BaseEstimator):
7070
An optional dictionary that is passed to the pipeline's steps. It complies
7171
a similar function as the kwargs
7272
"""
73+
7374
def __init__(self, config: str,
7475
dataset_properties: Dict[str, Any],
7576
random_state: Optional[Union[int, np.random.RandomState]] = None,
@@ -78,7 +79,7 @@ def __init__(self, config: str,
7879
self.dataset_properties = dataset_properties
7980
self.random_state = random_state
8081
self.init_params = init_params
81-
self.pipeline = autoPyTorch.pipeline.traditional_tabular_classification.\
82+
self.pipeline = autoPyTorch.pipeline.traditional_tabular_classification. \
8283
TraditionalTabularClassificationPipeline(dataset_properties=dataset_properties)
8384
configuration_space = self.pipeline.get_hyperparameter_search_space()
8485
default_configuration = configuration_space.get_default_configuration().get_dictionary()
@@ -129,6 +130,7 @@ class DummyClassificationPipeline(DummyClassifier):
129130
An optional dictionary that is passed to the pipeline's steps. It complies
130131
a similar function as the kwargs
131132
"""
133+
132134
def __init__(self, config: Configuration,
133135
random_state: Optional[Union[int, np.random.RandomState]] = None,
134136
init_params: Optional[Dict] = None
@@ -194,6 +196,7 @@ class DummyRegressionPipeline(DummyRegressor):
194196
An optional dictionary that is passed to the pipeline's steps. It complies
195197
a similar function as the kwargs
196198
"""
199+
197200
def __init__(self, config: Configuration,
198201
random_state: Optional[Union[int, np.random.RandomState]] = None,
199202
init_params: Optional[Dict] = None) -> None:
@@ -339,7 +342,11 @@ def __init__(self, backend: Backend,
339342
raise ValueError('task {} not available'.format(self.task_type))
340343
self.predict_function = self._predict_proba
341344
self.dataset_properties = self.datamanager.get_dataset_properties(
342-
get_dataset_requirements(self.datamanager.get_required_dataset_info()))
345+
get_dataset_requirements(info=self.datamanager.get_required_dataset_info(),
346+
include=self.include,
347+
exclude=self.exclude,
348+
search_space_updates=self.search_space_updates
349+
))
343350

344351
self.additional_metrics: Optional[List[autoPyTorchMetric]] = None
345352
if all_supported_metrics:
@@ -483,9 +490,9 @@ def finish_up(self, loss: Dict[str, float], train_loss: Dict[str, float],
483490
return None
484491

485492
def calculate_auxiliary_losses(
486-
self,
487-
Y_valid_pred: np.ndarray,
488-
Y_test_pred: np.ndarray,
493+
self,
494+
Y_valid_pred: np.ndarray,
495+
Y_test_pred: np.ndarray,
489496
) -> Tuple[Optional[float], Optional[float]]:
490497

491498
validation_loss: Optional[float] = None
@@ -504,10 +511,10 @@ def calculate_auxiliary_losses(
504511
return validation_loss, test_loss
505512

506513
def file_output(
507-
self,
508-
Y_optimization_pred: np.ndarray,
509-
Y_valid_pred: np.ndarray,
510-
Y_test_pred: np.ndarray
514+
self,
515+
Y_optimization_pred: np.ndarray,
516+
Y_valid_pred: np.ndarray,
517+
Y_test_pred: np.ndarray
511518
) -> Tuple[Optional[float], Dict]:
512519
# Abort if self.Y_optimization is None
513520
# self.Y_optimization can be None if we use partial-cv, then,

autoPyTorch/utils/pipeline.py

Lines changed: 32 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from ConfigSpace.configuration_space import ConfigurationSpace
55

66
from autoPyTorch.constants import (
7-
CLASSIFICATION_TASKS,
87
IMAGE_TASKS,
98
REGRESSION_TASKS,
109
STRING_TO_TASK_TYPES,
@@ -23,75 +22,61 @@
2322

2423

2524
def get_dataset_requirements(info: Dict[str, Any],
26-
include_estimators: Optional[List[str]] = None,
27-
exclude_estimators: Optional[List[str]] = None,
28-
include_preprocessors: Optional[List[str]] = None,
29-
exclude_preprocessors: Optional[List[str]] = None
25+
include: Optional[Dict] = None,
26+
exclude: Optional[Dict] = None,
27+
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
3028
) -> List[FitRequirement]:
31-
exclude = dict()
32-
include = dict()
33-
if include_preprocessors is not None and \
34-
exclude_preprocessors is not None:
35-
raise ValueError('Cannot specify include_preprocessors and '
36-
'exclude_preprocessors.')
37-
elif include_preprocessors is not None:
38-
include['feature_preprocessor'] = include_preprocessors
39-
elif exclude_preprocessors is not None:
40-
exclude['feature_preprocessor'] = exclude_preprocessors
41-
4229
task_type: int = STRING_TO_TASK_TYPES[info['task_type']]
43-
if include_estimators is not None and \
44-
exclude_estimators is not None:
45-
raise ValueError('Cannot specify include_estimators and '
46-
'exclude_estimators.')
47-
elif include_estimators is not None:
48-
if task_type in CLASSIFICATION_TASKS:
49-
include['classifier'] = include_estimators
50-
elif task_type in REGRESSION_TASKS:
51-
include['regressor'] = include_estimators
52-
else:
53-
raise ValueError(info['task_type'])
54-
elif exclude_estimators is not None:
55-
if task_type in CLASSIFICATION_TASKS:
56-
exclude['classifier'] = exclude_estimators
57-
elif task_type in REGRESSION_TASKS:
58-
exclude['regressor'] = exclude_estimators
59-
else:
60-
raise ValueError(info['task_type'])
61-
6230
if task_type in REGRESSION_TASKS:
63-
return _get_regression_dataset_requirements(info, include, exclude)
31+
return _get_regression_dataset_requirements(info,
32+
include if include is not None else {},
33+
exclude if exclude is not None else {},
34+
search_space_updates=search_space_updates
35+
)
6436
else:
65-
return _get_classification_dataset_requirements(info, include, exclude)
66-
67-
68-
def _get_regression_dataset_requirements(info: Dict[str, Any], include: Dict[str, List[str]],
69-
exclude: Dict[str, List[str]]) -> List[FitRequirement]:
37+
return _get_classification_dataset_requirements(info,
38+
include if include is not None else {},
39+
exclude if exclude is not None else {},
40+
search_space_updates=search_space_updates
41+
)
42+
43+
44+
def _get_regression_dataset_requirements(info: Dict[str, Any],
45+
include: Optional[Dict] = None,
46+
exclude: Optional[Dict] = None,
47+
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
48+
) -> List[FitRequirement]:
7049
task_type = STRING_TO_TASK_TYPES[info['task_type']]
7150
if task_type in TABULAR_TASKS:
7251
fit_requirements = TabularRegressionPipeline(
7352
dataset_properties=info,
7453
include=include,
75-
exclude=exclude
54+
exclude=exclude,
55+
search_space_updates=search_space_updates
7656
).get_dataset_requirements()
7757
return fit_requirements
7858
else:
7959
raise ValueError("Task_type not supported")
8060

8161

82-
def _get_classification_dataset_requirements(info: Dict[str, Any], include: Dict[str, List[str]],
83-
exclude: Dict[str, List[str]]) -> List[FitRequirement]:
62+
def _get_classification_dataset_requirements(info: Dict[str, Any],
63+
include: Optional[Dict] = None,
64+
exclude: Optional[Dict] = None,
65+
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
66+
) -> List[FitRequirement]:
8467
task_type = STRING_TO_TASK_TYPES[info['task_type']]
8568

8669
if task_type in TABULAR_TASKS:
8770
return TabularClassificationPipeline(
8871
dataset_properties=info,
89-
include=include, exclude=exclude).\
72+
include=include, exclude=exclude,
73+
search_space_updates=search_space_updates). \
9074
get_dataset_requirements()
9175
elif task_type in IMAGE_TASKS:
9276
return ImageClassificationPipeline(
9377
dataset_properties=info,
94-
include=include, exclude=exclude).\
78+
include=include, exclude=exclude,
79+
search_space_updates=search_space_updates). \
9580
get_dataset_requirements()
9681
else:
9782
raise ValueError("Task_type not supported")
@@ -147,7 +132,7 @@ def _get_classification_configuration_space(info: Dict[str, Any], include: Dict[
147132
return ImageClassificationPipeline(
148133
dataset_properties=info,
149134
include=include, exclude=exclude,
150-
search_space_updates=search_space_updates).\
135+
search_space_updates=search_space_updates). \
151136
get_hyperparameter_search_space()
152137
else:
153138
raise ValueError("Task_type not supported")

0 commit comments

Comments
 (0)