diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index b278a8563..c4d2e2396 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -31,4 +31,5 @@ jobs: - name: Run tests run: | python examples/example_tabular_classification.py - python examples/example_image_classification.py + python examples/example_tabular_regression.py + python examples/example_image_classification.py \ No newline at end of file diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index fb9a72082..e73b66bb1 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -11,7 +11,7 @@ import sklearn.utils from sklearn import preprocessing from sklearn.base import BaseEstimator -from sklearn.compose import make_column_transformer +from sklearn.compose import ColumnTransformer from sklearn.exceptions import NotFittedError from autoPyTorch.data.base_feature_validator import BaseFeatureValidator, SUPPORTED_FEAT_TYPES @@ -53,16 +53,34 @@ def _fit( for column in X.columns: if X[column].isna().all(): X[column] = pd.to_numeric(X[column]) + # Also note this change in self.dtypes + if len(self.dtypes) != 0: + self.dtypes[list(X.columns).index(column)] = X[column].dtype self.enc_columns, self.feat_type = self._get_columns_to_encode(X) if len(self.enc_columns) > 0: - - self.encoder = make_column_transformer( - (preprocessing.OrdinalEncoder( - handle_unknown='use_encoded_value', - unknown_value=-1, - ), self.enc_columns), + # impute missing values before encoding, + # remove once sklearn natively supports + # it in ordinal encoding. Sklearn issue: + # "https://github.com/scikit-learn/scikit-learn/issues/17123)" + for column in self.enc_columns: + if X[column].isna().any(): + missing_value: typing.Union[int, str] = -1 + # make sure for a string column we give + # string missing value else we give numeric + if type(X[column][0]) == str: + missing_value = str(missing_value) + X[column] = X[column].cat.add_categories([missing_value]) + X[column] = X[column].fillna(missing_value) + + self.encoder = ColumnTransformer( + [ + ("encoder", + preprocessing.OrdinalEncoder( + handle_unknown='use_encoded_value', + unknown_value=-1, + ), self.enc_columns)], remainder="passthrough" ) @@ -85,6 +103,7 @@ def comparator(cmp1: str, cmp2: str) -> int: return 1 else: raise ValueError((cmp1, cmp2)) + self.feat_type = sorted( self.feat_type, key=functools.cmp_to_key(comparator) @@ -182,9 +201,8 @@ def _check_data( if not isinstance(X, (np.ndarray, pd.DataFrame)) and not scipy.sparse.issparse(X): raise ValueError("AutoPyTorch only supports Numpy arrays, Pandas DataFrames," " scipy sparse and Python Lists, yet, the provided input is" - " of type {}".format( - type(X) - )) + " of type {}".format(type(X)) + ) if self.data_type is None: self.data_type = type(X) @@ -217,28 +235,14 @@ def _check_data( # per estimator enc_columns, _ = self._get_columns_to_encode(X) - if len(enc_columns) > 0: - if np.any(pd.isnull( - X[enc_columns].dropna( # type: ignore[call-overload] - axis='columns', how='all') - )): - # Ignore all NaN columns, and if still a NaN - # Error out - raise ValueError("Categorical features in a dataframe cannot contain " - "missing/NaN values. The OrdinalEncoder used by " - "AutoPyTorch cannot handle this yet (due to a " - "limitation on scikit-learn being addressed via: " - "https://github.com/scikit-learn/scikit-learn/issues/17123)" - ) column_order = [column for column in X.columns] if len(self.column_order) > 0: if self.column_order != column_order: raise ValueError("Changing the column order of the features after fit() is " "not supported. Fit() method was called with " - "{} whereas the new features have {} as type".format( - self.column_order, - column_order, - )) + "{} whereas the new features have {} as type".format(self.column_order, + column_order,) + ) else: self.column_order = column_order dtypes = [dtype.name for dtype in X.dtypes] @@ -246,10 +250,10 @@ def _check_data( if self.dtypes != dtypes: raise ValueError("Changing the dtype of the features after fit() is " "not supported. Fit() method was called with " - "{} whereas the new features have {} as type".format( - self.dtypes, - dtypes, - )) + "{} whereas the new features have {} as type".format(self.dtypes, + dtypes, + ) + ) else: self.dtypes = dtypes @@ -294,7 +298,8 @@ def _get_columns_to_encode( "pandas.Series.astype ." "If working with string objects, the following " "tutorial illustrates how to work with text data: " - "https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html".format( # noqa: E501 + "https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html".format( + # noqa: E501 column, ) ) @@ -349,15 +354,13 @@ def list_to_dataframe( # If a list was provided, it will be converted to pandas X_train = pd.DataFrame(data=X_train).infer_objects() self.logger.warning("The provided feature types to AutoPyTorch are of type list." - "Features have been interpreted as: {}".format( - [(col, t) for col, t in zip(X_train.columns, X_train.dtypes)] - )) + "Features have been interpreted as: {}".format([(col, t) for col, t in + zip(X_train.columns, X_train.dtypes)])) if X_test is not None: if not isinstance(X_test, list): self.logger.warning("Train features are a list while the provided test data" - "is {}. X_test will be casted as DataFrame.".format( - type(X_test) - )) + "is {}. X_test will be casted as DataFrame.".format(type(X_test)) + ) X_test = pd.DataFrame(data=X_test).infer_objects() return X_train, X_test diff --git a/autoPyTorch/evaluation/abstract_evaluator.py b/autoPyTorch/evaluation/abstract_evaluator.py index c1f7da60d..29075841a 100644 --- a/autoPyTorch/evaluation/abstract_evaluator.py +++ b/autoPyTorch/evaluation/abstract_evaluator.py @@ -331,6 +331,8 @@ def __init__(self, backend: Backend, name=logger_name, port=logger_port, ) + self.backend.setup_logger(name=logger_name, port=logger_port) + self.Y_optimization: Optional[np.ndarray] = None self.Y_actual_train: Optional[np.ndarray] = None self.pipelines: Optional[List[BaseEstimator]] = None @@ -538,6 +540,7 @@ def file_output( else: pipeline = None + self.logger.debug("Saving directory {}, {}, {}".format(self.seed, self.num_run, self.budget)) self.backend.save_numrun_to_dir( seed=int(self.seed), idx=int(self.num_run), diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py index e90f35ed1..e1e08e94e 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py @@ -2,7 +2,7 @@ import numpy as np -from sklearn.compose import ColumnTransformer, make_column_transformer +from sklearn.compose import ColumnTransformer from sklearn.pipeline import make_pipeline import torch @@ -57,9 +57,9 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> "TabularColumnTransformer": if len(X['dataset_properties']['categorical_columns']): categorical_pipeline = make_pipeline(*preprocessors['categorical']) - self.preprocessor = make_column_transformer( - (numerical_pipeline, X['dataset_properties']['numerical_columns']), - (categorical_pipeline, X['dataset_properties']['categorical_columns']), + self.preprocessor = ColumnTransformer([ + ('numerical_pipeline', numerical_pipeline, X['dataset_properties']['numerical_columns']), + ('categorical_pipeline', categorical_pipeline, X['dataset_properties']['categorical_columns'])], remainder='passthrough' ) diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/OrdinalEncoder.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/OrdinalEncoder.py deleted file mode 100644 index c65726327..000000000 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/OrdinalEncoder.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Any, Dict, Optional, Union - -import numpy as np - -from sklearn.preprocessing import OrdinalEncoder as OE - -from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.encoding.base_encoder import BaseEncoder - - -class OrdinalEncoder(BaseEncoder): - """ - Encode categorical features as a one-hot numerical array - """ - def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = None): - super().__init__() - self.random_state = random_state - - def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEncoder: - - self.check_requirements(X, y) - - self.preprocessor['categorical'] = OE(handle_unknown='use_encoded_value', - unknown_value=-1, - ) - return self - - @staticmethod - def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]: - return { - 'shortname': 'OrdinalEncoder', - 'name': 'Ordinal Encoder', - 'handles_sparse': False - } diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/base_encoder_choice.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/base_encoder_choice.py index 7be7c94a2..df71ff209 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/base_encoder_choice.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/base_encoder_choice.py @@ -65,7 +65,7 @@ def get_hyperparameter_search_space(self, raise ValueError("no encoders found, please add a encoder") if default is None: - defaults = ['OneHotEncoder', 'OrdinalEncoder', 'NoEncoder'] + defaults = ['OneHotEncoder', 'NoEncoder'] for default_ in defaults: if default_ in available_preprocessors: if include is not None and default_ not in include: diff --git a/autoPyTorch/pipeline/components/setup/network/base_network.py b/autoPyTorch/pipeline/components/setup/network/base_network.py index 4f7c18b7c..81fd8e5f4 100644 --- a/autoPyTorch/pipeline/components/setup/network/base_network.py +++ b/autoPyTorch/pipeline/components/setup/network/base_network.py @@ -29,6 +29,7 @@ def __init__( self.add_fit_requirements([ FitRequirement("network_head", (torch.nn.Module,), user_defined=False, dataset_property=False), FitRequirement("network_backbone", (torch.nn.Module,), user_defined=False, dataset_property=False), + FitRequirement("network_embedding", (torch.nn.Module,), user_defined=False, dataset_property=False), ]) self.final_activation = None @@ -47,7 +48,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchTrainingComponent: # information to fit this stage self.check_requirements(X, y) - self.network = torch.nn.Sequential(X['network_backbone'], X['network_head']) + self.network = torch.nn.Sequential(X['network_embedding'], X['network_backbone'], X['network_head']) # Properly set the network training device if self.device is None: diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py index 2557e92b8..241fcb51b 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py @@ -14,6 +14,7 @@ from autoPyTorch.pipeline.components.base_component import ( autoPyTorchComponent, ) +from autoPyTorch.pipeline.components.setup.network_backbone.utils import get_output_shape from autoPyTorch.utils.common import FitRequirement @@ -31,7 +32,9 @@ def __init__(self, FitRequirement('X_train', (np.ndarray, pd.DataFrame, csr_matrix), user_defined=True, dataset_property=False), FitRequirement('input_shape', (Iterable,), user_defined=True, dataset_property=True), - FitRequirement('tabular_transformer', (BaseEstimator,), user_defined=False, dataset_property=False)]) + FitRequirement('tabular_transformer', (BaseEstimator,), user_defined=False, dataset_property=False), + FitRequirement('network_embedding', (nn.Module,), user_defined=False, dataset_property=False) + ]) self.backbone: nn.Module = None self.config = kwargs self.input_shape: Optional[Iterable] = None @@ -56,6 +59,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator: column_transformer = X['tabular_transformer'].preprocessor input_shape = column_transformer.transform(X_train[:1]).shape[1:] + input_shape = get_output_shape(X['network_embedding'], input_shape=input_shape) self.input_shape = input_shape self.backbone = self.build_backbone( diff --git a/autoPyTorch/pipeline/components/setup/network_embedding/LearnedEntityEmbedding.py b/autoPyTorch/pipeline/components/setup/network_embedding/LearnedEntityEmbedding.py new file mode 100644 index 000000000..3910afc37 --- /dev/null +++ b/autoPyTorch/pipeline/components/setup/network_embedding/LearnedEntityEmbedding.py @@ -0,0 +1,132 @@ +from typing import Any, Dict, Optional, Tuple, Union + +from ConfigSpace.configuration_space import ConfigurationSpace +from ConfigSpace.hyperparameters import ( + UniformFloatHyperparameter, + UniformIntegerHyperparameter +) + +import numpy as np + +import torch +from torch import nn + +from autoPyTorch.pipeline.components.setup.network_embedding.base_network_embedding import NetworkEmbeddingComponent + + +class _LearnedEntityEmbedding(nn.Module): + """ Learned entity embedding module for categorical features""" + + def __init__(self, config: Dict[str, Any], num_input_features: np.ndarray, num_numerical_features: int): + """ + Arguments: + config (Dict[str, Any]): The configuration sampled by the hyperparameter optimizer + num_input_features (np.ndarray): column wise information of number of output columns after transformation + for each categorical column and 0 for numerical columns + num_numerical_features (int): number of numerical features in X + """ + super().__init__() + self.config = config + + self.num_numerical = num_numerical_features + # list of number of categories of categorical data + # or 0 for numerical data + self.num_input_features = num_input_features + categorical_features = self.num_input_features > 0 + + self.num_categorical_features = self.num_input_features[categorical_features] + + self.embed_features = [num_in >= config["min_unique_values_for_embedding"] for num_in in + self.num_input_features] + self.num_output_dimensions = [0] * num_numerical_features + self.num_output_dimensions.extend([config["dimension_reduction_" + str(i)] * num_in for i, num_in in + enumerate(self.num_categorical_features)]) + self.num_output_dimensions = [int(np.clip(num_out, 1, num_in - 1)) for num_out, num_in in + zip(self.num_output_dimensions, self.num_input_features)] + self.num_output_dimensions = [num_out if embed else num_in for num_out, embed, num_in in + zip(self.num_output_dimensions, self.embed_features, + self.num_input_features)] + self.num_out_feats = self.num_numerical + sum(self.num_output_dimensions) + + self.ee_layers = self._create_ee_layers() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # pass the columns of each categorical feature through entity embedding layer + # before passing it through the model + concat_seq = [] + last_concat = 0 + x_pointer = 0 + layer_pointer = 0 + for num_in, embed in zip(self.num_input_features, self.embed_features): + if not embed: + x_pointer += 1 + continue + if x_pointer > last_concat: + concat_seq.append(x[:, last_concat: x_pointer]) + categorical_feature_slice = x[:, x_pointer: x_pointer + num_in] + concat_seq.append(self.ee_layers[layer_pointer](categorical_feature_slice)) + layer_pointer += 1 + x_pointer += num_in + last_concat = x_pointer + + concat_seq.append(x[:, last_concat:]) + return torch.cat(concat_seq, dim=1) + + def _create_ee_layers(self) -> nn.ModuleList: + # entity embeding layers are Linear Layers + layers = nn.ModuleList() + for i, (num_in, embed, num_out) in enumerate(zip(self.num_input_features, self.embed_features, + self.num_output_dimensions)): + if not embed: + continue + layers.append(nn.Linear(num_in, num_out)) + return layers + + +class LearnedEntityEmbedding(NetworkEmbeddingComponent): + """ + Class to learn an embedding for categorical hyperparameters. + """ + + def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = None, **kwargs: Any): + super().__init__(random_state=random_state) + self.config = kwargs + + def build_embedding(self, num_input_features: np.ndarray, num_numerical_features: int) -> nn.Module: + return _LearnedEntityEmbedding(config=self.config, + num_input_features=num_input_features, + num_numerical_features=num_numerical_features) + + @staticmethod + def get_hyperparameter_search_space( + dataset_properties: Optional[Dict[str, str]] = None, + min_unique_values_for_embedding: Tuple[Tuple, int, bool] = ((3, 7), 5, True), + dimension_reduction: Tuple[Tuple, float] = ((0, 1), 0.5), + ) -> ConfigurationSpace: + cs = ConfigurationSpace() + min_hp = UniformIntegerHyperparameter("min_unique_values_for_embedding", + lower=min_unique_values_for_embedding[0][0], + upper=min_unique_values_for_embedding[0][1], + default_value=min_unique_values_for_embedding[1], + log=min_unique_values_for_embedding[2] + ) + cs.add_hyperparameter(min_hp) + if dataset_properties is not None: + for i in range(len(dataset_properties['categorical_columns'])): + ee_dimensions_hp = UniformFloatHyperparameter("dimension_reduction_" + str(i), + lower=dimension_reduction[0][0], + upper=dimension_reduction[0][1], + default_value=dimension_reduction[1] + ) + cs.add_hyperparameter(ee_dimensions_hp) + return cs + + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]: + return { + 'shortname': 'embedding', + 'name': 'LearnedEntityEmbedding', + 'handles_tabular': True, + 'handles_image': False, + 'handles_time_series': False, + } diff --git a/autoPyTorch/pipeline/components/setup/network_embedding/NoEmbedding.py b/autoPyTorch/pipeline/components/setup/network_embedding/NoEmbedding.py new file mode 100644 index 000000000..a8b81af2f --- /dev/null +++ b/autoPyTorch/pipeline/components/setup/network_embedding/NoEmbedding.py @@ -0,0 +1,44 @@ +from typing import Any, Dict, Optional, Union + +from ConfigSpace.configuration_space import ConfigurationSpace + +import numpy as np + +import torch +from torch import nn + +from autoPyTorch.pipeline.components.setup.network_embedding.base_network_embedding import NetworkEmbeddingComponent + + +class _NoEmbedding(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + +class NoEmbedding(NetworkEmbeddingComponent): + """ + Class to learn an embedding for categorical hyperparameters. + """ + + def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = None): + super().__init__(random_state=random_state) + + def build_embedding(self, num_input_features: np.ndarray, num_numerical_features: int) -> nn.Module: + return _NoEmbedding() + + @staticmethod + def get_hyperparameter_search_space( + dataset_properties: Optional[Dict[str, str]] = None, + ) -> ConfigurationSpace: + cs = ConfigurationSpace() + return cs + + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]: + return { + 'shortname': 'no embedding', + 'name': 'NoEmbedding', + 'handles_tabular': True, + 'handles_image': False, + 'handles_time_series': False, + } diff --git a/autoPyTorch/pipeline/components/setup/network_embedding/__init__.py b/autoPyTorch/pipeline/components/setup/network_embedding/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding.py b/autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding.py new file mode 100644 index 000000000..8652c347c --- /dev/null +++ b/autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding.py @@ -0,0 +1,52 @@ +import copy +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np + +from sklearn.base import BaseEstimator + +from torch import nn + +from autoPyTorch.pipeline.components.setup.base_setup import autoPyTorchSetupComponent + + +class NetworkEmbeddingComponent(autoPyTorchSetupComponent): + def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = None): + super().__init__() + self.embedding: Optional[nn.Module] = None + self.random_state = random_state + + def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator: + + num_numerical_columns, num_input_features = self._get_args(X) + + self.embedding = self.build_embedding( + num_input_features=num_input_features, + num_numerical_features=num_numerical_columns) + return self + + def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: + X.update({'network_embedding': self.embedding}) + return X + + def build_embedding(self, num_input_features: np.ndarray, num_numerical_features: int) -> nn.Module: + raise NotImplementedError + + def _get_args(self, X: Dict[str, Any]) -> Tuple[int, np.ndarray]: + # Feature preprocessors can alter numerical columns + if len(X['dataset_properties']['numerical_columns']) == 0: + num_numerical_columns = 0 + else: + X_train = copy.deepcopy(X['backend'].load_datamanager().train_tensors[0][:2]) + + numerical_column_transformer = X['tabular_transformer'].preprocessor. \ + named_transformers_['numerical_pipeline'] + num_numerical_columns = numerical_column_transformer.transform( + X_train[:, X['dataset_properties']['numerical_columns']]).shape[1] + num_input_features = np.zeros((num_numerical_columns + len(X['dataset_properties']['categorical_columns'])), + dtype=int) + categories = X['dataset_properties']['categories'] + + for i, category in enumerate(categories): + num_input_features[num_numerical_columns + i, ] = len(category) + return num_numerical_columns, num_input_features diff --git a/autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding_choice.py b/autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding_choice.py new file mode 100644 index 000000000..c08b156ce --- /dev/null +++ b/autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding_choice.py @@ -0,0 +1,188 @@ +import os +from collections import OrderedDict +from typing import Dict, List, Optional + +import ConfigSpace.hyperparameters as CSH +from ConfigSpace.configuration_space import ConfigurationSpace + +import numpy as np + +from autoPyTorch.pipeline.components.base_choice import autoPyTorchChoice +from autoPyTorch.pipeline.components.base_component import ( + ThirdPartyComponents, + autoPyTorchComponent, + find_components, +) +from autoPyTorch.pipeline.components.setup.network_embedding.base_network_embedding import ( + NetworkEmbeddingComponent, +) + +directory = os.path.split(__file__)[0] +_embeddings = find_components(__package__, + directory, + NetworkEmbeddingComponent) +_addons = ThirdPartyComponents(NetworkEmbeddingComponent) + + +def add_embedding(embedding: NetworkEmbeddingComponent) -> None: + _addons.add_component(embedding) + + +class NetworkEmbeddingChoice(autoPyTorchChoice): + + def get_components(self) -> Dict[str, autoPyTorchComponent]: + """Returns the available embedding components + + Args: + None + + Returns: + Dict[str, autoPyTorchComponent]: all baseembedding components available + as choices for learning rate scheduling + """ + components = OrderedDict() + components.update(_embeddings) + components.update(_addons.components) + return components + + def get_available_components( + self, + dataset_properties: Optional[Dict[str, str]] = None, + include: List[str] = None, + exclude: List[str] = None, + ) -> Dict[str, autoPyTorchComponent]: + """Filters out components based on user provided + include/exclude directives, as well as the dataset properties + + Args: + include (Optional[Dict[str, Any]]): what hyper-parameter configurations + to honor when creating the configuration space + exclude (Optional[Dict[str, Any]]): what hyper-parameter configurations + to remove from the configuration space + dataset_properties (Optional[Dict[str, Union[str, int]]]): Caracteristics + of the dataset to guide the pipeline choices of components + + Returns: + Dict[str, autoPyTorchComponent]: A filtered dict of learning + rate embeddings + + """ + if dataset_properties is None: + dataset_properties = {} + + if include is not None and exclude is not None: + raise ValueError( + "The argument include and exclude cannot be used together.") + + available_comp = self.get_components() + + if include is not None: + for incl in include: + if incl not in available_comp: + raise ValueError("Trying to include unknown component: " + "%s" % incl) + + components_dict = OrderedDict() + for name in available_comp: + if include is not None and name not in include: + continue + elif exclude is not None and name in exclude: + continue + + entry = available_comp[name] + + # Exclude itself to avoid infinite loop + if entry == NetworkEmbeddingChoice or hasattr(entry, 'get_components'): + continue + + task_type = dataset_properties['task_type'] + properties = entry.get_properties() + if 'tabular' in task_type and not properties['handles_tabular']: + continue + elif 'image' in task_type and not properties['handles_image']: + continue + elif 'time_series' in task_type and not properties['handles_time_series']: + continue + + components_dict[name] = entry + + return components_dict + + def get_hyperparameter_search_space( + self, + dataset_properties: Optional[Dict[str, str]] = None, + default: Optional[str] = None, + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + ) -> ConfigurationSpace: + """Returns the configuration space of the current chosen components + + Args: + dataset_properties (Optional[Dict[str, str]]): Describes the dataset to work on + default (Optional[str]): Default embedding to use + include: Optional[Dict[str, Any]]: what components to include. It is an exhaustive + list, and will exclusively use this components. + exclude: Optional[Dict[str, Any]]: which components to skip + + Returns: + ConfigurationSpace: the configuration space of the hyper-parameters of the + chosen component + """ + cs = ConfigurationSpace() + + if dataset_properties is None: + dataset_properties = {} + + # Compile a list of legal preprocessors for this problem + available_embedding = self.get_available_components( + dataset_properties=dataset_properties, + include=include, exclude=exclude) + + if len(available_embedding) == 0 and 'tabular' in dataset_properties['task_type']: + raise ValueError("No embedding found") + + if available_embedding == 0: + return cs + + if default is None: + defaults = [ + 'LearnedEntityEmbedding', + 'NoEmbedding' + ] + for default_ in defaults: + if default_ in available_embedding: + default = default_ + break + + if len(dataset_properties['categorical_columns']) == 0: + default = 'NoEmbedding' + if include is not None and default not in include: + raise ValueError("Provided {} in include, however, the dataset " + "is incompatible with it".format(include)) + embedding = CSH.CategoricalHyperparameter('__choice__', + ['NoEmbedding'], + default_value=default) + else: + embedding = CSH.CategoricalHyperparameter('__choice__', + list(available_embedding.keys()), + default_value=default) + + cs.add_hyperparameter(embedding) + for name in embedding.choices: + updates = self._get_search_space_updates(prefix=name) + config_space = available_embedding[name].get_hyperparameter_search_space(dataset_properties, # type: ignore + **updates) + parent_hyperparameter = {'parent': embedding, 'value': name} + cs.add_configuration_space( + name, + config_space, + parent_hyperparameter=parent_hyperparameter + ) + + self.configuration_space_ = cs + self.dataset_properties_ = dataset_properties + return cs + + def transform(self, X: np.ndarray) -> np.ndarray: + assert self.choice is not None, "Cannot call transform before the object is initialized" + return self.choice.transform(X) diff --git a/autoPyTorch/pipeline/tabular_classification.py b/autoPyTorch/pipeline/tabular_classification.py index e3abad9cc..73dca2878 100644 --- a/autoPyTorch/pipeline/tabular_classification.py +++ b/autoPyTorch/pipeline/tabular_classification.py @@ -1,7 +1,9 @@ +import copy import warnings from typing import Any, Dict, List, Optional, Tuple from ConfigSpace.configuration_space import Configuration, ConfigurationSpace +from ConfigSpace.forbidden import ForbiddenAndConjunction, ForbiddenEqualsClause import numpy as np @@ -25,6 +27,7 @@ from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler_choice import SchedulerChoice from autoPyTorch.pipeline.components.setup.network.base_network import NetworkComponent from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone_choice import NetworkBackboneChoice +from autoPyTorch.pipeline.components.setup.network_embedding.base_network_embedding_choice import NetworkEmbeddingChoice from autoPyTorch.pipeline.components.setup.network_head.base_network_head_choice import NetworkHeadChoice from autoPyTorch.pipeline.components.setup.network_initializer.base_network_init_choice import ( NetworkInitializerChoice @@ -62,15 +65,15 @@ class TabularClassificationPipeline(ClassifierMixin, BasePipeline): """ def __init__( - self, - config: Optional[Configuration] = None, - steps: Optional[List[Tuple[str, autoPyTorchChoice]]] = None, - dataset_properties: Optional[Dict[str, Any]] = None, - include: Optional[Dict[str, Any]] = None, - exclude: Optional[Dict[str, Any]] = None, - random_state: Optional[np.random.RandomState] = None, - init_params: Optional[Dict[str, Any]] = None, - search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None + self, + config: Optional[Configuration] = None, + steps: Optional[List[Tuple[str, autoPyTorchChoice]]] = None, + dataset_properties: Optional[Dict[str, Any]] = None, + include: Optional[Dict[str, Any]] = None, + exclude: Optional[Dict[str, Any]] = None, + random_state: Optional[np.random.RandomState] = None, + init_params: Optional[Dict[str, Any]] = None, + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None ): super().__init__( config, steps, dataset_properties, include, exclude, @@ -188,6 +191,33 @@ def _get_hyperparameter_search_space(self, # Here we add custom code, like this with this # is not a valid configuration + # Learned Entity Embedding is only valid when encoder is one hot encoder + if 'network_embedding' in self.named_steps.keys() and 'encoder' in self.named_steps.keys(): + embeddings = cs.get_hyperparameter('network_embedding:__choice__').choices + if 'LearnedEntityEmbedding' in embeddings: + encoders = cs.get_hyperparameter('encoder:__choice__').choices + default = cs.get_hyperparameter('network_embedding:__choice__').default_value + possible_default_embeddings = copy.copy(list(embeddings)) + del possible_default_embeddings[possible_default_embeddings.index(default)] + + for encoder in encoders: + if encoder == 'OneHotEncoder': + continue + while True: + try: + cs.add_forbidden_clause(ForbiddenAndConjunction( + ForbiddenEqualsClause(cs.get_hyperparameter( + 'network_embedding:__choice__'), 'LearnedEntityEmbedding'), + ForbiddenEqualsClause(cs.get_hyperparameter('encoder:__choice__'), encoder) + )) + break + except ValueError: + # change the default and try again + try: + default = possible_default_embeddings.pop() + except IndexError: + raise ValueError("Cannot find a legal default configuration") + cs.get_hyperparameter('network_embedding:__choice__').default_value = default self.configuration_space = cs self.dataset_properties = dataset_properties @@ -216,6 +246,7 @@ def _get_pipeline_steps(self, dataset_properties: Optional[Dict[str, Any]], ("feature_preprocessor", FeatureProprocessorChoice(default_dataset_properties)), ("tabular_transformer", TabularColumnTransformer()), ("preprocessing", EarlyPreprocessing()), + ("network_embedding", NetworkEmbeddingChoice(default_dataset_properties)), ("network_backbone", NetworkBackboneChoice(default_dataset_properties)), ("network_head", NetworkHeadChoice(default_dataset_properties)), ("network", NetworkComponent()), diff --git a/autoPyTorch/pipeline/tabular_regression.py b/autoPyTorch/pipeline/tabular_regression.py index 174e41dee..855a025e8 100644 --- a/autoPyTorch/pipeline/tabular_regression.py +++ b/autoPyTorch/pipeline/tabular_regression.py @@ -1,7 +1,9 @@ +import copy import warnings from typing import Any, Dict, List, Optional, Tuple from ConfigSpace.configuration_space import Configuration, ConfigurationSpace +from ConfigSpace.forbidden import ForbiddenAndConjunction, ForbiddenEqualsClause import numpy as np @@ -23,6 +25,7 @@ from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler_choice import SchedulerChoice from autoPyTorch.pipeline.components.setup.network.base_network import NetworkComponent from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone_choice import NetworkBackboneChoice +from autoPyTorch.pipeline.components.setup.network_embedding.base_network_embedding_choice import NetworkEmbeddingChoice from autoPyTorch.pipeline.components.setup.network_head.base_network_head_choice import NetworkHeadChoice from autoPyTorch.pipeline.components.setup.network_initializer.base_network_init_choice import ( NetworkInitializerChoice @@ -136,6 +139,33 @@ def _get_hyperparameter_search_space(self, # Here we add custom code, like this with this # is not a valid configuration + # Learned Entity Embedding is only valid when encoder is one hot encoder + if 'network_embedding' in self.named_steps.keys() and 'encoder' in self.named_steps.keys(): + embeddings = cs.get_hyperparameter('network_embedding:__choice__').choices + if 'LearnedEntityEmbedding' in embeddings: + encoders = cs.get_hyperparameter('encoder:__choice__').choices + default = cs.get_hyperparameter('network_embedding:__choice__').default_value + possible_default_embeddings = copy.copy(list(embeddings)) + del possible_default_embeddings[possible_default_embeddings.index(default)] + + for encoder in encoders: + if encoder == 'OneHotEncoder': + continue + while True: + try: + cs.add_forbidden_clause(ForbiddenAndConjunction( + ForbiddenEqualsClause(cs.get_hyperparameter( + 'network_embedding:__choice__'), 'LearnedEntityEmbedding'), + ForbiddenEqualsClause(cs.get_hyperparameter('encoder:__choice__'), encoder) + )) + break + except ValueError: + # change the default and try again + try: + default = possible_default_embeddings.pop() + except IndexError: + raise ValueError("Cannot find a legal default configuration") + cs.get_hyperparameter('network_embedding:__choice__').default_value = default self.configuration_space = cs self.dataset_properties = dataset_properties @@ -162,6 +192,7 @@ def _get_pipeline_steps(self, dataset_properties: Optional[Dict[str, Any]]) -> L ("scaler", ScalerChoice(default_dataset_properties)), ("tabular_transformer", TabularColumnTransformer()), ("preprocessing", EarlyPreprocessing()), + ("network_embedding", NetworkEmbeddingChoice(default_dataset_properties)), ("network_backbone", NetworkBackboneChoice(default_dataset_properties)), ("network_head", NetworkHeadChoice(default_dataset_properties)), ("network", NetworkComponent()), diff --git a/autoPyTorch/utils/backend.py b/autoPyTorch/utils/backend.py index dd24c2340..5111c116f 100644 --- a/autoPyTorch/utils/backend.py +++ b/autoPyTorch/utils/backend.py @@ -392,6 +392,7 @@ def save_numrun_to_dir( cv_model: Optional[BasePipeline], ensemble_predictions: Optional[np.ndarray], valid_predictions: Optional[np.ndarray], test_predictions: Optional[np.ndarray], ) -> None: + assert self._logger is not None runs_directory = self.get_runs_directory() tmpdir = tempfile.mkdtemp(dir=runs_directory) if model is not None: @@ -417,6 +418,8 @@ def save_numrun_to_dir( with open(file_path, 'wb') as fh: pickle.dump(preds.astype(np.float32), fh, -1) try: + self._logger.debug("Renaming {} to {}".format(tmpdir, + self.get_numrun_directory(seed, idx, budget))) os.rename(tmpdir, self.get_numrun_directory(seed, idx, budget)) except OSError: if os.path.exists(self.get_numrun_directory(seed, idx, budget)): diff --git a/setup.py b/setup.py index c496a48c1..1d8e47ba5 100755 --- a/setup.py +++ b/setup.py @@ -47,6 +47,7 @@ "codecov", "pep8", "mypy", + "openml" ], "examples": [ "matplotlib", diff --git a/test/conftest.py b/test/conftest.py index f05f573a7..e658b7e37 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -8,6 +8,8 @@ import numpy as np +import openml + import pandas as pd import pytest @@ -23,6 +25,42 @@ from autoPyTorch.utils.pipeline import get_dataset_requirements +@pytest.fixture(scope="session") +def callattr_ahead_of_alltests(request): + """ + This procedure will run at the start of the pytest session. + It will prefetch several task that are going to be used by + the testing face, and it does so in a robust way, until the openml + API provides the desired resources + """ + tasks_used = [ + 146818, # Australian + 2295, # cholesterol + 2075, # abalone + 2071, # adult + 3, # kr-vs-kp + 9981, # cnae-9 + 146821, # car + 146822, # Segment + 2, # anneal + 53, # vehicle + 5136, # tecator + 4871, # sensory + 4857, # boston + 3916, # kc1 + ] + + # Populate the cache + # This will make the test fail immediately rather than + # Waiting for a openml fetch timeout + openml.populate_cache(task_ids=tasks_used) + # Also the bunch + for task in tasks_used: + fetch_openml(data_id=openml.tasks.get_task(task).dataset_id, + return_X_y=True) + return + + def slugify(text): return re.sub(r'[\[\]]+', '-', text.lower()) @@ -189,7 +227,7 @@ def get_tabular_data(task): validator = TabularInputValidator(is_classification=False).fit(X.copy(), y.copy()) elif task == "regression_categorical_only": - X, y = fetch_openml("cholesterol", return_X_y=True, as_frame=True) + X, y = fetch_openml("boston", return_X_y=True, as_frame=True) categorical_columns = [column for column in X.columns if X[column].dtype.name == 'category'] X = X[categorical_columns] @@ -207,7 +245,7 @@ def get_tabular_data(task): validator = TabularInputValidator(is_classification=False).fit(X.copy(), y.copy()) elif task == "regression_numerical_and_categorical": - X, y = fetch_openml("cholesterol", return_X_y=True, as_frame=True) + X, y = fetch_openml("boston", return_X_y=True, as_frame=True) # fill nan values for now since they are not handled properly yet for column in X.columns: diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index ea7cccd72..607448de0 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -43,10 +43,16 @@ def test_tabular_classification(openml_id, resampling_strategy, backend): X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( X, y, random_state=1) + include = None + # for python less than 3.7, learned entity embedding + # is not able to be stored on disk (only on CI) + if sys.version_info < (3, 7): + include = {'network_embedding': ['NoEmbedding']} # Search for a good configuration estimator = TabularClassificationTask( backend=backend, resampling_strategy=resampling_strategy, + include_components=include ) estimator.search( @@ -121,6 +127,7 @@ def test_tabular_classification(openml_id, resampling_strategy, backend): f"{estimator.seed}.{run_key.config_id}.{run_key.budget}.cv_model" ) assert os.path.exists(model_file), model_file + model = estimator._backend.load_cv_model_by_seed_and_id_and_budget( estimator.seed, run_key.config_id, run_key.budget) assert isinstance(model, VotingClassifier) @@ -178,7 +185,7 @@ def test_tabular_classification(openml_id, resampling_strategy, backend): restored_estimator.predict(X_test) -@pytest.mark.parametrize('openml_name', ("cholesterol", )) +@pytest.mark.parametrize('openml_name', ("boston", )) @pytest.mark.parametrize('resampling_strategy', (HoldoutValTypes.holdout_validation, CrossValTypes.k_fold_cross_validation, )) diff --git a/test/test_data/test_feature_validator.py b/test/test_data/test_feature_validator.py index afa2b43e1..6d90ef2f9 100644 --- a/test/test_data/test_feature_validator.py +++ b/test/test_data/test_feature_validator.py @@ -231,10 +231,17 @@ def test_featurevalidator_unsupported_numpy(input_data_featuretest): ), indirect=True ) -def test_featurevalidator_unsupported_pandas(input_data_featuretest): +def test_featurevalidator_categorical_nan(input_data_featuretest): validator = TabularFeatureValidator() - with pytest.raises(ValueError, match=r"Categorical features in a dataframe.*missing/NaN"): - validator.fit(input_data_featuretest) + validator.fit(input_data_featuretest) + transformed_X = validator.transform(input_data_featuretest) + assert any(pd.isna(input_data_featuretest)) + assert any((-1 in categories) or ('-1' in categories) for categories in + validator.encoder.named_transformers_['encoder'].categories_) + assert np.shape(input_data_featuretest) == np.shape(transformed_X) + assert np.issubdtype(transformed_X.dtype, np.number) + assert validator._is_fitted + assert isinstance(transformed_X, np.ndarray) @pytest.mark.parametrize( diff --git a/test/test_evaluation/test_evaluation.py b/test/test_evaluation/test_evaluation.py index f4345cb40..415dc707f 100644 --- a/test/test_evaluation/test_evaluation.py +++ b/test/test_evaluation/test_evaluation.py @@ -380,6 +380,10 @@ def test_silent_exception_in_target_function(self): """'save_targets_ensemble'",)""", """AttributeError("'BackendMock' object has no attribute """ """'save_targets_ensemble'")""", + """AttributeError("'BackendMock' object has no attribute """ + """'setup_logger'",)""", + """AttributeError("'BackendMock' object has no attribute """ + """'setup_logger'")""", ) ) self.assertNotIn('exitcode', info[1].additional_info) diff --git a/test/test_pipeline/components/preprocessing/base.py b/test/test_pipeline/components/preprocessing/base.py new file mode 100644 index 000000000..875ed399c --- /dev/null +++ b/test/test_pipeline/components/preprocessing/base.py @@ -0,0 +1,36 @@ +from typing import Any, Dict, List, Optional, Tuple + +from autoPyTorch.pipeline.components.base_choice import autoPyTorchChoice +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.TabularColumnTransformer import \ + TabularColumnTransformer +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.encoding.base_encoder_choice import \ + EncoderChoice +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.imputation.SimpleImputer import SimpleImputer +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling.base_scaler_choice import ScalerChoice +from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline + + +class TabularPipeline(TabularClassificationPipeline): + def _get_pipeline_steps(self, dataset_properties: Optional[Dict[str, Any]], + ) -> List[Tuple[str, autoPyTorchChoice]]: + """ + Defines what steps a pipeline should follow. + The step itself has choices given via autoPyTorchChoice. + + Returns: + List[Tuple[str, autoPyTorchChoice]]: list of steps sequentially exercised + by the pipeline. + """ + steps = [] # type: List[Tuple[str, autoPyTorchChoice]] + + default_dataset_properties = {'target_type': 'tabular_classification'} + if dataset_properties is not None: + default_dataset_properties.update(dataset_properties) + + steps.extend([ + ("imputer", SimpleImputer()), + ("encoder", EncoderChoice(default_dataset_properties)), + ("scaler", ScalerChoice(default_dataset_properties)), + ("tabular_transformer", TabularColumnTransformer()), + ]) + return steps diff --git a/test/test_pipeline/components/preprocessing/test_encoders.py b/test/test_pipeline/components/preprocessing/test_encoders.py index 1f210936f..a901823ba 100644 --- a/test/test_pipeline/components/preprocessing/test_encoders.py +++ b/test/test_pipeline/components/preprocessing/test_encoders.py @@ -8,7 +8,6 @@ from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.encoding.NoEncoder import NoEncoder from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.encoding.OneHotEncoder import OneHotEncoder -from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.encoding.OrdinalEncoder import OrdinalEncoder class TestEncoders(unittest.TestCase): @@ -53,48 +52,6 @@ def test_one_hot_encoder_no_unknown(self): # check if the transform is correct assert_array_equal(transformed, [['1.0', '0.0', 1], ['1.0', '0.0', 2]]) - def test_ordinal_encoder(self): - - data = np.array([[1, 'male'], - [1, 'female'], - [3, 'male'], - [2, 'female'], - [2, 'male']]) - - categorical_columns = [1] - numerical_columns = [0] - train_indices = np.array([0, 2, 3]) - test_indices = np.array([1, 4]) - - dataset_properties = { - 'categorical_columns': categorical_columns, - 'numerical_columns': numerical_columns, - 'categories': [['female', 'male', 'unknown']] - } - X = { - 'X_train': data[train_indices], - 'dataset_properties': dataset_properties - } - encoder_component = OrdinalEncoder() - encoder_component.fit(X) - X = encoder_component.transform(X) - - encoder = X['encoder']['categorical'] - - # check if the fit dictionary X is modified as expected - self.assertIsInstance(X['encoder'], dict) - self.assertIsInstance(encoder, BaseEstimator) - self.assertIsNone(X['encoder']['numerical']) - - # make column transformer with returned encoder to fit on data - column_transformer = make_column_transformer((encoder, X['dataset_properties']['categorical_columns']), - remainder='passthrough') - column_transformer = column_transformer.fit(X['X_train']) - transformed = column_transformer.transform(data[test_indices]) - - # check if we got the expected transformed array - assert_array_equal(transformed, [['0.0', 1], ['1.0', 2]]) - def test_none_encoder(self): data = np.array([[1, 'male'], diff --git a/test/test_pipeline/components/preprocessing/test_tabular_column_transformer.py b/test/test_pipeline/components/preprocessing/test_tabular_column_transformer.py index ef113c5eb..66a96f27f 100644 --- a/test/test_pipeline/components/preprocessing/test_tabular_column_transformer.py +++ b/test/test_pipeline/components/preprocessing/test_tabular_column_transformer.py @@ -1,4 +1,4 @@ -from test.test_pipeline.components.base import TabularPipeline +from test.test_pipeline.components.preprocessing.base import TabularPipeline import numpy as np @@ -33,7 +33,18 @@ def test_tabular_preprocess(self, fit_dictionary_tabular): data = column_transformer.preprocessor.fit_transform(X['X_train']) assert isinstance(data, np.ndarray) + # Make sure no columns are unintentionally dropped after preprocessing + if len(fit_dictionary_tabular['dataset_properties']["numerical_columns"]) == 0: + categorical_pipeline = column_transformer.preprocessor.named_transformers_['categorical_pipeline'] + categorical_data = categorical_pipeline.transform(X['X_train']) + assert data.shape[1] == categorical_data.shape[1] + elif len(fit_dictionary_tabular['dataset_properties']["categorical_columns"]) == 0: + numerical_pipeline = column_transformer.preprocessor.named_transformers_['numerical_pipeline'] + numerical_data = numerical_pipeline.transform(X['X_train']) + assert data.shape[1] == numerical_data.shape[1] + def test_sparse_data(self, fit_dictionary_tabular): + X = np.random.binomial(1, 0.1, (100, 2000)) sparse_X = csr_matrix(X) numerical_columns = list(range(2000)) diff --git a/test/test_pipeline/components/setup/test_setup_networks.py b/test/test_pipeline/components/setup/test_setup_networks.py index be8af94c5..f6f3decb0 100644 --- a/test/test_pipeline/components/setup/test_setup_networks.py +++ b/test/test_pipeline/components/setup/test_setup_networks.py @@ -17,21 +17,33 @@ def head(request): return request.param +@pytest.fixture(params=['LearnedEntityEmbedding', 'NoEmbedding']) +def embedding(request): + return request.param + + @flaky.flaky(max_runs=3) @pytest.mark.parametrize("fit_dictionary_tabular", ['classification_numerical_only', 'classification_categorical_only', 'classification_numerical_and_categorical'], indirect=True) class TestNetworks: - def test_pipeline_fit(self, fit_dictionary_tabular, backbone, head): + def test_pipeline_fit(self, fit_dictionary_tabular, embedding, backbone, head): """This test makes sure that the pipeline is able to fit - given random combinations of hyperparameters across the pipeline""" + every combination of network embedding, backbone, head""" + include = {'network_backbone': [backbone], 'network_head': [head], 'network_embedding': [embedding]} + + if len(fit_dictionary_tabular['dataset_properties'] + ['categorical_columns']) == 0 and embedding == 'LearnedEntityEmbedding': + pytest.skip("Learned Entity Embedding is not used with numerical only data") pipeline = TabularClassificationPipeline( dataset_properties=fit_dictionary_tabular['dataset_properties'], - include={'network_backbone': [backbone], 'network_head': [head]}) + include=include) + cs = pipeline.get_hyperparameter_search_space() config = cs.get_default_configuration() + assert embedding == config.get('network_embedding:__choice__', None) assert backbone == config.get('network_backbone:__choice__', None) assert head == config.get('network_head:__choice__', None) pipeline.set_hyperparameters(config) diff --git a/test/test_pipeline/components/base.py b/test/test_pipeline/components/training/base.py similarity index 67% rename from test/test_pipeline/components/base.py rename to test/test_pipeline/components/training/base.py index 8211172e7..10d9ea416 100644 --- a/test/test_pipeline/components/base.py +++ b/test/test_pipeline/components/training/base.py @@ -1,23 +1,20 @@ import logging import unittest -from typing import Any, Dict, List, Optional, Tuple from sklearn.datasets import make_classification, make_regression import torch -from autoPyTorch.constants import BINARY, CLASSIFICATION_TASKS, CONTINUOUS, OUTPUT_TYPES_TO_STRING, REGRESSION_TASKS, \ +from autoPyTorch.constants import ( + BINARY, + CLASSIFICATION_TASKS, + CONTINUOUS, + OUTPUT_TYPES_TO_STRING, + REGRESSION_TASKS, TASK_TYPES_TO_STRING -from autoPyTorch.pipeline.components.base_choice import autoPyTorchChoice -from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.TabularColumnTransformer import \ - TabularColumnTransformer -from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.encoding.base_encoder_choice import \ - EncoderChoice -from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.imputation.SimpleImputer import SimpleImputer -from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling.base_scaler_choice import ScalerChoice +) from autoPyTorch.pipeline.components.training.metrics.utils import get_metrics from autoPyTorch.pipeline.components.training.trainer.base_trainer import BaseTrainerComponent, BudgetTracker -from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline class BaseTraining(unittest.TestCase): @@ -121,29 +118,3 @@ def train_model(self, # Backward pass loss.backward() optimizer.step() - - -class TabularPipeline(TabularClassificationPipeline): - def _get_pipeline_steps(self, dataset_properties: Optional[Dict[str, Any]], - ) -> List[Tuple[str, autoPyTorchChoice]]: - """ - Defines what steps a pipeline should follow. - The step itself has choices given via autoPyTorchChoice. - - Returns: - List[Tuple[str, autoPyTorchChoice]]: list of steps sequentially exercised - by the pipeline. - """ - steps = [] # type: List[Tuple[str, autoPyTorchChoice]] - - default_dataset_properties = {'target_type': 'tabular_classification'} - if dataset_properties is not None: - default_dataset_properties.update(dataset_properties) - - steps.extend([ - ("imputer", SimpleImputer()), - ("encoder", EncoderChoice(default_dataset_properties)), - ("scaler", ScalerChoice(default_dataset_properties)), - ("tabular_transformer", TabularColumnTransformer()), - ]) - return steps diff --git a/test/test_pipeline/components/training/test_training.py b/test/test_pipeline/components/training/test_training.py index 9005d1ad2..d6964fa14 100644 --- a/test/test_pipeline/components/training/test_training.py +++ b/test/test_pipeline/components/training/test_training.py @@ -27,7 +27,7 @@ ) sys.path.append(os.path.dirname(__file__)) -from test.test_pipeline.components.base import BaseTraining # noqa (E402: module level import not at top of file) +from test.test_pipeline.components.training.base import BaseTraining # noqa (E402: module level import not at top of file) class BaseDataLoaderTest(unittest.TestCase): diff --git a/test/test_pipeline/test_tabular_classification.py b/test/test_pipeline/test_tabular_classification.py index 260587adb..fc6eea0e4 100644 --- a/test/test_pipeline/test_tabular_classification.py +++ b/test/test_pipeline/test_tabular_classification.py @@ -35,7 +35,13 @@ def _assert_pipeline_search_space(self, pipeline, search_space_updates): assert any(update.node_name + ':' + update.hyperparameter in name for name in config_space.get_hyperparameter_names()), \ "Can't find hyperparameter: {}".format(update.hyperparameter) - hyperparameter = config_space.get_hyperparameter(update.node_name + ':' + update.hyperparameter + '_1') + # dimension reduction in embedding starts from 0 + if 'embedding' in update.node_name: + hyperparameter = config_space.get_hyperparameter( + update.node_name + ':' + update.hyperparameter + '_0') + else: + hyperparameter = config_space.get_hyperparameter( + update.node_name + ':' + update.hyperparameter + '_1') assert update.default_value == hyperparameter.default_value if isinstance(hyperparameter, (UniformIntegerHyperparameter, UniformFloatHyperparameter)): assert update.value_range[0] == hyperparameter.lower @@ -208,6 +214,7 @@ def test_network_optimizer_lr_handshake(self, fit_dictionary_tabular): # Make sure that fitting a network adds a "network" to X assert 'network' in pipeline.named_steps.keys() + fit_dictionary_tabular['network_embedding'] = torch.nn.Linear(3, 3) fit_dictionary_tabular['network_backbone'] = torch.nn.Linear(3, 4) fit_dictionary_tabular['network_head'] = torch.nn.Linear(4, 1) X = pipeline.named_steps['network'].fit( diff --git a/test/test_pipeline/test_tabular_regression.py b/test/test_pipeline/test_tabular_regression.py index 15b8351f9..74de19405 100644 --- a/test/test_pipeline/test_tabular_regression.py +++ b/test/test_pipeline/test_tabular_regression.py @@ -39,7 +39,13 @@ def _assert_pipeline_search_space(self, pipeline, search_space_updates): assert any(update.node_name + ':' + update.hyperparameter in name for name in config_space.get_hyperparameter_names()), \ "Can't find hyperparameter: {}".format(update.hyperparameter) - hyperparameter = config_space.get_hyperparameter(update.node_name + ':' + update.hyperparameter + '_1') + # dimension reduction in embedding starts from 0 + if 'embedding' in update.node_name: + hyperparameter = config_space.get_hyperparameter( + update.node_name + ':' + update.hyperparameter + '_0') + else: + hyperparameter = config_space.get_hyperparameter( + update.node_name + ':' + update.hyperparameter + '_1') assert update.default_value == hyperparameter.default_value if isinstance(hyperparameter, (UniformIntegerHyperparameter, UniformFloatHyperparameter)): assert update.value_range[0] == hyperparameter.lower @@ -199,6 +205,7 @@ def test_network_optimizer_lr_handshake(self, fit_dictionary_tabular): # Make sure that fitting a network adds a "network" to X assert 'network' in pipeline.named_steps.keys() + fit_dictionary_tabular['network_embedding'] = torch.nn.Linear(3, 3) fit_dictionary_tabular['network_backbone'] = torch.nn.Linear(3, 4) fit_dictionary_tabular['network_head'] = torch.nn.Linear(4, 1) X = pipeline.named_steps['network'].fit(