|
4 | 4 | from ConfigSpace.configuration_space import ConfigurationSpace
|
5 | 5 |
|
6 | 6 | from autoPyTorch.constants import (
|
7 |
| - CLASSIFICATION_TASKS, |
8 | 7 | IMAGE_TASKS,
|
9 | 8 | REGRESSION_TASKS,
|
10 | 9 | STRING_TO_TASK_TYPES,
|
|
23 | 22 |
|
24 | 23 |
|
25 | 24 | def get_dataset_requirements(info: Dict[str, Any],
|
26 |
| - include_estimators: Optional[List[str]] = None, |
27 |
| - exclude_estimators: Optional[List[str]] = None, |
28 |
| - include_preprocessors: Optional[List[str]] = None, |
29 |
| - exclude_preprocessors: Optional[List[str]] = None |
| 25 | + include: Optional[Dict] = None, |
| 26 | + exclude: Optional[Dict] = None, |
| 27 | + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None |
30 | 28 | ) -> List[FitRequirement]:
|
31 |
| - exclude = dict() |
32 |
| - include = dict() |
33 |
| - if include_preprocessors is not None and \ |
34 |
| - exclude_preprocessors is not None: |
35 |
| - raise ValueError('Cannot specify include_preprocessors and ' |
36 |
| - 'exclude_preprocessors.') |
37 |
| - elif include_preprocessors is not None: |
38 |
| - include['feature_preprocessor'] = include_preprocessors |
39 |
| - elif exclude_preprocessors is not None: |
40 |
| - exclude['feature_preprocessor'] = exclude_preprocessors |
41 |
| - |
42 | 29 | task_type: int = STRING_TO_TASK_TYPES[info['task_type']]
|
43 |
| - if include_estimators is not None and \ |
44 |
| - exclude_estimators is not None: |
45 |
| - raise ValueError('Cannot specify include_estimators and ' |
46 |
| - 'exclude_estimators.') |
47 |
| - elif include_estimators is not None: |
48 |
| - if task_type in CLASSIFICATION_TASKS: |
49 |
| - include['classifier'] = include_estimators |
50 |
| - elif task_type in REGRESSION_TASKS: |
51 |
| - include['regressor'] = include_estimators |
52 |
| - else: |
53 |
| - raise ValueError(info['task_type']) |
54 |
| - elif exclude_estimators is not None: |
55 |
| - if task_type in CLASSIFICATION_TASKS: |
56 |
| - exclude['classifier'] = exclude_estimators |
57 |
| - elif task_type in REGRESSION_TASKS: |
58 |
| - exclude['regressor'] = exclude_estimators |
59 |
| - else: |
60 |
| - raise ValueError(info['task_type']) |
61 |
| - |
62 | 30 | if task_type in REGRESSION_TASKS:
|
63 |
| - return _get_regression_dataset_requirements(info, include, exclude) |
| 31 | + return _get_regression_dataset_requirements(info, |
| 32 | + include if include is not None else {}, |
| 33 | + exclude if exclude is not None else {}, |
| 34 | + search_space_updates=search_space_updates |
| 35 | + ) |
64 | 36 | else:
|
65 |
| - return _get_classification_dataset_requirements(info, include, exclude) |
66 |
| - |
67 |
| - |
68 |
| -def _get_regression_dataset_requirements(info: Dict[str, Any], include: Dict[str, List[str]], |
69 |
| - exclude: Dict[str, List[str]]) -> List[FitRequirement]: |
| 37 | + return _get_classification_dataset_requirements(info, |
| 38 | + include if include is not None else {}, |
| 39 | + exclude if exclude is not None else {}, |
| 40 | + search_space_updates=search_space_updates |
| 41 | + ) |
| 42 | + |
| 43 | + |
| 44 | +def _get_regression_dataset_requirements(info: Dict[str, Any], |
| 45 | + include: Optional[Dict] = None, |
| 46 | + exclude: Optional[Dict] = None, |
| 47 | + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None |
| 48 | + ) -> List[FitRequirement]: |
70 | 49 | task_type = STRING_TO_TASK_TYPES[info['task_type']]
|
71 | 50 | if task_type in TABULAR_TASKS:
|
72 | 51 | fit_requirements = TabularRegressionPipeline(
|
73 | 52 | dataset_properties=info,
|
74 | 53 | include=include,
|
75 |
| - exclude=exclude |
| 54 | + exclude=exclude, |
| 55 | + search_space_updates=search_space_updates |
76 | 56 | ).get_dataset_requirements()
|
77 | 57 | return fit_requirements
|
78 | 58 | else:
|
79 | 59 | raise ValueError("Task_type not supported")
|
80 | 60 |
|
81 | 61 |
|
82 |
| -def _get_classification_dataset_requirements(info: Dict[str, Any], include: Dict[str, List[str]], |
83 |
| - exclude: Dict[str, List[str]]) -> List[FitRequirement]: |
| 62 | +def _get_classification_dataset_requirements(info: Dict[str, Any], |
| 63 | + include: Optional[Dict] = None, |
| 64 | + exclude: Optional[Dict] = None, |
| 65 | + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None |
| 66 | + ) -> List[FitRequirement]: |
84 | 67 | task_type = STRING_TO_TASK_TYPES[info['task_type']]
|
85 | 68 |
|
86 | 69 | if task_type in TABULAR_TASKS:
|
87 | 70 | return TabularClassificationPipeline(
|
88 | 71 | dataset_properties=info,
|
89 |
| - include=include, exclude=exclude).\ |
| 72 | + include=include, exclude=exclude, |
| 73 | + search_space_updates=search_space_updates). \ |
90 | 74 | get_dataset_requirements()
|
91 | 75 | elif task_type in IMAGE_TASKS:
|
92 | 76 | return ImageClassificationPipeline(
|
93 | 77 | dataset_properties=info,
|
94 |
| - include=include, exclude=exclude).\ |
| 78 | + include=include, exclude=exclude, |
| 79 | + search_space_updates=search_space_updates). \ |
95 | 80 | get_dataset_requirements()
|
96 | 81 | else:
|
97 | 82 | raise ValueError("Task_type not supported")
|
@@ -147,7 +132,7 @@ def _get_classification_configuration_space(info: Dict[str, Any], include: Dict[
|
147 | 132 | return ImageClassificationPipeline(
|
148 | 133 | dataset_properties=info,
|
149 | 134 | include=include, exclude=exclude,
|
150 |
| - search_space_updates=search_space_updates).\ |
| 135 | + search_space_updates=search_space_updates). \ |
151 | 136 | get_hyperparameter_search_space()
|
152 | 137 | else:
|
153 | 138 | raise ValueError("Task_type not supported")
|
0 commit comments