diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index 56d849205..48edeb9a5 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -196,14 +196,6 @@ def __init__( raise ValueError("Expected search space updates to be of instance" " HyperparameterSearchSpaceUpdates got {}".format(type(self.search_space_updates))) - @abstractmethod - def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, Any]: - """ - given a pipeline type, this function returns the - dataset properties required by the dataset object - """ - raise NotImplementedError - @abstractmethod def build_pipeline(self, dataset_properties: Dict[str, Any]) -> BasePipeline: """ @@ -267,7 +259,10 @@ def get_search_space(self, dataset: BaseDataset = None) -> ConfigurationSpace: return self.search_space elif dataset is not None: dataset_requirements = get_dataset_requirements( - info=self._get_required_dataset_properties(dataset)) + info=dataset.get_required_dataset_info(), + include=self.include_components, + exclude=self.exclude_components, + search_space_updates=self.search_space_updates) return get_configuration_space(info=dataset.get_dataset_properties(dataset_requirements), include=self.include_components, exclude=self.exclude_components, @@ -785,7 +780,10 @@ def _search( # Initialise information needed for the experiment experiment_task_name: str = 'runSearch' dataset_requirements = get_dataset_requirements( - info=self._get_required_dataset_properties(dataset)) + info=dataset.get_required_dataset_info(), + include=self.include_components, + exclude=self.exclude_components, + search_space_updates=self.search_space_updates) self._dataset_requirements = dataset_requirements dataset_properties = dataset.get_dataset_properties(dataset_requirements) self._stopwatch.start_task(experiment_task_name) @@ -1049,7 +1047,10 @@ def refit( self._logger = self._get_logger(str(self.dataset_name)) dataset_requirements = get_dataset_requirements( - info=self._get_required_dataset_properties(dataset)) + info=dataset.get_required_dataset_info(), + include=self.include_components, + exclude=self.exclude_components, + search_space_updates=self.search_space_updates) dataset_properties = dataset.get_dataset_properties(dataset_requirements) self._backend.save_datamanager(dataset) @@ -1119,7 +1120,10 @@ def fit(self, # get dataset properties dataset_requirements = get_dataset_requirements( - info=self._get_required_dataset_properties(dataset)) + info=dataset.get_required_dataset_info(), + include=self.include_components, + exclude=self.exclude_components, + search_space_updates=self.search_space_updates) dataset_properties = dataset.get_dataset_properties(dataset_requirements) self._backend.save_datamanager(dataset) diff --git a/autoPyTorch/api/tabular_classification.py b/autoPyTorch/api/tabular_classification.py index c7b77c4d0..9662afd67 100644 --- a/autoPyTorch/api/tabular_classification.py +++ b/autoPyTorch/api/tabular_classification.py @@ -13,7 +13,6 @@ TASK_TYPES_TO_STRING, ) from autoPyTorch.data.tabular_validator import TabularInputValidator -from autoPyTorch.datasets.base_dataset import BaseDataset from autoPyTorch.datasets.resampling_strategy import ( CrossValTypes, HoldoutValTypes, @@ -97,17 +96,6 @@ def __init__( task_type=TASK_TYPES_TO_STRING[TABULAR_CLASSIFICATION], ) - def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, Any]: - if not isinstance(dataset, TabularDataset): - raise ValueError("Dataset is incompatible for the given task,: {}".format( - type(dataset) - )) - return {'task_type': dataset.task_type, - 'output_type': dataset.output_type, - 'issparse': dataset.issparse, - 'numerical_columns': dataset.numerical_columns, - 'categorical_columns': dataset.categorical_columns} - def build_pipeline(self, dataset_properties: Dict[str, Any]) -> TabularClassificationPipeline: return TabularClassificationPipeline(dataset_properties=dataset_properties) diff --git a/autoPyTorch/api/tabular_regression.py b/autoPyTorch/api/tabular_regression.py index 098f2c506..a6f46d2fe 100644 --- a/autoPyTorch/api/tabular_regression.py +++ b/autoPyTorch/api/tabular_regression.py @@ -13,7 +13,6 @@ TASK_TYPES_TO_STRING ) from autoPyTorch.data.tabular_validator import TabularInputValidator -from autoPyTorch.datasets.base_dataset import BaseDataset from autoPyTorch.datasets.resampling_strategy import ( CrossValTypes, HoldoutValTypes, @@ -89,17 +88,6 @@ def __init__( task_type=TASK_TYPES_TO_STRING[TABULAR_REGRESSION], ) - def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, Any]: - if not isinstance(dataset, TabularDataset): - raise ValueError("Dataset is incompatible for the given task,: {}".format( - type(dataset) - )) - return {'task_type': dataset.task_type, - 'output_type': dataset.output_type, - 'issparse': dataset.issparse, - 'numerical_columns': dataset.numerical_columns, - 'categorical_columns': dataset.categorical_columns} - def build_pipeline(self, dataset_properties: Dict[str, Any]) -> TabularRegressionPipeline: return TabularRegressionPipeline(dataset_properties=dataset_properties) diff --git a/autoPyTorch/datasets/base_dataset.py b/autoPyTorch/datasets/base_dataset.py index 1bd283d7b..9955e706f 100644 --- a/autoPyTorch/datasets/base_dataset.py +++ b/autoPyTorch/datasets/base_dataset.py @@ -348,11 +348,17 @@ def replace_data(self, X_train: BaseDatasetInputType, def get_dataset_properties(self, dataset_requirements: List[FitRequirement]) -> Dict[str, Any]: """ - Gets the dataset properties required in the fit dictionary + Gets the dataset properties required in the fit dictionary. + This depends on the components that are active in the + pipeline and returns the properties they need about the dataset. + Information of the required properties of each component + can be found in their documentation. Args: dataset_requirements (List[FitRequirement]): List of fit requirements that the dataset properties must - contain. + contain. This is created using the `get_dataset_requirements + function in + ` Returns: dataset_properties (Dict[str, Any]): @@ -362,19 +368,15 @@ def get_dataset_properties(self, dataset_requirements: List[FitRequirement]) -> for dataset_requirement in dataset_requirements: dataset_properties[dataset_requirement.name] = getattr(self, dataset_requirement.name) - # Add task type, output type and issparse to dataset properties as - # they are not a dataset requirement in the pipeline - dataset_properties.update({'task_type': self.task_type, - 'output_type': self.output_type, - 'issparse': self.issparse, - 'input_shape': self.input_shape, - 'output_shape': self.output_shape - }) + # Add the required dataset info to dataset properties as + # they might not be a dataset requirement in the pipeline + dataset_properties.update(self.get_required_dataset_info()) return dataset_properties def get_required_dataset_info(self) -> Dict[str, Any]: """ - Returns a dictionary containing required dataset properties to instantiate a pipeline, + Returns a dictionary containing required dataset + properties to instantiate a pipeline. """ info = {'output_type': self.output_type, 'issparse': self.issparse} diff --git a/autoPyTorch/datasets/tabular_dataset.py b/autoPyTorch/datasets/tabular_dataset.py index 5087f6886..19e483612 100644 --- a/autoPyTorch/datasets/tabular_dataset.py +++ b/autoPyTorch/datasets/tabular_dataset.py @@ -112,7 +112,24 @@ def __init__(self, def get_required_dataset_info(self) -> Dict[str, Any]: """ - Returns a dictionary containing required dataset properties to instantiate a pipeline, + Returns a dictionary containing required dataset + properties to instantiate a pipeline. + For a Tabular Dataset this includes- + 1. 'output_type'- Enum indicating the type of the output for this problem. + We currently use the `sklearn type_of_target + ` + to infer the output type from the data and we encode it to an + Enum for which you can find more info in `autopytorch/constants.py + ` + 2. 'issparse'- A flag indicating if the input is in a sparse matrix. + 3. 'numerical_columns'- a list which contains the column numbers + for the numerical columns in the input dataset + 4. 'categorical_columns'- a list which contains the column numbers + for the categorical columns in the input dataset + 5. 'task_type'- Enum indicating the type of task. For tabular datasets, + currently we support 'tabular_classification' and 'tabular_regression'. and we encode it to an + Enum for which you can find more info in `autopytorch/constants.py + ` """ info = super().get_required_dataset_info() info.update({ diff --git a/autoPyTorch/evaluation/abstract_evaluator.py b/autoPyTorch/evaluation/abstract_evaluator.py index 156289c60..9a4260e05 100644 --- a/autoPyTorch/evaluation/abstract_evaluator.py +++ b/autoPyTorch/evaluation/abstract_evaluator.py @@ -31,7 +31,6 @@ TABULAR_TASKS, ) from autoPyTorch.datasets.base_dataset import BaseDataset -from autoPyTorch.datasets.tabular_dataset import TabularDataset from autoPyTorch.evaluation.utils import ( VotingRegressorWrapper, convert_multioutput_multiclass_to_multilabel @@ -71,6 +70,7 @@ class MyTraditionalTabularClassificationPipeline(BaseEstimator): An optional dictionary that is passed to the pipeline's steps. It complies a similar function as the kwargs """ + def __init__(self, config: str, dataset_properties: Dict[str, Any], random_state: Optional[Union[int, np.random.RandomState]] = None, @@ -141,6 +141,7 @@ class DummyClassificationPipeline(DummyClassifier): An optional dictionary that is passed to the pipeline's steps. It complies a similar function as the kwargs """ + def __init__(self, config: Configuration, random_state: Optional[Union[int, np.random.RandomState]] = None, init_params: Optional[Dict] = None @@ -208,6 +209,7 @@ class DummyRegressionPipeline(DummyRegressor): An optional dictionary that is passed to the pipeline's steps. It complies a similar function as the kwargs """ + def __init__(self, config: Configuration, random_state: Optional[Union[int, np.random.RandomState]] = None, init_params: Optional[Dict] = None) -> None: @@ -394,12 +396,9 @@ def __init__(self, backend: Backend, raise ValueError('disable_file_output should be either a bool or a list') self.pipeline_class: Optional[Union[BaseEstimator, BasePipeline]] = None - info: Dict[str, Any] = {'task_type': self.datamanager.task_type, - 'output_type': self.datamanager.output_type, - 'issparse': self.issparse} if self.task_type in REGRESSION_TASKS: if isinstance(self.configuration, int): - self.pipeline_class = DummyClassificationPipeline + self.pipeline_class = DummyRegressionPipeline elif isinstance(self.configuration, str): raise ValueError("Only tabular classifications tasks " "are currently supported with traditional methods") @@ -425,11 +424,12 @@ def __init__(self, backend: Backend, else: raise ValueError('task {} not available'.format(self.task_type)) self.predict_function = self._predict_proba - if self.task_type in TABULAR_TASKS: - assert isinstance(self.datamanager, TabularDataset) - info.update({'numerical_columns': self.datamanager.numerical_columns, - 'categorical_columns': self.datamanager.categorical_columns}) - self.dataset_properties = self.datamanager.get_dataset_properties(get_dataset_requirements(info)) + self.dataset_properties = self.datamanager.get_dataset_properties( + get_dataset_requirements(info=self.datamanager.get_required_dataset_info(), + include=self.include, + exclude=self.exclude, + search_space_updates=self.search_space_updates + )) self.additional_metrics: Optional[List[autoPyTorchMetric]] = None if all_supported_metrics: @@ -630,9 +630,9 @@ def finish_up(self, loss: Dict[str, float], train_loss: Dict[str, float], return None def calculate_auxiliary_losses( - self, - Y_valid_pred: np.ndarray, - Y_test_pred: np.ndarray, + self, + Y_valid_pred: np.ndarray, + Y_test_pred: np.ndarray, ) -> Tuple[Optional[float], Optional[float]]: """ A helper function to calculate the performance estimate of the @@ -670,10 +670,10 @@ def calculate_auxiliary_losses( return validation_loss, test_loss def file_output( - self, - Y_optimization_pred: np.ndarray, - Y_valid_pred: np.ndarray, - Y_test_pred: np.ndarray + self, + Y_optimization_pred: np.ndarray, + Y_valid_pred: np.ndarray, + Y_test_pred: np.ndarray ) -> Tuple[Optional[float], Dict]: """ This method decides what file outputs are written to disk. diff --git a/autoPyTorch/utils/pipeline.py b/autoPyTorch/utils/pipeline.py index 3cd0d528f..8276b1243 100644 --- a/autoPyTorch/utils/pipeline.py +++ b/autoPyTorch/utils/pipeline.py @@ -4,7 +4,6 @@ from ConfigSpace.configuration_space import ConfigurationSpace from autoPyTorch.constants import ( - CLASSIFICATION_TASKS, IMAGE_TASKS, REGRESSION_TASKS, STRING_TO_TASK_TYPES, @@ -23,75 +22,91 @@ def get_dataset_requirements(info: Dict[str, Any], - include_estimators: Optional[List[str]] = None, - exclude_estimators: Optional[List[str]] = None, - include_preprocessors: Optional[List[str]] = None, - exclude_preprocessors: Optional[List[str]] = None + include: Optional[Dict] = None, + exclude: Optional[Dict] = None, + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None ) -> List[FitRequirement]: - exclude = dict() - include = dict() - if include_preprocessors is not None and \ - exclude_preprocessors is not None: - raise ValueError('Cannot specify include_preprocessors and ' - 'exclude_preprocessors.') - elif include_preprocessors is not None: - include['feature_preprocessor'] = include_preprocessors - elif exclude_preprocessors is not None: - exclude['feature_preprocessor'] = exclude_preprocessors - + """ + + This function is used to return the dataset + property requirements which are needed to fit + a pipeline created based on the constraints + specified using include, exclude and + search_space_updates + + Args: + info (Dict[str, Any]): + A dictionary that specifies the required information + about the dataset to instantiate a pipeline. For more + info check the get_required_dataset_info of the + appropriate dataset in autoPyTorch/datasets + include (Optional[Dict]), (default=None): + If None, all possible components are used. + Otherwise specifies set of components to use. + exclude (Optional[Dict]), (default=None): + If None, all possible components are used. + Otherwise specifies set of components not to use. + Incompatible with include. + search_space_updates (Optional[HyperparameterSearchSpaceUpdates]): + search space updates that can be used to modify the search + space of particular components or choice modules of the pipeline + + Returns: + List[FitRequirement]: + List of requirements that should be in the fit + dictionary used to fit the pipeline. + """ task_type: int = STRING_TO_TASK_TYPES[info['task_type']] - if include_estimators is not None and \ - exclude_estimators is not None: - raise ValueError('Cannot specify include_estimators and ' - 'exclude_estimators.') - elif include_estimators is not None: - if task_type in CLASSIFICATION_TASKS: - include['classifier'] = include_estimators - elif task_type in REGRESSION_TASKS: - include['regressor'] = include_estimators - else: - raise ValueError(info['task_type']) - elif exclude_estimators is not None: - if task_type in CLASSIFICATION_TASKS: - exclude['classifier'] = exclude_estimators - elif task_type in REGRESSION_TASKS: - exclude['regressor'] = exclude_estimators - else: - raise ValueError(info['task_type']) - if task_type in REGRESSION_TASKS: - return _get_regression_dataset_requirements(info, include, exclude) + return _get_regression_dataset_requirements(info, + include if include is not None else {}, + exclude if exclude is not None else {}, + search_space_updates=search_space_updates + ) else: - return _get_classification_dataset_requirements(info, include, exclude) - - -def _get_regression_dataset_requirements(info: Dict[str, Any], include: Dict[str, List[str]], - exclude: Dict[str, List[str]]) -> List[FitRequirement]: + return _get_classification_dataset_requirements(info, + include if include is not None else {}, + exclude if exclude is not None else {}, + search_space_updates=search_space_updates + ) + + +def _get_regression_dataset_requirements(info: Dict[str, Any], + include: Optional[Dict] = None, + exclude: Optional[Dict] = None, + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None + ) -> List[FitRequirement]: task_type = STRING_TO_TASK_TYPES[info['task_type']] if task_type in TABULAR_TASKS: fit_requirements = TabularRegressionPipeline( dataset_properties=info, include=include, - exclude=exclude + exclude=exclude, + search_space_updates=search_space_updates ).get_dataset_requirements() return fit_requirements else: raise ValueError("Task_type not supported") -def _get_classification_dataset_requirements(info: Dict[str, Any], include: Dict[str, List[str]], - exclude: Dict[str, List[str]]) -> List[FitRequirement]: +def _get_classification_dataset_requirements(info: Dict[str, Any], + include: Optional[Dict] = None, + exclude: Optional[Dict] = None, + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None + ) -> List[FitRequirement]: task_type = STRING_TO_TASK_TYPES[info['task_type']] if task_type in TABULAR_TASKS: return TabularClassificationPipeline( dataset_properties=info, - include=include, exclude=exclude).\ + include=include, exclude=exclude, + search_space_updates=search_space_updates). \ get_dataset_requirements() elif task_type in IMAGE_TASKS: return ImageClassificationPipeline( dataset_properties=info, - include=include, exclude=exclude).\ + include=include, exclude=exclude, + search_space_updates=search_space_updates). \ get_dataset_requirements() else: raise ValueError("Task_type not supported") @@ -102,6 +117,33 @@ def get_configuration_space(info: Dict[str, Any], exclude: Optional[Dict] = None, search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None ) -> ConfigurationSpace: + """ + + This function is used to return the configuration + space of the pipeline created based on the constraints + specified using include, exclude and search_space_updates + + Args: + info (Dict[str, Any]): + A dictionary that specifies the required information + about the dataset to instantiate a pipeline. For more + info check the get_required_dataset_info of the + appropriate dataset in autoPyTorch/datasets + include (Optional[Dict]), (default=None): + If None, all possible components are used. + Otherwise specifies set of components to use. + exclude (Optional[Dict]), (default=None): + If None, all possible components are used. + Otherwise specifies set of components not to use. + Incompatible with include. + search_space_updates (Optional[HyperparameterSearchSpaceUpdates]): + search space updates that can be used to modify the search + space of particular components or choice modules of the pipeline + + Returns: + ConfigurationSpace + + """ task_type: int = STRING_TO_TASK_TYPES[info['task_type']] if task_type in REGRESSION_TASKS: @@ -147,7 +189,7 @@ def _get_classification_configuration_space(info: Dict[str, Any], include: Dict[ return ImageClassificationPipeline( dataset_properties=info, include=include, exclude=exclude, - search_space_updates=search_space_updates).\ + search_space_updates=search_space_updates). \ get_hyperparameter_search_space() else: raise ValueError("Task_type not supported")