From f0accc232115c0447572ac76700ed53f9535bfb5 Mon Sep 17 00:00:00 2001 From: chico Date: Sat, 23 Jan 2021 04:10:32 +0100 Subject: [PATCH 1/5] Allow specifying the network type in include --- .../normalise/base_normalizer.py | 4 - .../base_tabular_preprocessing.py | 8 - .../augmentation/image/ImageAugmenter.py | 10 - .../early_preprocessor/EarlyPreprocessing.py | 4 - .../setup/lr_scheduler/base_scheduler.py | 5 - .../setup/network/BackboneHeadNet.py | 112 ---- .../setup/network/backbone/__init__.py | 32 - .../setup/network/backbone/tabular.py | 620 ------------------ .../components/setup/network/base_network.py | 57 +- .../components/setup/network/head/__init__.py | 26 - .../setup/network_backbone/MLPBackbone.py | 136 ++++ .../setup/network_backbone/ResNetBackbone.py | 263 ++++++++ .../network_backbone/ShapedMLPBackbone.py | 127 ++++ .../network_backbone/ShapedResNetBackbone.py | 156 +++++ .../setup/network_backbone/__init__.py | 0 .../base_network_backbone.py} | 20 +- .../base_network_backbone_choice.py | 185 ++++++ .../backbone => network_backbone}/image.py | 12 +- .../time_series.py | 14 +- .../{network => network_backbone}/utils.py | 22 + .../components/setup/network_head/__init__.py | 0 .../base_network_head.py} | 24 +- .../base_network_head_choice.py} | 108 +-- .../head => network_head}/fully_connected.py | 9 +- .../fully_convolutional.py | 9 +- .../base_network_initializer.py | 5 - .../setup/optimizer/base_optimizer.py | 3 - .../setup/traditional_ml/base_model.py | 6 - .../traditional_ml/tabular_classifier.py | 7 +- .../training/data_loader/base_data_loader.py | 8 - .../training/trainer/base_trainer_choice.py | 1 - .../pipeline/tabular_classification.py | 8 +- autoPyTorch/pipeline/tabular_regression.py | 8 +- test/test_pipeline/components/test_setup.py | 99 +-- .../test_tabular_classification.py | 7 +- 35 files changed, 1096 insertions(+), 1019 deletions(-) delete mode 100644 autoPyTorch/pipeline/components/setup/network/BackboneHeadNet.py delete mode 100644 autoPyTorch/pipeline/components/setup/network/backbone/__init__.py delete mode 100644 autoPyTorch/pipeline/components/setup/network/backbone/tabular.py delete mode 100644 autoPyTorch/pipeline/components/setup/network/head/__init__.py create mode 100644 autoPyTorch/pipeline/components/setup/network_backbone/MLPBackbone.py create mode 100644 autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py create mode 100644 autoPyTorch/pipeline/components/setup/network_backbone/ShapedMLPBackbone.py create mode 100644 autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py create mode 100644 autoPyTorch/pipeline/components/setup/network_backbone/__init__.py rename autoPyTorch/pipeline/components/setup/{network/backbone/base_backbone.py => network_backbone/base_network_backbone.py} (71%) create mode 100644 autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone_choice.py rename autoPyTorch/pipeline/components/setup/{network/backbone => network_backbone}/image.py (96%) rename autoPyTorch/pipeline/components/setup/{network/backbone => network_backbone}/time_series.py (96%) rename autoPyTorch/pipeline/components/setup/{network => network_backbone}/utils.py (92%) create mode 100644 autoPyTorch/pipeline/components/setup/network_head/__init__.py rename autoPyTorch/pipeline/components/setup/{network/head/base_head.py => network_head/base_network_head.py} (53%) rename autoPyTorch/pipeline/components/setup/{network/base_network_choice.py => network_head/base_network_head_choice.py} (60%) rename autoPyTorch/pipeline/components/setup/{network/head => network_head}/fully_connected.py (91%) rename autoPyTorch/pipeline/components/setup/{network/head => network_head}/fully_convolutional.py (94%) diff --git a/autoPyTorch/pipeline/components/preprocessing/image_preprocessing/normalise/base_normalizer.py b/autoPyTorch/pipeline/components/preprocessing/image_preprocessing/normalise/base_normalizer.py index 2ea12fae8..9c468d7d7 100644 --- a/autoPyTorch/pipeline/components/preprocessing/image_preprocessing/normalise/base_normalizer.py +++ b/autoPyTorch/pipeline/components/preprocessing/image_preprocessing/normalise/base_normalizer.py @@ -39,8 +39,4 @@ def check_requirements(self, X: Dict[str, Any], y: Any = None) -> None: def __str__(self) -> str: """ Allow a nice understanding of what components where used """ string = self.__class__.__name__ - info = vars(self) - # Remove unwanted info - info.pop('random_state', None) - string += " (" + str(info) + ")" return string diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/base_tabular_preprocessing.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/base_tabular_preprocessing.py index 6e7c2f8f1..aefe9ddf8 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/base_tabular_preprocessing.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/base_tabular_preprocessing.py @@ -34,12 +34,4 @@ def get_preprocessor_dict(self) -> Dict[str, BaseEstimator]: def __str__(self) -> str: """ Allow a nice understanding of what components where used """ string = self.__class__.__name__ - info = vars(self) - # Remove unwanted info - info.pop('early_preprocessor', None) - info.pop('column_transformer', None) - info.pop('random_state', None) - info.pop('_fit_requirements', None) - if len(info.keys()) != 0: - string += " (" + str(info) + ")" return string diff --git a/autoPyTorch/pipeline/components/setup/augmentation/image/ImageAugmenter.py b/autoPyTorch/pipeline/components/setup/augmentation/image/ImageAugmenter.py index a718dec26..40946d371 100644 --- a/autoPyTorch/pipeline/components/setup/augmentation/image/ImageAugmenter.py +++ b/autoPyTorch/pipeline/components/setup/augmentation/image/ImageAugmenter.py @@ -131,14 +131,4 @@ def get_properties(dataset_properties: Optional[Dict[str, str]] = None def __str__(self) -> str: """ Allow a nice understanding of what components where used """ string = self.__class__.__name__ - info = vars(self) - augmenters = list() - for augmenter in info['augmenter']: - augmenters.append(augmenter.name) - info['augmenters'] = augmenters - # Remove unwanted info - info.pop('random_state', None) - info.pop('available_augmenters', None) - info.pop('augmenter', None) - string += " (" + str(info) + ")" return string diff --git a/autoPyTorch/pipeline/components/setup/early_preprocessor/EarlyPreprocessing.py b/autoPyTorch/pipeline/components/setup/early_preprocessor/EarlyPreprocessing.py index ad6dfabd6..cd97a9ba1 100644 --- a/autoPyTorch/pipeline/components/setup/early_preprocessor/EarlyPreprocessing.py +++ b/autoPyTorch/pipeline/components/setup/early_preprocessor/EarlyPreprocessing.py @@ -61,8 +61,4 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ def __str__(self) -> str: """ Allow a nice understanding of what components where used """ string = self.__class__.__name__ - info = vars(self) - # Remove unwanted info - info.pop('random_state', None) - string += " (" + str(info) + ")" return string diff --git a/autoPyTorch/pipeline/components/setup/lr_scheduler/base_scheduler.py b/autoPyTorch/pipeline/components/setup/lr_scheduler/base_scheduler.py index 221e4e9a5..c541507b8 100644 --- a/autoPyTorch/pipeline/components/setup/lr_scheduler/base_scheduler.py +++ b/autoPyTorch/pipeline/components/setup/lr_scheduler/base_scheduler.py @@ -40,9 +40,4 @@ def get_scheduler(self) -> _LRScheduler: def __str__(self) -> str: """ Allow a nice understanding of what components where used """ string = self.scheduler.__class__.__name__ - info = vars(self) - # Remove unwanted info - info.pop('scheduler', None) - info.pop('random_state', None) - string += " (" + str(info) + ")" return string diff --git a/autoPyTorch/pipeline/components/setup/network/BackboneHeadNet.py b/autoPyTorch/pipeline/components/setup/network/BackboneHeadNet.py deleted file mode 100644 index 3a3773b25..000000000 --- a/autoPyTorch/pipeline/components/setup/network/BackboneHeadNet.py +++ /dev/null @@ -1,112 +0,0 @@ -from typing import Any, Dict, Optional, Tuple, Type - -from ConfigSpace.configuration_space import ConfigurationSpace -from ConfigSpace.hyperparameters import ( - CategoricalHyperparameter -) - -import numpy as np - -from torch import nn - -from autoPyTorch.pipeline.components.setup.network.backbone import BaseBackbone, get_available_backbones -from autoPyTorch.pipeline.components.setup.network.base_network import BaseNetworkComponent -from autoPyTorch.pipeline.components.setup.network.head import BaseHead, get_available_heads -from autoPyTorch.utils import common - - -class BackboneHeadNet(BaseNetworkComponent): - """ - Implementation of a dynamic network, that consists of a backbone and a head - """ - - def __init__( - self, - network: Optional[BaseNetworkComponent] = None, - random_state: Optional[np.random.RandomState] = None, - **kwargs: Any - ): - super().__init__( - network=network, - random_state=random_state, - ) - self.config = kwargs - self._backbones = get_available_backbones() - self._heads = get_available_heads() - - @staticmethod - def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[str, Any]: - return { - "shortname": "BackboneHeadNet", - "name": "BackboneHeadNet", - } - - @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict[str, str]] = None, - **kwargs: Any) -> ConfigurationSpace: - cs = ConfigurationSpace() - backbones: Dict[str, Type[BaseBackbone]] = get_available_backbones() - heads: Dict[str, Type[BaseHead]] = get_available_heads() - - # filter backbones and heads for those who support the current task type - if dataset_properties is not None and "task_type" in dataset_properties: - task = dataset_properties["task_type"] - backbones = {name: backbone for name, backbone in backbones.items() if task in backbone.supported_tasks} - heads = {name: head for name, head in heads.items() if task in head.supported_tasks} - - backbone_defaults = [ - 'ShapedMLPBackbone', - 'MLPBackbone', - 'ConvNetImageBackbone', - 'InceptionTimeBackbone', - ] - for default_ in backbone_defaults: - if default_ in backbones.keys(): - backbone_default = default_ - break - - backbone_hp = CategoricalHyperparameter("backbone", choices=backbones.keys(), default_value=backbone_default) - head_hp = CategoricalHyperparameter("head", choices=heads.keys()) - cs.add_hyperparameters([backbone_hp, head_hp]) - - # for each backbone and head, add a conditional search space if this backbone or head is chosen - for backbone_name in backbones.keys(): - backbone_cs = backbones[backbone_name].get_hyperparameter_search_space(dataset_properties) - cs.add_configuration_space(backbone_name, - backbone_cs, - parent_hyperparameter={"parent": backbone_hp, "value": backbone_name}) - - for head_name in heads.keys(): - head_cs: ConfigurationSpace = heads[head_name].get_hyperparameter_search_space(dataset_properties) - cs.add_configuration_space(head_name, - head_cs, - parent_hyperparameter={"parent": head_hp, "value": head_name}) - return cs - - def build_network(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> nn.Module: - """ - This method returns a pytorch network, that is dynamically built using - a self.config that is network specific, and contains the additional - configuration hyperparameters to build a domain specific network - """ - backbone_name = self.config["backbone"] - head_name = self.config["head"] - Backbone = self._backbones[backbone_name] - Head = self._heads[head_name] - - backbone = Backbone(**common.replace_prefix_in_config_dict(self.config, backbone_name)) - backbone_module = backbone.build_backbone(input_shape=input_shape) - backbone_output_shape = backbone.get_output_shape(input_shape=input_shape) - - head = Head(**common.replace_prefix_in_config_dict(self.config, head_name)) - head_module = head.build_head(input_shape=backbone_output_shape, output_shape=output_shape) - - return nn.Sequential(backbone_module, head_module) - - def __str__(self) -> str: - """ Allow a nice understanding of what components where used """ - info = vars(self) - # Remove unwanted info - info.pop('network', None) - info.pop('random_state', None) - return f"BackboneHeadNet: {self.config['backbone']} -> {self.config['head']} ({str(info)})" diff --git a/autoPyTorch/pipeline/components/setup/network/backbone/__init__.py b/autoPyTorch/pipeline/components/setup/network/backbone/__init__.py deleted file mode 100644 index 97b0392ee..000000000 --- a/autoPyTorch/pipeline/components/setup/network/backbone/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -from typing import Any, Dict, Type, Union - -from autoPyTorch.pipeline.components.base_component import ( - ThirdPartyComponents, -) -from autoPyTorch.pipeline.components.setup.network.backbone.base_backbone import BaseBackbone -from autoPyTorch.pipeline.components.setup.network.backbone.image import ConvNetImageBackbone, DenseNetBackbone -from autoPyTorch.pipeline.components.setup.network.backbone.tabular import MLPBackbone, ResNetBackbone, \ - ShapedMLPBackbone -from autoPyTorch.pipeline.components.setup.network.backbone.time_series import InceptionTimeBackbone, TCNBackbone - -_backbones = { - ConvNetImageBackbone.get_name(): ConvNetImageBackbone, - DenseNetBackbone.get_name(): DenseNetBackbone, - ResNetBackbone.get_name(): ResNetBackbone, - ShapedMLPBackbone.get_name(): ShapedMLPBackbone, - MLPBackbone.get_name(): MLPBackbone, - TCNBackbone.get_name(): TCNBackbone, - InceptionTimeBackbone.get_name(): InceptionTimeBackbone -} -_addons = ThirdPartyComponents(BaseBackbone) - - -def add_backbone(backbone: BaseBackbone) -> None: - _addons.add_component(backbone) - - -def get_available_backbones() -> Dict[str, Union[Type[BaseBackbone], Any]]: - backbones = dict() - backbones.update(_backbones) - backbones.update(_addons.components) - return backbones diff --git a/autoPyTorch/pipeline/components/setup/network/backbone/tabular.py b/autoPyTorch/pipeline/components/setup/network/backbone/tabular.py deleted file mode 100644 index a24424a06..000000000 --- a/autoPyTorch/pipeline/components/setup/network/backbone/tabular.py +++ /dev/null @@ -1,620 +0,0 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple - -import ConfigSpace as CS -from ConfigSpace.configuration_space import ConfigurationSpace -from ConfigSpace.hyperparameters import ( - CategoricalHyperparameter, - UniformFloatHyperparameter, - UniformIntegerHyperparameter -) - -import torch -from torch import nn - -from autoPyTorch.pipeline.components.setup.network.backbone.base_backbone import BaseBackbone -from autoPyTorch.pipeline.components.setup.network.utils import ( - get_shaped_neuron_counts, - shake_drop, - shake_drop_get_bl, - shake_get_alpha_beta, - shake_shake -) - -_activations = { - "relu": nn.ReLU, - "tanh": nn.Tanh, - "sigmoid": nn.Sigmoid -} - - -class MLPBackbone(BaseBackbone): - """ - This component automatically creates a Multi Layer Perceptron based on a given config. - - This MLP allows for: - - Different number of layers - - Specifying the activation. But this activation is shared among layers - - Using or not dropout - - Specifying the number of units per layers - """ - supported_tasks = {"tabular_classification", "tabular_regression"} - - def build_backbone(self, input_shape: Tuple[int, ...]) -> nn.Module: - layers = list() # type: List[nn.Module] - in_features = input_shape[0] - - self._add_layer(layers, in_features, self.config['num_units_1'], 1) - - for i in range(2, self.config['num_groups'] + 1): - self._add_layer(layers, self.config["num_units_%d" % (i - 1)], - self.config["num_units_%d" % i], i) - backbone = nn.Sequential(*layers) - self.backbone = backbone - return backbone - - def get_output_shape(self, input_shape: Tuple[int, ...]) -> Tuple[int, ...]: - return (self.config["num_units_%d" % self.config["num_groups"]],) - - def _add_layer(self, layers: List[nn.Module], in_features: int, out_features: int, - layer_id: int) -> None: - """ - Dynamically add a layer given the in->out specification - - Args: - layers (List[nn.Module]): The list where all modules are added - in_features (int): input dimensionality of the new layer - out_features (int): output dimensionality of the new layer - - """ - layers.append(nn.Linear(in_features, out_features)) - layers.append(_activations[self.config["activation"]]()) - if self.config['use_dropout']: - layers.append(nn.Dropout(self.config["dropout_%d" % layer_id])) - - @staticmethod - def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, str]: - return { - 'shortname': 'MLPBackbone', - 'name': 'MLPBackbone', - } - - @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, - min_mlp_layers: int = 1, - max_mlp_layers: int = 15, - dropout: bool = True, - min_num_units: int = 10, - max_num_units: int = 1024, - ) -> ConfigurationSpace: - - cs = ConfigurationSpace() - - # The number of hidden layers the network will have. - # Layer blocks are meant to have the same architecture, differing only - # by the number of units - num_groups = UniformIntegerHyperparameter( - "num_groups", min_mlp_layers, max_mlp_layers, default_value=5) - - activation = CategoricalHyperparameter( - "activation", choices=list(_activations.keys()) - ) - cs.add_hyperparameters([num_groups, activation]) - - # We can have dropout in the network for - # better generalization - if dropout: - use_dropout = CategoricalHyperparameter( - "use_dropout", choices=[True, False], default_value=False) - cs.add_hyperparameters([use_dropout]) - - for i in range(1, max_mlp_layers + 1): - n_units_hp = UniformIntegerHyperparameter("num_units_%d" % i, - lower=min_num_units, - upper=max_num_units, - default_value=200) - cs.add_hyperparameter(n_units_hp) - - if i > min_mlp_layers: - # The units of layer i should only exist - # if there are at least i layers - cs.add_condition( - CS.GreaterThanCondition( - n_units_hp, num_groups, i - 1 - ) - ) - - if dropout: - dropout_hp = UniformFloatHyperparameter( - "dropout_%d" % i, - lower=0.0, - upper=0.8, - default_value=0.5 - ) - cs.add_hyperparameter(dropout_hp) - dropout_condition_1 = CS.EqualsCondition(dropout_hp, use_dropout, True) - - if i > min_mlp_layers: - dropout_condition_2 = CS.GreaterThanCondition(dropout_hp, num_groups, i - 1) - cs.add_condition(CS.AndConjunction(dropout_condition_1, dropout_condition_2)) - else: - cs.add_condition(dropout_condition_1) - - return cs - - -class ShapedMLPBackbone(BaseBackbone): - """ - Implementation of a Shaped MLP -- an MLP with the number of units - arranged so that a given shape is honored - """ - supported_tasks = {"tabular_classification", "tabular_regression"} - - def build_backbone(self, input_shape: Tuple[int, ...]) -> nn.Module: - layers = list() # type: List[nn.Module] - in_features = input_shape[0] - out_features = self.config["output_dim"] - neuron_counts = get_shaped_neuron_counts(self.config['mlp_shape'], - in_features, - out_features, - self.config['max_units'], - self.config['num_groups']) - if self.config["use_dropout"] and self.config["max_dropout"] > 0.05: - dropout_shape = get_shaped_neuron_counts( - self.config['mlp_shape'], 0, 0, 1000, self.config['num_groups'] - ) - - previous = in_features - for i in range(self.config['num_groups'] - 1): - if i >= len(neuron_counts): - break - if self.config["use_dropout"] and self.config["max_dropout"] > 0.05: - dropout = dropout_shape[i] / 1000 * self.config["max_dropout"] - else: - dropout = 0.0 - self._add_layer(layers, previous, neuron_counts[i], dropout) - previous = neuron_counts[i] - layers.append(nn.Linear(previous, out_features)) - - backbone = nn.Sequential(*layers) - self.backbone = backbone - return backbone - - def get_output_shape(self, input_shape: Tuple[int, ...]) -> Tuple[int, ...]: - return (self.config["output_dim"],) - - def _add_layer(self, layers: List[nn.Module], - in_features: int, out_features: int, dropout: float - ) -> None: - layers.append(nn.Linear(in_features, out_features)) - layers.append(_activations[self.config["activation"]]()) - if self.config["use_dropout"] and self.config["max_dropout"] > 0.05: - layers.append(nn.Dropout(dropout)) - - @staticmethod - def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, str]: - return { - 'shortname': 'ShapedMLPBackbone', - 'name': 'ShapedMLPBackbone', - } - - @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, - min_num_gropus: int = 1, - max_num_groups: int = 15, - min_num_units: int = 10, - max_num_units: int = 1024, - ) -> ConfigurationSpace: - - cs = ConfigurationSpace() - - # The number of groups that will compose the resnet. That is, - # a group can have N Resblock. The M number of this N resblock - # repetitions is num_groups - num_groups = UniformIntegerHyperparameter( - "num_groups", lower=min_num_gropus, upper=max_num_groups, default_value=5) - - mlp_shape = CategoricalHyperparameter('mlp_shape', choices=[ - 'funnel', 'long_funnel', 'diamond', 'hexagon', 'brick', 'triangle', 'stairs' - ]) - - activation = CategoricalHyperparameter( - "activation", choices=list(_activations.keys()) - ) - - max_units = UniformIntegerHyperparameter( - "max_units", - lower=min_num_units, - upper=max_num_units, - default_value=200, - ) - - output_dim = UniformIntegerHyperparameter( - "output_dim", - lower=min_num_units, - upper=max_num_units - ) - - cs.add_hyperparameters([num_groups, activation, mlp_shape, max_units, output_dim]) - - # We can have dropout in the network for - # better generalization - use_dropout = CategoricalHyperparameter( - "use_dropout", choices=[True, False]) - max_dropout = UniformFloatHyperparameter("max_dropout", lower=0.0, upper=1.0) - cs.add_hyperparameters([use_dropout, max_dropout]) - cs.add_condition(CS.EqualsCondition(max_dropout, use_dropout, True)) - - return cs - - -class ResNetBackbone(BaseBackbone): - """ - Implementation of a Residual Network backbone - - """ - supported_tasks = {"tabular_classification", "tabular_regression"} - - def build_backbone(self, input_shape: Tuple[int, ...]) -> None: - layers = list() # type: List[nn.Module] - in_features = input_shape[0] - layers.append(nn.Linear(in_features, self.config["num_units_0"])) - - # build num_groups-1 groups each consisting of blocks_per_group ResBlocks - # the output features of each group is defined by num_units_i - for i in range(1, self.config['num_groups'] + 1): - layers.append( - self._add_group( - in_features=self.config["num_units_%d" % (i - 1)], - out_features=self.config["num_units_%d" % i], - blocks_per_group=self.config["blocks_per_group_%d" % i], - last_block_index=(i - 1) * self.config["blocks_per_group_%d" % i], - dropout=self.config['use_dropout'] - ) - ) - - layers.append(nn.BatchNorm1d(self.config["num_units_%i" % self.config['num_groups']])) - layers.append(_activations[self.config["activation"]]()) - backbone = nn.Sequential(*layers) - self.backbone = backbone - return backbone - - def _add_group(self, in_features: int, out_features: int, - blocks_per_group: int, last_block_index: int, dropout: bool - ) -> nn.Module: - """ - Adds a group into the main backbone. - In the case of ResNet a group is a set of blocks_per_group - ResBlocks - - Args: - in_features (int): number of inputs for the current block - out_features (int): output dimensionality for the current block - blocks_per_group (int): Number of ResNet per group - last_block_index (int): block index for shake regularization - dropout (bool): whether or not use dropout - """ - blocks = list() - for i in range(blocks_per_group): - blocks.append( - ResBlock( - config=self.config, - in_features=in_features, - out_features=out_features, - blocks_per_group=blocks_per_group, - block_index=last_block_index + i, - dropout=dropout, - activation=_activations[self.config["activation"]] - ) - ) - in_features = out_features - return nn.Sequential(*blocks) - - @staticmethod - def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, str]: - return { - 'shortname': 'ResNetBackbone', - 'name': 'ResidualNetworkBackbone', - } - - @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, - min_num_gropus: int = 1, - max_num_groups: int = 9, - min_blocks_per_groups: int = 1, - max_blocks_per_groups: int = 4, - min_num_units: int = 10, - max_num_units: int = 1024, - ) -> ConfigurationSpace: - cs = ConfigurationSpace() - - # The number of groups that will compose the resnet. That is, - # a group can have N Resblock. The M number of this N resblock - # repetitions is num_groups - num_groups = UniformIntegerHyperparameter( - "num_groups", lower=min_num_gropus, upper=max_num_groups, default_value=5) - - activation = CategoricalHyperparameter( - "activation", choices=list(_activations.keys()) - ) - cs.add_hyperparameters([num_groups, activation]) - - # We can have dropout in the network for - # better generalization - use_dropout = CategoricalHyperparameter( - "use_dropout", choices=[True, False]) - cs.add_hyperparameters([use_dropout]) - - use_shake_shake = CategoricalHyperparameter("use_shake_shake", choices=[True, False]) - use_shake_drop = CategoricalHyperparameter("use_shake_drop", choices=[True, False]) - shake_drop_prob = UniformFloatHyperparameter( - "max_shake_drop_probability", lower=0.0, upper=1.0) - cs.add_hyperparameters([use_shake_shake, use_shake_drop, shake_drop_prob]) - cs.add_condition(CS.EqualsCondition(shake_drop_prob, use_shake_drop, True)) - - # It is the upper bound of the nr of groups, - # since the configuration will actually be sampled. - for i in range(0, max_num_groups + 1): - - n_units = UniformIntegerHyperparameter( - "num_units_%d" % i, - lower=min_num_units, - upper=max_num_units, - ) - blocks_per_group = UniformIntegerHyperparameter( - "blocks_per_group_%d" % i, lower=min_blocks_per_groups, - upper=max_blocks_per_groups) - - cs.add_hyperparameters([n_units, blocks_per_group]) - - if i > 1: - cs.add_condition(CS.GreaterThanCondition(n_units, num_groups, i - 1)) - cs.add_condition(CS.GreaterThanCondition(blocks_per_group, num_groups, i - 1)) - - this_dropout = UniformFloatHyperparameter( - "dropout_%d" % i, lower=0.0, upper=1.0 - ) - cs.add_hyperparameters([this_dropout]) - - dropout_condition_1 = CS.EqualsCondition(this_dropout, use_dropout, True) - - if i > 1: - - dropout_condition_2 = CS.GreaterThanCondition(this_dropout, num_groups, i - 1) - - cs.add_condition(CS.AndConjunction(dropout_condition_1, dropout_condition_2)) - else: - cs.add_condition(dropout_condition_1) - return cs - - -class ResBlock(nn.Module): - """ - __author__ = "Max Dippel, Michael Burkart and Matthias Urban" - """ - - def __init__( - self, - config: Dict[str, Any], - in_features: int, - out_features: int, - blocks_per_group: int, - block_index: int, - dropout: bool, - activation: nn.Module - ): - super(ResBlock, self).__init__() - self.config = config - self.dropout = dropout - self.activation = activation - - self.shortcut = None - self.start_norm = None # type: Optional[Callable] - - # if in != out the shortcut needs a linear layer to match the result dimensions - # if the shortcut needs a layer we apply batchnorm and activation to the shortcut - # as well (start_norm) - if in_features != out_features: - self.shortcut = nn.Linear(in_features, out_features) - self.start_norm = nn.Sequential( - nn.BatchNorm1d(in_features), - self.activation() - ) - - self.block_index = block_index - self.num_blocks = blocks_per_group * self.config["num_groups"] - self.layers = self._build_block(in_features, out_features) - - if config["use_shake_shake"]: - self.shake_shake_layers = self._build_block(in_features, out_features) - - # each bloack consists of two linear layers with batch norm and activation - def _build_block(self, in_features: int, out_features: int) -> nn.Module: - layers = list() - - if self.start_norm is None: - layers.append(nn.BatchNorm1d(in_features)) - layers.append(self.activation()) - layers.append(nn.Linear(in_features, out_features)) - - layers.append(nn.BatchNorm1d(out_features)) - layers.append(self.activation()) - - if self.config["use_dropout"]: - layers.append(nn.Dropout(self.dropout)) - layers.append(nn.Linear(out_features, out_features)) - - return nn.Sequential(*layers) - - def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: - residual = x - - # if shortcut is not none we need a layer such that x matches the output dimension - if self.shortcut is not None and self.start_norm is not None: - # in this case self.start_norm is also != none - # apply start_norm to x in order to have batchnorm+activation - # in front of shortcut and layers. Note that in this case layers - # does not start with batchnorm+activation but with the first linear layer - # (see _build_block). As a result if in_features == out_features - # -> result = x + W(~D(A(BN(W(A(BN(x)))))) - # if in_features != out_features - # -> result = W_shortcut(A(BN(x))) + W_2(~D(A(BN(W_1(A(BN(x)))))) - x = self.start_norm(x) - residual = self.shortcut(x) - - if self.config["use_shake_shake"]: - x1 = self.layers(x) - x2 = self.shake_shake_layers(x) - alpha, beta = shake_get_alpha_beta(self.training, x.is_cuda) - x = shake_shake(x1, x2, alpha, beta) - else: - x = self.layers(x) - - if self.config["use_shake_drop"]: - alpha, beta = shake_get_alpha_beta(self.training, x.is_cuda) - bl = shake_drop_get_bl( - self.block_index, - 1 - self.config["max_shake_drop_probability"], - self.num_blocks, - self.training, - x.is_cuda - ) - x = shake_drop(x, alpha, beta, bl) - - x = x + residual - return x - - -class ShapedResNetBackbone(ResNetBackbone): - """ - Implementation of a Residual Network builder with support - for shaped number of units per group. - - """ - - def build_backbone(self, input_shape: Tuple[int, ...]) -> None: - layers = list() # type: List[nn.Module] - in_features = input_shape[0] - out_features = self.config["output_dim"] - - # use the get_shaped_neuron_counts to update the number of units - neuron_counts = get_shaped_neuron_counts(self.config['resnet_shape'], - in_features, - out_features, - self.config['max_units'], - self.config['num_groups'] + 2)[:-1] - self.config.update( - {"num_units_%d" % (i): num for i, num in enumerate(neuron_counts)} - ) - if self.config['use_dropout'] and self.config["max_dropout"] > 0.05: - dropout_shape = get_shaped_neuron_counts( - self.config['resnet_shape'], 0, 0, 1000, self.config['num_groups'] - ) - - dropout_shape = [ - dropout / 1000 * self.config["max_dropout"] for dropout in dropout_shape - ] - - self.config.update( - {"dropout_%d" % (i + 1): dropout for i, dropout in enumerate(dropout_shape)} - ) - layers.append(nn.Linear(in_features, self.config["num_units_0"])) - - # build num_groups-1 groups each consisting of blocks_per_group ResBlocks - # the output features of each group is defined by num_units_i - for i in range(1, self.config['num_groups'] + 1): - layers.append( - self._add_group( - in_features=self.config["num_units_%d" % (i - 1)], - out_features=self.config["num_units_%d" % i], - blocks_per_group=self.config["blocks_per_group"], - last_block_index=(i - 1) * self.config["blocks_per_group"], - dropout=self.config['use_dropout'] - ) - ) - - layers.append(nn.BatchNorm1d(self.config["num_units_%i" % self.config['num_groups']])) - backbone = nn.Sequential(*layers) - self.backbone = backbone - return backbone - - @staticmethod - def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, str]: - return { - 'shortname': 'ShapedResNetBackbone', - 'name': 'ShapedResidualNetworkBackbone', - } - - @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, - min_num_gropus: int = 1, - max_num_groups: int = 9, - min_blocks_per_groups: int = 1, - max_blocks_per_groups: int = 4, - min_num_units: int = 10, - max_num_units: int = 1024, - ) -> ConfigurationSpace: - cs = ConfigurationSpace() - - # Support for different shapes - resnet_shape = CategoricalHyperparameter( - 'resnet_shape', - choices=[ - 'funnel', - 'long_funnel', - 'diamond', - 'hexagon', - 'brick', - 'triangle', - 'stairs' - ] - ) - cs.add_hyperparameter(resnet_shape) - - # The number of groups that will compose the resnet. That is, - # a group can have N Resblock. The M number of this N resblock - # repetitions is num_groups - num_groups = UniformIntegerHyperparameter( - "num_groups", lower=min_num_gropus, upper=max_num_groups, default_value=5) - - blocks_per_group = UniformIntegerHyperparameter( - "blocks_per_group", lower=min_blocks_per_groups, upper=max_blocks_per_groups) - - activation = CategoricalHyperparameter( - "activation", choices=list(_activations.keys()) - ) - - output_dim = UniformIntegerHyperparameter( - "output_dim", - lower=min_num_units, - upper=max_num_units - ) - - cs.add_hyperparameters([num_groups, blocks_per_group, activation, output_dim]) - - # We can have dropout in the network for - # better generalization - use_dropout = CategoricalHyperparameter( - "use_dropout", choices=[True, False]) - cs.add_hyperparameters([use_dropout]) - - use_shake_shake = CategoricalHyperparameter("use_shake_shake", choices=[True, False]) - use_shake_drop = CategoricalHyperparameter("use_shake_drop", choices=[True, False]) - shake_drop_prob = UniformFloatHyperparameter( - "max_shake_drop_probability", lower=0.0, upper=1.0) - cs.add_hyperparameters([use_shake_shake, use_shake_drop, shake_drop_prob]) - cs.add_condition(CS.EqualsCondition(shake_drop_prob, use_shake_drop, True)) - - max_units = UniformIntegerHyperparameter( - "max_units", - lower=min_num_units, - upper=max_num_units, - ) - cs.add_hyperparameters([max_units]) - - max_dropout = UniformFloatHyperparameter( - "max_dropout", lower=0.0, upper=1.0 - ) - cs.add_hyperparameters([max_dropout]) - cs.add_condition(CS.EqualsCondition(max_dropout, use_dropout, True)) - - return cs diff --git a/autoPyTorch/pipeline/components/setup/network/base_network.py b/autoPyTorch/pipeline/components/setup/network/base_network.py index 83872b55f..2b8c27d1c 100644 --- a/autoPyTorch/pipeline/components/setup/network/base_network.py +++ b/autoPyTorch/pipeline/components/setup/network/base_network.py @@ -1,5 +1,6 @@ -from abc import abstractmethod -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional + +from ConfigSpace.configuration_space import ConfigurationSpace import numpy as np @@ -7,11 +8,11 @@ from torch import nn from autoPyTorch.constants import CLASSIFICATION_TASKS, STRING_TO_TASK_TYPES -from autoPyTorch.pipeline.components.setup.base_setup import autoPyTorchSetupComponent +from autoPyTorch.pipeline.components.training.base_training import autoPyTorchTrainingComponent from autoPyTorch.utils.common import FitRequirement -class BaseNetworkComponent(autoPyTorchSetupComponent): +class NetworkComponent(autoPyTorchTrainingComponent): """ Provide an abstract interface for networks in Auto-Pytorch @@ -23,15 +24,18 @@ def __init__( random_state: Optional[np.random.RandomState] = None, device: Optional[torch.device] = None ) -> None: - super(BaseNetworkComponent, self).__init__() + super(NetworkComponent, self).__init__() self.network = network self.random_state = random_state - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device - self.add_fit_requirements([FitRequirement('task_type', (str,), user_defined=True, dataset_property=True), - FitRequirement('input_shape', (tuple,), user_defined=True, dataset_property=True), - ]) - - def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchSetupComponent: + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu") if device is None else device + self.add_fit_requirements([ + FitRequirement("network_head", (torch.nn.Module,), user_defined=False, dataset_property=False), + FitRequirement("network_backbone", (torch.nn.Module,), user_defined=False, dataset_property=False), + ]) + self.final_activation = None + + def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchTrainingComponent: """ Fits a component by using an input dictionary with pre-requisites @@ -46,26 +50,17 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchSetupComponent: # information to fit this stage self.check_requirements(X, y) - output_shape = (X['dataset_properties']['num_classes'],) if \ - STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']] in \ - CLASSIFICATION_TASKS else X['dataset_properties']['output_shape'] - input_shape = X['X_train'].shape[1:] - self.network = self.build_network(input_shape=input_shape, - output_shape=output_shape) + self.network = torch.nn.Sequential(X['network_backbone'], X['network_head']) # Properly set the network training device self.to(self.device) - return self + if STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']] in CLASSIFICATION_TASKS: + self.final_activation = nn.Softmax(dim=1) - @abstractmethod - def build_network(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> torch.nn.Module: - """ - This method returns a pytorch network, that is dynamically built using - a self.config that is network specific, and contains the additional - configuration hyperparameters to build a domain specific network - """ - raise NotImplementedError() + self.is_fitted_ = True + + return self def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: """ @@ -121,16 +116,22 @@ def predict(self, loader: torch.utils.data.DataLoader) -> torch.Tensor: X_batch = torch.autograd.Variable(X_batch).float().to(self.device) Y_batch_pred = self.network(X_batch).detach().cpu() + if self.final_activation is not None: + Y_batch_pred = self.final_activation(Y_batch_pred) Y_batch_preds.append(Y_batch_pred) return torch.cat(Y_batch_preds, 0).cpu().numpy() + @staticmethod + def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None + ) -> ConfigurationSpace: + cs = ConfigurationSpace() + return cs + def __str__(self) -> str: """ Allow a nice understanding of what components where used """ string = self.network.__class__.__name__ info = vars(self) # Remove unwanted info - info.pop('network', None) - info.pop('random_state', None) string += " (" + str(info) + ")" return string diff --git a/autoPyTorch/pipeline/components/setup/network/head/__init__.py b/autoPyTorch/pipeline/components/setup/network/head/__init__.py deleted file mode 100644 index dc07a268a..000000000 --- a/autoPyTorch/pipeline/components/setup/network/head/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -from collections import OrderedDict -from typing import Dict, Type - -from autoPyTorch.pipeline.components.base_component import ( - ThirdPartyComponents -) -from autoPyTorch.pipeline.components.setup.network.head.base_head import BaseHead -from autoPyTorch.pipeline.components.setup.network.head.fully_connected import FullyConnectedHead -from autoPyTorch.pipeline.components.setup.network.head.fully_convolutional import FullyConvolutional2DHead - -_heads = { - FullyConnectedHead.get_name(): FullyConnectedHead, - FullyConvolutional2DHead.get_name(): FullyConvolutional2DHead -} -_addons = ThirdPartyComponents(BaseHead) - - -def add_head(head: BaseHead) -> None: - _addons.add_component(head) - - -def get_available_heads() -> Dict[str, Type[BaseHead]]: - heads = OrderedDict() - heads.update(_heads) - heads.update(_addons.components) - return heads diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/MLPBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/MLPBackbone.py new file mode 100644 index 000000000..2c30f7992 --- /dev/null +++ b/autoPyTorch/pipeline/components/setup/network_backbone/MLPBackbone.py @@ -0,0 +1,136 @@ +from typing import Any, Dict, List, Optional, Tuple + +import ConfigSpace as CS +from ConfigSpace.configuration_space import ConfigurationSpace +from ConfigSpace.hyperparameters import ( + CategoricalHyperparameter, + UniformFloatHyperparameter, + UniformIntegerHyperparameter +) + +from torch import nn + +from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import ( + NetworkBackboneComponent, +) +from autoPyTorch.pipeline.components.setup.network_backbone.utils import ( + _activations, +) + + +class MLPBackbone(NetworkBackboneComponent): + """ + This component automatically creates a Multi Layer Perceptron based on a given config. + + This MLP allows for: + - Different number of layers + - Specifying the activation. But this activation is shared among layers + - Using or not dropout + - Specifying the number of units per layers + """ + supported_tasks = {"tabular_classification", "tabular_regression"} + + def build_backbone(self, input_shape: Tuple[int, ...]) -> nn.Module: + layers = list() # type: List[nn.Module] + in_features = input_shape[0] + + self._add_layer(layers, in_features, self.config['num_units_1'], 1) + + for i in range(2, self.config['num_groups'] + 1): + self._add_layer(layers, self.config["num_units_%d" % (i - 1)], + self.config["num_units_%d" % i], i) + backbone = nn.Sequential(*layers) + self.backbone = backbone + return backbone + + def get_output_shape(self, input_shape: Tuple[int, ...]) -> Tuple[int, ...]: + return (self.config["num_units_%d" % self.config["num_groups"]],) + + def _add_layer(self, layers: List[nn.Module], in_features: int, out_features: int, + layer_id: int) -> None: + """ + Dynamically add a layer given the in->out specification + + Args: + layers (List[nn.Module]): The list where all modules are added + in_features (int): input dimensionality of the new layer + out_features (int): output dimensionality of the new layer + + """ + layers.append(nn.Linear(in_features, out_features)) + layers.append(_activations[self.config["activation"]]()) + if self.config['use_dropout']: + layers.append(nn.Dropout(self.config["dropout_%d" % layer_id])) + + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, str]: + return { + 'shortname': 'MLPBackbone', + 'name': 'MLPBackbone', + 'handles_tabular': True, + 'handles_image': False, + 'handles_time_series': False, + } + + @staticmethod + def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + min_mlp_layers: int = 1, + max_mlp_layers: int = 15, + dropout: bool = True, + min_num_units: int = 10, + max_num_units: int = 1024, + ) -> ConfigurationSpace: + + cs = ConfigurationSpace() + + # The number of hidden layers the network will have. + # Layer blocks are meant to have the same architecture, differing only + # by the number of units + num_groups = UniformIntegerHyperparameter( + "num_groups", min_mlp_layers, max_mlp_layers, default_value=5) + + activation = CategoricalHyperparameter( + "activation", choices=list(_activations.keys()) + ) + cs.add_hyperparameters([num_groups, activation]) + + # We can have dropout in the network for + # better generalization + if dropout: + use_dropout = CategoricalHyperparameter( + "use_dropout", choices=[True, False], default_value=False) + cs.add_hyperparameters([use_dropout]) + + for i in range(1, max_mlp_layers + 1): + n_units_hp = UniformIntegerHyperparameter("num_units_%d" % i, + lower=min_num_units, + upper=max_num_units, + default_value=200) + cs.add_hyperparameter(n_units_hp) + + if i > min_mlp_layers: + # The units of layer i should only exist + # if there are at least i layers + cs.add_condition( + CS.GreaterThanCondition( + n_units_hp, num_groups, i - 1 + ) + ) + + if dropout: + dropout_hp = UniformFloatHyperparameter( + "dropout_%d" % i, + lower=0.0, + upper=0.8, + default_value=0.5 + ) + cs.add_hyperparameter(dropout_hp) + dropout_condition_1 = CS.EqualsCondition(dropout_hp, use_dropout, True) + + if i > min_mlp_layers: + dropout_condition_2 = CS.GreaterThanCondition(dropout_hp, num_groups, i - 1) + cs.add_condition(CS.AndConjunction(dropout_condition_1, dropout_condition_2)) + else: + cs.add_condition(dropout_condition_1) + + return cs diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py new file mode 100644 index 000000000..7c4a5ecc1 --- /dev/null +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py @@ -0,0 +1,263 @@ +from typing import Any, Callable, Dict, List, Optional, Tuple + +import ConfigSpace as CS +from ConfigSpace.configuration_space import ConfigurationSpace +from ConfigSpace.hyperparameters import ( + CategoricalHyperparameter, + UniformFloatHyperparameter, + UniformIntegerHyperparameter +) + +import torch +from torch import nn + +from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import ( + NetworkBackboneComponent, +) +from autoPyTorch.pipeline.components.setup.network_backbone.utils import ( + _activations, + shake_drop, + shake_drop_get_bl, + shake_get_alpha_beta, + shake_shake +) + + +class ResNetBackbone(NetworkBackboneComponent): + """ + Implementation of a Residual Network backbone + + """ + supported_tasks = {"tabular_classification", "tabular_regression"} + + def build_backbone(self, input_shape: Tuple[int, ...]) -> None: + layers = list() # type: List[nn.Module] + in_features = input_shape[0] + layers.append(nn.Linear(in_features, self.config["num_units_0"])) + + # build num_groups-1 groups each consisting of blocks_per_group ResBlocks + # the output features of each group is defined by num_units_i + for i in range(1, self.config['num_groups'] + 1): + layers.append( + self._add_group( + in_features=self.config["num_units_%d" % (i - 1)], + out_features=self.config["num_units_%d" % i], + blocks_per_group=self.config["blocks_per_group_%d" % i], + last_block_index=(i - 1) * self.config["blocks_per_group_%d" % i], + dropout=self.config['use_dropout'] + ) + ) + + layers.append(nn.BatchNorm1d(self.config["num_units_%i" % self.config['num_groups']])) + layers.append(_activations[self.config["activation"]]()) + backbone = nn.Sequential(*layers) + self.backbone = backbone + return backbone + + def _add_group(self, in_features: int, out_features: int, + blocks_per_group: int, last_block_index: int, dropout: bool + ) -> nn.Module: + """ + Adds a group into the main backbone. + In the case of ResNet a group is a set of blocks_per_group + ResBlocks + + Args: + in_features (int): number of inputs for the current block + out_features (int): output dimensionality for the current block + blocks_per_group (int): Number of ResNet per group + last_block_index (int): block index for shake regularization + dropout (bool): whether or not use dropout + """ + blocks = list() + for i in range(blocks_per_group): + blocks.append( + ResBlock( + config=self.config, + in_features=in_features, + out_features=out_features, + blocks_per_group=blocks_per_group, + block_index=last_block_index + i, + dropout=dropout, + activation=_activations[self.config["activation"]] + ) + ) + in_features = out_features + return nn.Sequential(*blocks) + + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, str]: + return { + 'shortname': 'ResNetBackbone', + 'name': 'ResidualNetworkBackbone', + 'handles_tabular': True, + 'handles_image': False, + 'handles_time_series': False, + } + + @staticmethod + def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + min_num_gropus: int = 1, + max_num_groups: int = 9, + min_blocks_per_groups: int = 1, + max_blocks_per_groups: int = 4, + min_num_units: int = 10, + max_num_units: int = 1024, + ) -> ConfigurationSpace: + cs = ConfigurationSpace() + + # The number of groups that will compose the resnet. That is, + # a group can have N Resblock. The M number of this N resblock + # repetitions is num_groups + num_groups = UniformIntegerHyperparameter( + "num_groups", lower=min_num_gropus, upper=max_num_groups, default_value=5) + + activation = CategoricalHyperparameter( + "activation", choices=list(_activations.keys()) + ) + cs.add_hyperparameters([num_groups, activation]) + + # We can have dropout in the network for + # better generalization + use_dropout = CategoricalHyperparameter( + "use_dropout", choices=[True, False]) + cs.add_hyperparameters([use_dropout]) + + use_shake_shake = CategoricalHyperparameter("use_shake_shake", choices=[True, False]) + use_shake_drop = CategoricalHyperparameter("use_shake_drop", choices=[True, False]) + shake_drop_prob = UniformFloatHyperparameter( + "max_shake_drop_probability", lower=0.0, upper=1.0) + cs.add_hyperparameters([use_shake_shake, use_shake_drop, shake_drop_prob]) + cs.add_condition(CS.EqualsCondition(shake_drop_prob, use_shake_drop, True)) + + # It is the upper bound of the nr of groups, + # since the configuration will actually be sampled. + for i in range(0, max_num_groups + 1): + + n_units = UniformIntegerHyperparameter( + "num_units_%d" % i, + lower=min_num_units, + upper=max_num_units, + ) + blocks_per_group = UniformIntegerHyperparameter( + "blocks_per_group_%d" % i, lower=min_blocks_per_groups, + upper=max_blocks_per_groups) + + cs.add_hyperparameters([n_units, blocks_per_group]) + + if i > 1: + cs.add_condition(CS.GreaterThanCondition(n_units, num_groups, i - 1)) + cs.add_condition(CS.GreaterThanCondition(blocks_per_group, num_groups, i - 1)) + + this_dropout = UniformFloatHyperparameter( + "dropout_%d" % i, lower=0.0, upper=1.0 + ) + cs.add_hyperparameters([this_dropout]) + + dropout_condition_1 = CS.EqualsCondition(this_dropout, use_dropout, True) + + if i > 1: + + dropout_condition_2 = CS.GreaterThanCondition(this_dropout, num_groups, i - 1) + + cs.add_condition(CS.AndConjunction(dropout_condition_1, dropout_condition_2)) + else: + cs.add_condition(dropout_condition_1) + return cs + + +class ResBlock(nn.Module): + """ + __author__ = "Max Dippel, Michael Burkart and Matthias Urban" + """ + + def __init__( + self, + config: Dict[str, Any], + in_features: int, + out_features: int, + blocks_per_group: int, + block_index: int, + dropout: bool, + activation: nn.Module + ): + super(ResBlock, self).__init__() + self.config = config + self.dropout = dropout + self.activation = activation + + self.shortcut = None + self.start_norm = None # type: Optional[Callable] + + # if in != out the shortcut needs a linear layer to match the result dimensions + # if the shortcut needs a layer we apply batchnorm and activation to the shortcut + # as well (start_norm) + if in_features != out_features: + self.shortcut = nn.Linear(in_features, out_features) + self.start_norm = nn.Sequential( + nn.BatchNorm1d(in_features), + self.activation() + ) + + self.block_index = block_index + self.num_blocks = blocks_per_group * self.config["num_groups"] + self.layers = self._build_block(in_features, out_features) + + if config["use_shake_shake"]: + self.shake_shake_layers = self._build_block(in_features, out_features) + + # each bloack consists of two linear layers with batch norm and activation + def _build_block(self, in_features: int, out_features: int) -> nn.Module: + layers = list() + + if self.start_norm is None: + layers.append(nn.BatchNorm1d(in_features)) + layers.append(self.activation()) + layers.append(nn.Linear(in_features, out_features)) + + layers.append(nn.BatchNorm1d(out_features)) + layers.append(self.activation()) + + if self.config["use_dropout"]: + layers.append(nn.Dropout(self.dropout)) + layers.append(nn.Linear(out_features, out_features)) + + return nn.Sequential(*layers) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + residual = x + + # if shortcut is not none we need a layer such that x matches the output dimension + if self.shortcut is not None and self.start_norm is not None: + # in this case self.start_norm is also != none + # apply start_norm to x in order to have batchnorm+activation + # in front of shortcut and layers. Note that in this case layers + # does not start with batchnorm+activation but with the first linear layer + # (see _build_block). As a result if in_features == out_features + # -> result = x + W(~D(A(BN(W(A(BN(x)))))) + # if in_features != out_features + # -> result = W_shortcut(A(BN(x))) + W_2(~D(A(BN(W_1(A(BN(x)))))) + x = self.start_norm(x) + residual = self.shortcut(x) + + if self.config["use_shake_shake"]: + x1 = self.layers(x) + x2 = self.shake_shake_layers(x) + alpha, beta = shake_get_alpha_beta(self.training, x.is_cuda) + x = shake_shake(x1, x2, alpha, beta) + else: + x = self.layers(x) + + if self.config["use_shake_drop"]: + alpha, beta = shake_get_alpha_beta(self.training, x.is_cuda) + bl = shake_drop_get_bl( + self.block_index, + 1 - self.config["max_shake_drop_probability"], + self.num_blocks, + self.training, + x.is_cuda + ) + x = shake_drop(x, alpha, beta, bl) + + x = x + residual + return x diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedMLPBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedMLPBackbone.py new file mode 100644 index 000000000..41e69f37a --- /dev/null +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedMLPBackbone.py @@ -0,0 +1,127 @@ +from typing import Any, Dict, List, Optional, Tuple + +import ConfigSpace as CS +from ConfigSpace.configuration_space import ConfigurationSpace +from ConfigSpace.hyperparameters import ( + CategoricalHyperparameter, + UniformFloatHyperparameter, + UniformIntegerHyperparameter +) + +from torch import nn + +from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import ( + NetworkBackboneComponent, +) +from autoPyTorch.pipeline.components.setup.network_backbone.utils import ( + _activations, + get_shaped_neuron_counts, +) + + +class ShapedMLPBackbone(NetworkBackboneComponent): + """ + Implementation of a Shaped MLP -- an MLP with the number of units + arranged so that a given shape is honored + """ + supported_tasks = {"tabular_classification", "tabular_regression"} + + def build_backbone(self, input_shape: Tuple[int, ...]) -> nn.Module: + layers = list() # type: List[nn.Module] + in_features = input_shape[0] + out_features = self.config["output_dim"] + neuron_counts = get_shaped_neuron_counts(self.config['mlp_shape'], + in_features, + out_features, + self.config['max_units'], + self.config['num_groups']) + if self.config["use_dropout"] and self.config["max_dropout"] > 0.05: + dropout_shape = get_shaped_neuron_counts( + self.config['mlp_shape'], 0, 0, 1000, self.config['num_groups'] + ) + + previous = in_features + for i in range(self.config['num_groups'] - 1): + if i >= len(neuron_counts): + break + if self.config["use_dropout"] and self.config["max_dropout"] > 0.05: + dropout = dropout_shape[i] / 1000 * self.config["max_dropout"] + else: + dropout = 0.0 + self._add_layer(layers, previous, neuron_counts[i], dropout) + previous = neuron_counts[i] + layers.append(nn.Linear(previous, out_features)) + + backbone = nn.Sequential(*layers) + self.backbone = backbone + return backbone + + def get_output_shape(self, input_shape: Tuple[int, ...]) -> Tuple[int, ...]: + return (self.config["output_dim"],) + + def _add_layer(self, layers: List[nn.Module], + in_features: int, out_features: int, dropout: float + ) -> None: + layers.append(nn.Linear(in_features, out_features)) + layers.append(_activations[self.config["activation"]]()) + if self.config["use_dropout"] and self.config["max_dropout"] > 0.05: + layers.append(nn.Dropout(dropout)) + + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, str]: + return { + 'shortname': 'ShapedMLPBackbone', + 'name': 'ShapedMLPBackbone', + 'handles_tabular': True, + 'handles_image': False, + 'handles_time_series': False, + } + + @staticmethod + def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + min_num_gropus: int = 1, + max_num_groups: int = 15, + min_num_units: int = 10, + max_num_units: int = 1024, + ) -> ConfigurationSpace: + + cs = ConfigurationSpace() + + # The number of groups that will compose the resnet. That is, + # a group can have N Resblock. The M number of this N resblock + # repetitions is num_groups + num_groups = UniformIntegerHyperparameter( + "num_groups", lower=min_num_gropus, upper=max_num_groups, default_value=5) + + mlp_shape = CategoricalHyperparameter('mlp_shape', choices=[ + 'funnel', 'long_funnel', 'diamond', 'hexagon', 'brick', 'triangle', 'stairs' + ]) + + activation = CategoricalHyperparameter( + "activation", choices=list(_activations.keys()) + ) + + max_units = UniformIntegerHyperparameter( + "max_units", + lower=min_num_units, + upper=max_num_units, + default_value=200, + ) + + output_dim = UniformIntegerHyperparameter( + "output_dim", + lower=min_num_units, + upper=max_num_units + ) + + cs.add_hyperparameters([num_groups, activation, mlp_shape, max_units, output_dim]) + + # We can have dropout in the network for + # better generalization + use_dropout = CategoricalHyperparameter( + "use_dropout", choices=[True, False]) + max_dropout = UniformFloatHyperparameter("max_dropout", lower=0.0, upper=1.0) + cs.add_hyperparameters([use_dropout, max_dropout]) + cs.add_condition(CS.EqualsCondition(max_dropout, use_dropout, True)) + + return cs diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py new file mode 100644 index 000000000..10f34de8d --- /dev/null +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py @@ -0,0 +1,156 @@ +from typing import Any, Dict, List, Optional, Tuple + +import ConfigSpace as CS +from ConfigSpace.configuration_space import ConfigurationSpace +from ConfigSpace.hyperparameters import ( + CategoricalHyperparameter, + UniformFloatHyperparameter, + UniformIntegerHyperparameter +) + +import torch + +from autoPyTorch.pipeline.components.setup.network_backbone.ResNetBackbone import ResNetBackbone +from autoPyTorch.pipeline.components.setup.network_backbone.utils import ( + _activations, + get_shaped_neuron_counts, +) + + +class ShapedResNetBackbone(ResNetBackbone): + """ + Implementation of a Residual Network builder with support + for shaped number of units per group. + + """ + + def build_backbone(self, input_shape: Tuple[int, ...]) -> None: + layers = list() # type: List[torch.nn.Module] + in_features = input_shape[0] + out_features = self.config["output_dim"] + + # use the get_shaped_neuron_counts to update the number of units + neuron_counts = get_shaped_neuron_counts(self.config['resnet_shape'], + in_features, + out_features, + self.config['max_units'], + self.config['num_groups'] + 2)[:-1] + self.config.update( + {"num_units_%d" % (i): num for i, num in enumerate(neuron_counts)} + ) + if self.config['use_dropout'] and self.config["max_dropout"] > 0.05: + dropout_shape = get_shaped_neuron_counts( + self.config['resnet_shape'], 0, 0, 1000, self.config['num_groups'] + ) + + dropout_shape = [ + dropout / 1000 * self.config["max_dropout"] for dropout in dropout_shape + ] + + self.config.update( + {"dropout_%d" % (i + 1): dropout for i, dropout in enumerate(dropout_shape)} + ) + layers.append(torch.nn.Linear(in_features, self.config["num_units_0"])) + + # build num_groups-1 groups each consisting of blocks_per_group ResBlocks + # the output features of each group is defined by num_units_i + for i in range(1, self.config['num_groups'] + 1): + layers.append( + self._add_group( + in_features=self.config["num_units_%d" % (i - 1)], + out_features=self.config["num_units_%d" % i], + blocks_per_group=self.config["blocks_per_group"], + last_block_index=(i - 1) * self.config["blocks_per_group"], + dropout=self.config['use_dropout'] + ) + ) + + layers.append(torch.nn.BatchNorm1d(self.config["num_units_%i" % self.config['num_groups']])) + backbone = torch.nn.Sequential(*layers) + self.backbone = backbone + return backbone + + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, str]: + return { + 'shortname': 'ShapedResNetBackbone', + 'name': 'ShapedResidualNetworkBackbone', + 'handles_tabular': True, + 'handles_image': False, + 'handles_time_series': False, + } + + @staticmethod + def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + min_num_gropus: int = 1, + max_num_groups: int = 9, + min_blocks_per_groups: int = 1, + max_blocks_per_groups: int = 4, + min_num_units: int = 10, + max_num_units: int = 1024, + ) -> ConfigurationSpace: + cs = ConfigurationSpace() + + # Support for different shapes + resnet_shape = CategoricalHyperparameter( + 'resnet_shape', + choices=[ + 'funnel', + 'long_funnel', + 'diamond', + 'hexagon', + 'brick', + 'triangle', + 'stairs' + ] + ) + cs.add_hyperparameter(resnet_shape) + + # The number of groups that will compose the resnet. That is, + # a group can have N Resblock. The M number of this N resblock + # repetitions is num_groups + num_groups = UniformIntegerHyperparameter( + "num_groups", lower=min_num_gropus, upper=max_num_groups, default_value=5) + + blocks_per_group = UniformIntegerHyperparameter( + "blocks_per_group", lower=min_blocks_per_groups, upper=max_blocks_per_groups) + + activation = CategoricalHyperparameter( + "activation", choices=list(_activations.keys()) + ) + + output_dim = UniformIntegerHyperparameter( + "output_dim", + lower=min_num_units, + upper=max_num_units + ) + + cs.add_hyperparameters([num_groups, blocks_per_group, activation, output_dim]) + + # We can have dropout in the network for + # better generalization + use_dropout = CategoricalHyperparameter( + "use_dropout", choices=[True, False]) + cs.add_hyperparameters([use_dropout]) + + use_shake_shake = CategoricalHyperparameter("use_shake_shake", choices=[True, False]) + use_shake_drop = CategoricalHyperparameter("use_shake_drop", choices=[True, False]) + shake_drop_prob = UniformFloatHyperparameter( + "max_shake_drop_probability", lower=0.0, upper=1.0) + cs.add_hyperparameters([use_shake_shake, use_shake_drop, shake_drop_prob]) + cs.add_condition(CS.EqualsCondition(shake_drop_prob, use_shake_drop, True)) + + max_units = UniformIntegerHyperparameter( + "max_units", + lower=min_num_units, + upper=max_num_units, + ) + cs.add_hyperparameters([max_units]) + + max_dropout = UniformFloatHyperparameter( + "max_dropout", lower=0.0, upper=1.0 + ) + cs.add_hyperparameters([max_dropout]) + cs.add_condition(CS.EqualsCondition(max_dropout, use_dropout, True)) + + return cs diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/__init__.py b/autoPyTorch/pipeline/components/setup/network_backbone/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/autoPyTorch/pipeline/components/setup/network/backbone/base_backbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py similarity index 71% rename from autoPyTorch/pipeline/components/setup/network/backbone/base_backbone.py rename to autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py index 62089d892..aad220db6 100644 --- a/autoPyTorch/pipeline/components/setup/network/backbone/base_backbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py @@ -4,13 +4,15 @@ import torch from torch import nn +from autoPyTorch.constants import CLASSIFICATION_TASKS, STRING_TO_TASK_TYPES from autoPyTorch.pipeline.components.base_component import BaseEstimator +from autoPyTorch.pipeline.components.base_choice import autoPyTorchChoice from autoPyTorch.pipeline.components.base_component import ( autoPyTorchComponent, ) -class BaseBackbone(autoPyTorchComponent): +class NetworkBackboneComponent(autoPyTorchComponent): """ Backbone base class """ @@ -26,8 +28,24 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator: """ Not used. Just for API compatibility. """ + input_shape = X['X_train'].shape[1:] + + self.backbone = self.build_backbone( + input_shape=input_shape, + ) return self + def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: + """ + Adds the scheduler into the fit dictionary 'X' and returns it. + Args: + X (Dict[str, Any]): 'X' dictionary + Returns: + (Dict[str, Any]): the updated 'X' dictionary + """ + X.update({'network_backbone': self.backbone}) + return X + @abstractmethod def build_backbone(self, input_shape: Tuple[int, ...]) -> nn.Module: """ diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone_choice.py b/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone_choice.py new file mode 100644 index 000000000..278979d60 --- /dev/null +++ b/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone_choice.py @@ -0,0 +1,185 @@ +import os +from collections import OrderedDict +from typing import Dict, List, Optional + +import ConfigSpace.hyperparameters as CSH +from ConfigSpace.configuration_space import ConfigurationSpace + +import numpy as np + +from autoPyTorch.pipeline.components.base_choice import autoPyTorchChoice +from autoPyTorch.pipeline.components.base_component import ( + ThirdPartyComponents, + autoPyTorchComponent, + find_components, +) +from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import ( + NetworkBackboneComponent, +) + + +directory = os.path.split(__file__)[0] +_backbones = find_components(__package__, + directory, + NetworkBackboneComponent) +_addons = ThirdPartyComponents(NetworkBackboneComponent) + + +def add_backbone(backbone: NetworkBackboneComponent) -> None: + _addons.add_component(backbone) + + +class NetworkBackboneChoice(autoPyTorchChoice): + + def get_components(self) -> Dict[str, autoPyTorchComponent]: + """Returns the available backbone components + + Args: + None + + Returns: + Dict[str, autoPyTorchComponent]: all basebackbone components available + as choices for learning rate scheduling + """ + components = OrderedDict() + components.update(_backbones) + components.update(_addons.components) + return components + + def get_available_components( + self, + dataset_properties: Optional[Dict[str, str]] = None, + include: List[str] = None, + exclude: List[str] = None, + ) -> Dict[str, autoPyTorchComponent]: + """Filters out components based on user provided + include/exclude directives, as well as the dataset properties + + Args: + include (Optional[Dict[str, Any]]): what hyper-parameter configurations + to honor when creating the configuration space + exclude (Optional[Dict[str, Any]]): what hyper-parameter configurations + to remove from the configuration space + dataset_properties (Optional[Dict[str, Union[str, int]]]): Caracteristics + of the dataset to guide the pipeline choices of components + + Returns: + Dict[str, autoPyTorchComponent]: A filtered dict of learning + rate backbones + + """ + if dataset_properties is None: + dataset_properties = {} + + if include is not None and exclude is not None: + raise ValueError( + "The argument include and exclude cannot be used together.") + + available_comp = self.get_components() + + if include is not None: + for incl in include: + if incl not in available_comp: + raise ValueError("Trying to include unknown component: " + "%s" % incl) + + components_dict = OrderedDict() + for name in available_comp: + if include is not None and name not in include: + continue + elif exclude is not None and name in exclude: + continue + + entry = available_comp[name] + + # Exclude itself to avoid infinite loop + if entry == NetworkBackboneChoice or hasattr(entry, 'get_components'): + continue + + task_type = dataset_properties['task_type'] + properties = entry.get_properties() + if 'tabular' in task_type and not properties['handles_tabular']: + continue + elif 'image' in task_type and not properties['handles_image']: + continue + elif 'time_series' in task_type and not properties['handles_time_series']: + continue + + # target_type = dataset_properties['target_type'] + # Apply some automatic filtering here for + # backbones based on the dataset! + # TODO: Think if there is any case where a backbone + # is not recommended for a certain dataset + + components_dict[name] = entry + + return components_dict + + def get_hyperparameter_search_space( + self, + dataset_properties: Optional[Dict[str, str]] = None, + default: Optional[str] = None, + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + ) -> ConfigurationSpace: + """Returns the configuration space of the current chosen components + + Args: + dataset_properties (Optional[Dict[str, str]]): Describes the dataset to work on + default (Optional[str]): Default backbone to use + include: Optional[Dict[str, Any]]: what components to include. It is an exhaustive + list, and will exclusively use this components. + exclude: Optional[Dict[str, Any]]: which components to skip + + Returns: + ConfigurationSpace: the configuration space of the hyper-parameters of the + chosen component + """ + cs = ConfigurationSpace() + + if dataset_properties is None: + dataset_properties = {} + + # Compile a list of legal preprocessors for this problem + available_backbones = self.get_available_components( + dataset_properties=dataset_properties, + include=include, exclude=exclude) + + if len(available_backbones) == 0: + raise ValueError("No backbone found") + + if default is None: + defaults = [ + 'ShapedMLPBackbone', + 'MLPBackbone', + 'ConvNetImageBackbone', + 'InceptionTimeBackbone', + ] + for default_ in defaults: + if default_ in available_backbones: + default = default_ + break + + backbone = CSH.CategoricalHyperparameter( + '__choice__', + list(available_backbones.keys()), + default_value=default + ) + cs.add_hyperparameter(backbone) + for name in available_backbones: + backbone_configuration_space = available_backbones[name]. \ + get_hyperparameter_search_space(dataset_properties) + parent_hyperparameter = {'parent': backbone, 'value': name} + cs.add_configuration_space( + name, + backbone_configuration_space, + parent_hyperparameter=parent_hyperparameter + ) + + self.configuration_space_ = cs + self.dataset_properties_ = dataset_properties + return cs + + def transform(self, X: np.ndarray) -> np.ndarray: + assert self.choice is not None, "Cannot call transform before the object is initialized" + return self.choice.transform(X) diff --git a/autoPyTorch/pipeline/components/setup/network/backbone/image.py b/autoPyTorch/pipeline/components/setup/network_backbone/image.py similarity index 96% rename from autoPyTorch/pipeline/components/setup/network/backbone/image.py rename to autoPyTorch/pipeline/components/setup/network_backbone/image.py index b980bc1bb..bdf6acb68 100644 --- a/autoPyTorch/pipeline/components/setup/network/backbone/image.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/image.py @@ -15,7 +15,7 @@ from torch import nn from torch.nn import functional as F -from autoPyTorch.pipeline.components.setup.network.backbone.base_backbone import BaseBackbone +from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import NetworkBackboneComponent _activations: Dict[str, nn.Module] = { "relu": nn.ReLU, @@ -24,7 +24,7 @@ } -class ConvNetImageBackbone(BaseBackbone): +class ConvNetImageBackbone(NetworkBackboneComponent): supported_tasks = {"image_classification", "image_regression"} def __init__(self, **kwargs: Any): @@ -71,6 +71,9 @@ def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[ return { 'shortname': 'ConvNetImageBackbone', 'name': 'ConvNetImageBackbone', + 'handles_tabular': False, + 'handles_image': True, + 'handles_time_series': False, } @staticmethod @@ -173,7 +176,7 @@ def __init__(self, self.add_module('pool', nn.AvgPool2d(kernel_size=pool_size, stride=pool_size)) -class DenseNetBackbone(BaseBackbone): +class DenseNetBackbone(NetworkBackboneComponent): supported_tasks = {"image_classification", "image_regression"} def __init__(self, **kwargs: Any): @@ -238,6 +241,9 @@ def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[ return { 'shortname': 'DenseNetBackbone', 'name': 'DenseNetBackbone', + 'handles_tabular': False, + 'handles_image': True, + 'handles_time_series': False, } @staticmethod diff --git a/autoPyTorch/pipeline/components/setup/network/backbone/time_series.py b/autoPyTorch/pipeline/components/setup/network_backbone/time_series.py similarity index 96% rename from autoPyTorch/pipeline/components/setup/network/backbone/time_series.py rename to autoPyTorch/pipeline/components/setup/network_backbone/time_series.py index 5ecb5f94c..6663a3565 100644 --- a/autoPyTorch/pipeline/components/setup/network/backbone/time_series.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/time_series.py @@ -12,7 +12,9 @@ from torch import nn from torch.nn.utils import weight_norm -from autoPyTorch.pipeline.components.setup.network.backbone.base_backbone import BaseBackbone +from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import ( + NetworkBackboneComponent, +) # Code inspired by https://github.com/hfawaz/InceptionTime @@ -124,7 +126,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class InceptionTimeBackbone(BaseBackbone): +class InceptionTimeBackbone(NetworkBackboneComponent): supported_tasks = {"time_series_classification", "time_series_regression"} def build_backbone(self, input_shape: Tuple[int, ...]) -> nn.Module: @@ -138,6 +140,9 @@ def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[ return { 'shortname': 'InceptionTimeBackbone', 'name': 'InceptionTimeBackbone', + 'handles_tabular': False, + 'handles_image': False, + 'handles_time_series': True, } @staticmethod @@ -254,7 +259,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class TCNBackbone(BaseBackbone): +class TCNBackbone(NetworkBackboneComponent): supported_tasks = {"time_series_classification", "time_series_regression"} def build_backbone(self, input_shape: Tuple[int, ...]) -> nn.Module: @@ -274,6 +279,9 @@ def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[ return { "shortname": "TCNBackbone", "name": "TCNBackbone", + 'handles_tabular': False, + 'handles_image': False, + 'handles_time_series': True, } @staticmethod diff --git a/autoPyTorch/pipeline/components/setup/network/utils.py b/autoPyTorch/pipeline/components/setup/network_backbone/utils.py similarity index 92% rename from autoPyTorch/pipeline/components/setup/network/utils.py rename to autoPyTorch/pipeline/components/setup/network_backbone/utils.py index 8ae1f0f39..45a96a362 100644 --- a/autoPyTorch/pipeline/components/setup/network/utils.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/utils.py @@ -10,6 +10,28 @@ __license__ = "BSD" +_activations = { + "relu": torch.nn.ReLU, + "tanh": torch.nn.Tanh, + "sigmoid": torch.nn.Sigmoid +} + + +def get_output_shape(network: torch.nn.Module, input_shape: typing.Tuple[int, ...] + ) -> typing.Tuple[int, ...]: + """ + Run a dummy forward pass to get the output shape of the backbone. + Can and should be overridden by subclasses that know the output shape + without running a dummy forward pass. + :param input_shape: shape of the input + :return: output_shape + """ + placeholder = torch.randn((2, *input_shape), dtype=torch.float) + with torch.no_grad(): + output = network(placeholder) + return tuple(output.shape[1:]) + + class ShakeShakeFunction(Function): @staticmethod def forward( diff --git a/autoPyTorch/pipeline/components/setup/network_head/__init__.py b/autoPyTorch/pipeline/components/setup/network_head/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/autoPyTorch/pipeline/components/setup/network/head/base_head.py b/autoPyTorch/pipeline/components/setup/network_head/base_network_head.py similarity index 53% rename from autoPyTorch/pipeline/components/setup/network/head/base_head.py rename to autoPyTorch/pipeline/components/setup/network_head/base_network_head.py index c4d17fd5f..72a34fefe 100644 --- a/autoPyTorch/pipeline/components/setup/network/head/base_head.py +++ b/autoPyTorch/pipeline/components/setup/network_head/base_network_head.py @@ -3,10 +3,12 @@ import torch.nn as nn +from autoPyTorch.constants import CLASSIFICATION_TASKS, STRING_TO_TASK_TYPES from autoPyTorch.pipeline.components.base_component import BaseEstimator, autoPyTorchComponent +from autoPyTorch.pipeline.components.setup.network_backbone.utils import get_output_shape -class BaseHead(autoPyTorchComponent): +class NetworkHeadComponent(autoPyTorchComponent): """ Head base class """ @@ -22,8 +24,28 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator: """ Not used. Just for API compatibility. """ + input_shape = X['X_train'].shape[1:] + output_shape = (X['dataset_properties']['num_classes'],) if \ + STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']] in \ + CLASSIFICATION_TASKS else X['dataset_properties']['output_shape'] + + self.head = self.build_head( + input_shape=get_output_shape(X['network_backbone'], input_shape=input_shape), + output_shape=output_shape, + ) return self + def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: + """ + Adds the scheduler into the fit dictionary 'X' and returns it. + Args: + X (Dict[str, Any]): 'X' dictionary + Returns: + (Dict[str, Any]): the updated 'X' dictionary + """ + X.update({'network_head': self.head}) + return X + @abstractmethod def build_head(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> nn.Module: """ diff --git a/autoPyTorch/pipeline/components/setup/network/base_network_choice.py b/autoPyTorch/pipeline/components/setup/network_head/base_network_head_choice.py similarity index 60% rename from autoPyTorch/pipeline/components/setup/network/base_network_choice.py rename to autoPyTorch/pipeline/components/setup/network_head/base_network_head_choice.py index 2f840a508..116a707f3 100644 --- a/autoPyTorch/pipeline/components/setup/network/base_network_choice.py +++ b/autoPyTorch/pipeline/components/setup/network_head/base_network_head_choice.py @@ -13,42 +13,48 @@ autoPyTorchComponent, find_components, ) -from autoPyTorch.pipeline.components.setup.network.base_network import BaseNetworkComponent +from autoPyTorch.pipeline.components.setup.network_head.base_network_head import ( + NetworkHeadComponent, +) + directory = os.path.split(__file__)[0] -_networks = find_components(__package__, - directory, - BaseNetworkComponent) -_addons = ThirdPartyComponents(BaseNetworkComponent) +_heads = find_components(__package__, + directory, + NetworkHeadComponent) +_addons = ThirdPartyComponents(NetworkHeadComponent) -def add_network(network: BaseNetworkComponent) -> None: - _addons.add_component(network) +def add_head(head: NetworkHeadComponent) -> None: + _addons.add_component(head) -class NetworkChoice(autoPyTorchChoice): +class NetworkHeadChoice(autoPyTorchChoice): def get_components(self) -> Dict[str, autoPyTorchComponent]: - """Returns the available network components + """Returns the available head components + Args: None + Returns: - Dict[str, autoPyTorchComponent]: all baseNetwork components available - as choices + Dict[str, autoPyTorchComponent]: all basehead components available + as choices for learning rate scheduling """ components = OrderedDict() - components.update(_networks) + components.update(_heads) components.update(_addons.components) return components def get_available_components( - self, - dataset_properties: Optional[Dict[str, str]] = None, - include: List[str] = None, - exclude: List[str] = None, + self, + dataset_properties: Optional[Dict[str, str]] = None, + include: List[str] = None, + exclude: List[str] = None, ) -> Dict[str, autoPyTorchComponent]: """Filters out components based on user provided include/exclude directives, as well as the dataset properties + Args: include (Optional[Dict[str, Any]]): what hyper-parameter configurations to honor when creating the configuration space @@ -56,9 +62,11 @@ def get_available_components( to remove from the configuration space dataset_properties (Optional[Dict[str, Union[str, int]]]): Caracteristics of the dataset to guide the pipeline choices of components + Returns: - Dict[str, autoPyTorchComponent]: A filtered dict of Network - components + Dict[str, autoPyTorchComponent]: A filtered dict of learning + rate heads + """ if dataset_properties is None: dataset_properties = {} @@ -72,11 +80,8 @@ def get_available_components( if include is not None: for incl in include: if incl not in available_comp: - raise ValueError( - "Trying to include unknown component: {} from {}".format( - incl, - available_comp, - )) + raise ValueError("Trying to include unknown component: " + "%s" % incl) components_dict = OrderedDict() for name in available_comp: @@ -88,30 +93,44 @@ def get_available_components( entry = available_comp[name] # Exclude itself to avoid infinite loop - if entry == NetworkChoice or hasattr(entry, 'get_components'): + if entry == NetworkHeadChoice or hasattr(entry, 'get_components'): + continue + + task_type = dataset_properties['task_type'] + properties = entry.get_properties() + if 'tabular' in task_type and not properties['handles_tabular']: + continue + elif 'image' in task_type and not properties['handles_image']: + continue + elif 'time_series' in task_type and not properties['handles_time_series']: continue # target_type = dataset_properties['target_type'] - # Apply some automatic filtering here based on dataset + # Apply some automatic filtering here for + # heads based on the dataset! + # TODO: Think if there is any case where a head + # is not recommended for a certain dataset components_dict[name] = entry return components_dict def get_hyperparameter_search_space( - self, - dataset_properties: Optional[Dict[str, str]] = None, - default: Optional[str] = None, - include: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, + self, + dataset_properties: Optional[Dict[str, str]] = None, + default: Optional[str] = None, + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, ) -> ConfigurationSpace: """Returns the configuration space of the current chosen components + Args: dataset_properties (Optional[Dict[str, str]]): Describes the dataset to work on - default (Optional[str]): Default component to use + default (Optional[str]): Default head to use include: Optional[Dict[str, Any]]: what components to include. It is an exhaustive list, and will exclusively use this components. exclude: Optional[Dict[str, Any]]: which components to skip + Returns: ConfigurationSpace: the configuration space of the hyper-parameters of the chosen component @@ -122,33 +141,36 @@ def get_hyperparameter_search_space( dataset_properties = {} # Compile a list of legal preprocessors for this problem - available_networks = self.get_available_components( + available_heads = self.get_available_components( dataset_properties=dataset_properties, include=include, exclude=exclude) - if len(available_networks) == 0: - raise ValueError("No Network found") + if len(available_heads) == 0: + raise ValueError("No head found") if default is None: - defaults = ['BackboneHeadNet'] + defaults = [ + 'FullyConnectedHead', + 'FullyConvolutional2DHead', + ] for default_ in defaults: - if default_ in available_networks: + if default_ in available_heads: default = default_ break - network = CSH.CategoricalHyperparameter( + head = CSH.CategoricalHyperparameter( '__choice__', - list(available_networks.keys()), + list(available_heads.keys()), default_value=default ) - cs.add_hyperparameter(network) - for name in available_networks: - network_configuration_space = available_networks[name]. \ + cs.add_hyperparameter(head) + for name in available_heads: + head_configuration_space = available_heads[name]. \ get_hyperparameter_search_space(dataset_properties) - parent_hyperparameter = {'parent': network, 'value': name} + parent_hyperparameter = {'parent': head, 'value': name} cs.add_configuration_space( name, - network_configuration_space, + head_configuration_space, parent_hyperparameter=parent_hyperparameter ) diff --git a/autoPyTorch/pipeline/components/setup/network/head/fully_connected.py b/autoPyTorch/pipeline/components/setup/network_head/fully_connected.py similarity index 91% rename from autoPyTorch/pipeline/components/setup/network/head/fully_connected.py rename to autoPyTorch/pipeline/components/setup/network_head/fully_connected.py index b17f390a0..97bb86789 100644 --- a/autoPyTorch/pipeline/components/setup/network/head/fully_connected.py +++ b/autoPyTorch/pipeline/components/setup/network_head/fully_connected.py @@ -8,7 +8,9 @@ from torch import nn -from autoPyTorch.pipeline.components.setup.network.head.base_head import BaseHead +from autoPyTorch.pipeline.components.setup.network_head.base_network_head import ( + NetworkHeadComponent, +) _activations: Dict[str, nn.Module] = { "relu": nn.ReLU, @@ -17,7 +19,7 @@ } -class FullyConnectedHead(BaseHead): +class FullyConnectedHead(NetworkHeadComponent): """ Standard head consisting of a number of fully connected layers. Flattens any input in a array of shape [B, prod(input_shape)]. @@ -44,6 +46,9 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ return { 'shortname': 'FullyConnectedHead', 'name': 'FullyConnectedHead', + 'handles_tabular': True, + 'handles_image': False, + 'handles_time_series': False, } @staticmethod diff --git a/autoPyTorch/pipeline/components/setup/network/head/fully_convolutional.py b/autoPyTorch/pipeline/components/setup/network_head/fully_convolutional.py similarity index 94% rename from autoPyTorch/pipeline/components/setup/network/head/fully_convolutional.py rename to autoPyTorch/pipeline/components/setup/network_head/fully_convolutional.py index 54ea887ec..7c0f85a2c 100644 --- a/autoPyTorch/pipeline/components/setup/network/head/fully_convolutional.py +++ b/autoPyTorch/pipeline/components/setup/network_head/fully_convolutional.py @@ -7,7 +7,9 @@ import torch from torch import nn -from autoPyTorch.pipeline.components.setup.network.head.base_head import BaseHead +from autoPyTorch.pipeline.components.setup.network_head.base_network_head import ( + NetworkHeadComponent, +) _activations: Dict[str, nn.Module] = { "relu": nn.ReLU, @@ -49,7 +51,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.head(x).view(B, -1) -class FullyConvolutional2DHead(BaseHead): +class FullyConvolutional2DHead(NetworkHeadComponent): """ Head consisting of a number of 2d convolutional connected layers. Applies a global pooling operation in the end. @@ -70,6 +72,9 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ return { 'shortname': 'FullyConvolutionalHead', 'name': 'FullyConvolutionalHead', + 'handles_tabular': False, + 'handles_image': True, + 'handles_time_series': False, } @staticmethod diff --git a/autoPyTorch/pipeline/components/setup/network_initializer/base_network_initializer.py b/autoPyTorch/pipeline/components/setup/network_initializer/base_network_initializer.py index cf89ab067..306c798da 100644 --- a/autoPyTorch/pipeline/components/setup/network_initializer/base_network_initializer.py +++ b/autoPyTorch/pipeline/components/setup/network_initializer/base_network_initializer.py @@ -91,9 +91,4 @@ def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, def __str__(self) -> str: """ Allow a nice understanding of what components where used """ string = self.__class__.__name__ - info = vars(self) - # Remove unwanted info - info.pop('strategy', None) - info.pop('random_state', None) - string += " (" + str(info) + ")" return string diff --git a/autoPyTorch/pipeline/components/setup/optimizer/base_optimizer.py b/autoPyTorch/pipeline/components/setup/optimizer/base_optimizer.py index a831e8db8..527238cb0 100644 --- a/autoPyTorch/pipeline/components/setup/optimizer/base_optimizer.py +++ b/autoPyTorch/pipeline/components/setup/optimizer/base_optimizer.py @@ -42,8 +42,5 @@ def __str__(self) -> str: """ Allow a nice understanding of what components where used """ string = self.optimizer.__class__.__name__ info = vars(self) - # Remove unwanted info - info.pop('optimizer', None) - info.pop('random_state', None) string += " (" + str(info) + ")" return string diff --git a/autoPyTorch/pipeline/components/setup/traditional_ml/base_model.py b/autoPyTorch/pipeline/components/setup/traditional_ml/base_model.py index f9c11a32b..620c49c45 100644 --- a/autoPyTorch/pipeline/components/setup/traditional_ml/base_model.py +++ b/autoPyTorch/pipeline/components/setup/traditional_ml/base_model.py @@ -141,10 +141,4 @@ def check_requirements(self, X: Dict[str, Any], y: Any = None) -> None: def __str__(self) -> str: """ Allow a nice understanding of what components where used """ string = self.model.__class__.__name__ - info = vars(self) - # Remove unwanted info - info.pop('model', None) - info.pop('random_state', None) - info.pop('fit_output', None) - string += " (" + str(info) + ")" return string diff --git a/autoPyTorch/pipeline/components/setup/traditional_ml/tabular_classifier.py b/autoPyTorch/pipeline/components/setup/traditional_ml/tabular_classifier.py index 6f46e754c..03343d9f3 100644 --- a/autoPyTorch/pipeline/components/setup/traditional_ml/tabular_classifier.py +++ b/autoPyTorch/pipeline/components/setup/traditional_ml/tabular_classifier.py @@ -63,9 +63,4 @@ def build_model(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ... def __str__(self) -> str: """ Allow a nice understanding of what components where used """ - info = vars(self) - # Remove unwanted info - info.pop('random_state', None) - info.pop('fit_output', None) - info.pop('config', None) - return f"TabularClassifier: {self.model.name if self.model is not None else None} ({str(info)})" + return f"TabularClassifier: {self.model.name if self.model is not None else None}" diff --git a/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py b/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py index 0049d8e38..bfcffe593 100644 --- a/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py +++ b/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py @@ -259,12 +259,4 @@ def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None def __str__(self) -> str: """ Allow a nice understanding of what components where used """ string = self.train_data_loader.__class__.__name__ - info = vars(self) - # Remove unwanted info - info.pop('train_data_loader', None) - info.pop('val_data_loader', None) - info.pop('test_data_loader', None) - info.pop('vision_datasets', None) - info.pop('random_state', None) - string += " (" + str(info) + ")" return string diff --git a/autoPyTorch/pipeline/components/training/trainer/base_trainer_choice.py b/autoPyTorch/pipeline/components/training/trainer/base_trainer_choice.py index e65086cb6..971eed3f5 100755 --- a/autoPyTorch/pipeline/components/training/trainer/base_trainer_choice.py +++ b/autoPyTorch/pipeline/components/training/trainer/base_trainer_choice.py @@ -67,7 +67,6 @@ def __init__(self, self._fit_requirements: Optional[List[FitRequirement]] = [ FitRequirement("lr_scheduler", (_LRScheduler,), user_defined=False, dataset_property=False), FitRequirement("job_id", (str,), user_defined=False, dataset_property=False), - FitRequirement("network", (torch.nn.Sequential,), user_defined=False, dataset_property=False), FitRequirement( "optimizer", (Optimizer,), user_defined=False, dataset_property=False), FitRequirement("train_data_loader", diff --git a/autoPyTorch/pipeline/tabular_classification.py b/autoPyTorch/pipeline/tabular_classification.py index fa8947f62..4b0a743be 100644 --- a/autoPyTorch/pipeline/tabular_classification.py +++ b/autoPyTorch/pipeline/tabular_classification.py @@ -20,7 +20,9 @@ from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling.base_scaler_choice import ScalerChoice from autoPyTorch.pipeline.components.setup.early_preprocessor.EarlyPreprocessing import EarlyPreprocessing from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler_choice import SchedulerChoice -from autoPyTorch.pipeline.components.setup.network.base_network_choice import NetworkChoice +from autoPyTorch.pipeline.components.setup.network.base_network import NetworkComponent +from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone_choice import NetworkBackboneChoice +from autoPyTorch.pipeline.components.setup.network_head.base_network_head_choice import NetworkHeadChoice from autoPyTorch.pipeline.components.setup.network_initializer.base_network_init_choice import ( NetworkInitializerChoice ) @@ -237,7 +239,9 @@ def _get_pipeline_steps(self, dataset_properties: Optional[Dict[str, Any]], ("scaler", ScalerChoice(default_dataset_properties)), ("tabular_transformer", TabularColumnTransformer()), ("preprocessing", EarlyPreprocessing()), - ("network", NetworkChoice(default_dataset_properties)), + ("network_backbone", NetworkBackboneChoice(default_dataset_properties)), + ("network_head", NetworkHeadChoice(default_dataset_properties)), + ("network", NetworkComponent(default_dataset_properties)), ("network_init", NetworkInitializerChoice(default_dataset_properties)), ("optimizer", OptimizerChoice(default_dataset_properties)), ("lr_scheduler", SchedulerChoice(default_dataset_properties)), diff --git a/autoPyTorch/pipeline/tabular_regression.py b/autoPyTorch/pipeline/tabular_regression.py index 3909e3484..2e3efad85 100644 --- a/autoPyTorch/pipeline/tabular_regression.py +++ b/autoPyTorch/pipeline/tabular_regression.py @@ -20,7 +20,9 @@ from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling.base_scaler_choice import ScalerChoice from autoPyTorch.pipeline.components.setup.early_preprocessor.EarlyPreprocessing import EarlyPreprocessing from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler_choice import SchedulerChoice -from autoPyTorch.pipeline.components.setup.network.base_network_choice import NetworkChoice +from autoPyTorch.pipeline.components.setup.network.base_network import NetworkComponent +from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone_choice import NetworkBackboneChoice +from autoPyTorch.pipeline.components.setup.network_head.base_network_head_choice import NetworkHeadChoice from autoPyTorch.pipeline.components.setup.network_initializer.base_network_init_choice import ( NetworkInitializerChoice ) @@ -189,7 +191,9 @@ def _get_pipeline_steps(self, dataset_properties: Optional[Dict[str, Any]], ("scaler", ScalerChoice(default_dataset_properties)), ("tabular_transformer", TabularColumnTransformer()), ("preprocessing", EarlyPreprocessing()), - ("network", NetworkChoice(default_dataset_properties)), + ("network_backbone", NetworkBackboneChoice(default_dataset_properties)), + ("network_head", NetworkHeadChoice(default_dataset_properties)), + ("network", NetworkComponent(default_dataset_properties)), ("network_init", NetworkInitializerChoice(default_dataset_properties)), ("optimizer", OptimizerChoice(default_dataset_properties)), ("lr_scheduler", SchedulerChoice(default_dataset_properties)), diff --git a/test/test_pipeline/components/test_setup.py b/test/test_pipeline/components/test_setup.py index 706d0718c..9ab961ddc 100644 --- a/test/test_pipeline/components/test_setup.py +++ b/test/test_pipeline/components/test_setup.py @@ -3,14 +3,9 @@ from ConfigSpace.configuration_space import ConfigurationSpace -import numpy as np - from sklearn.base import clone -import torch - import autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler_choice as lr_components -import autoPyTorch.pipeline.components.setup.network.base_network_choice as network_components import \ autoPyTorch.pipeline.components.setup.network_initializer.base_network_init_choice as network_initializer_components # noqa: E501 import autoPyTorch.pipeline.components.setup.optimizer.base_optimizer_choice as optimizer_components @@ -18,9 +13,8 @@ BaseLRComponent, SchedulerChoice ) -from autoPyTorch.pipeline.components.setup.network.base_network_choice import ( - BaseNetworkComponent, - NetworkChoice +from autoPyTorch.pipeline.components.setup.network_head.base_network_head_choice import ( + NetworkHeadChoice, ) from autoPyTorch.pipeline.components.setup.network_initializer.base_network_init_choice import ( BaseNetworkInitializerComponent, @@ -64,22 +58,6 @@ def get_properties(dataset_properties=None): } -class DummyNet(BaseNetworkComponent): - def __init__(self, random_state=None): - pass - - @staticmethod - def get_hyperparameter_search_space(dataset_properties=None): - cs = ConfigurationSpace() - return cs - - def get_properties(dataset_properties=None): - return { - 'shortname': 'Dummy', - 'name': 'Dummy', - } - - class DummyNetworkInitializer(BaseNetworkInitializerComponent): def __init__(self, random_state=None): pass @@ -268,8 +246,8 @@ def test_optimizer_add(self): self.assertIn('DummyOptimizer', str(cs)) -class NetworkTest(unittest.TestCase): - def test_every_network_is_valid(self): +class NetworkHeadTest(unittest.TestCase): + def test_every_networkHead_is_valid(self): """ Makes sure that every network is a valid estimator. That is, we can fully create an object via get/set params. @@ -277,18 +255,18 @@ def test_every_network_is_valid(self): This also test that we can properly initialize each one of them """ - network_choice = NetworkChoice(dataset_properties={}) + networkHead_choice = NetworkHeadChoice(dataset_properties={'task_type': 'tabular_classification'}) # Make sure all components are returned - self.assertEqual(len(network_choice.get_components().keys()), 1) + self.assertEqual(len(networkHead_choice.get_components().keys()), 2) # For every network in the components, make sure # that it complies with the scikit learn estimator. # This is important because usually components are forked to workers, # so the set/get params methods should recreate the same object - for name, network in network_choice.get_components().items(): - config = network.get_hyperparameter_search_space().sample_configuration() - estimator = network(**config) + for name, networkHead in networkHead_choice.get_components().items(): + config = networkHead.get_hyperparameter_search_space().sample_configuration() + estimator = networkHead(**config) estimator_clone = clone(estimator) estimator_clone_params = estimator_clone.get_params() @@ -309,45 +287,17 @@ def test_every_network_is_valid(self): param2 = params_set[name] self.assertEqual(param1, param2) - def test_backbone_head_net(self): - network_choice = NetworkChoice(dataset_properties={}) - task_types = {"image_classification": ((1, 3, 64, 64), (5,)), - "image_regression": ((1, 3, 64, 64), (1,)), - "time_series_classification": ((1, 32, 6), (5,)), - "time_series_regression": ((1, 32, 6), (1,)), - "tabular_classification": ((1, 100,), (5,)), - "tabular_regression": ((1, 100), (1,))} - - device = torch.device("cpu") - for task_type, (input_shape, output_shape) in task_types.items(): - cs = network_choice.get_hyperparameter_search_space(dataset_properties={"task_type": task_type}, - include=["BackboneHeadNet"]) - # test 10 random configurations - for i in range(10): - config = cs.sample_configuration() - network_choice.set_hyperparameters(config) - network_choice.fit(X={"X_train": np.zeros(input_shape), - "y_train": np.zeros(output_shape), - 'dataset_properties': {"task_type": task_type, - 'input_shape': input_shape[1:], - "output_shape": output_shape, - "num_classes": output_shape[0]}}, y=None) - self.assertNotEqual(network_choice.choice.network, None) - network_choice.choice.to(device) - dummy_input = torch.randn((2, *input_shape[1:]), dtype=torch.float) - output = network_choice.choice.network(dummy_input) - self.assertEqual(output.shape[1:], output_shape) - def test_get_set_config_space(self): - """Make sure that we can setup a valid choice in the network + """Make sure that we can setup a valid choice in the networkHead choice""" - network_choice = NetworkChoice(dataset_properties={}) - cs = network_choice.get_hyperparameter_search_space() + networkHead_choice = NetworkHeadChoice(dataset_properties={'task_type': 'tabular_classification'}) + cs = networkHead_choice.get_hyperparameter_search_space( + dataset_properties={"task_type": 'tabular_classification'}) # Make sure that all hyperparameters are part of the search space self.assertListEqual( sorted(cs.get_hyperparameter('__choice__').choices), - sorted(list(network_choice.get_components().keys())) + ['fully_connected'] ) # Make sure we can properly set some random configs @@ -357,10 +307,10 @@ def test_get_set_config_space(self): for i in range(5): config = cs.sample_configuration() config_dict = copy.deepcopy(config.get_dictionary()) - network_choice.set_hyperparameters(config) + networkHead_choice.set_hyperparameters(config) - self.assertEqual(network_choice.choice.__class__, - network_choice.get_components()[config_dict['__choice__']]) + self.assertEqual(networkHead_choice.choice.__class__, + networkHead_choice.get_components()[config_dict['__choice__']]) # Then check the choice configuration selected_choice = config_dict.pop('__choice__', None) @@ -371,22 +321,11 @@ def test_get_set_config_space(self): key = key.replace(selected_choice + ':', '') # In the case of MLP, parameters are dynamic, so they exist in config - parameters = vars(network_choice.choice) - parameters.update(vars(network_choice.choice)['config']) + parameters = vars(networkHead_choice.choice) + parameters.update(vars(networkHead_choice.choice)['config']) self.assertIn(key, parameters) self.assertEqual(value, parameters[key]) - def test_network_add(self): - """Makes sure that a component can be added to the CS""" - # No third party components to start with - self.assertEqual(len(network_components._addons.components), 0) - - # Then make sure the scheduler can be added and query'ed - network_components.add_network(DummyNet) - self.assertEqual(len(network_components._addons.components), 1) - cs = NetworkChoice(dataset_properties={}).get_hyperparameter_search_space() - self.assertIn('DummyNet', str(cs)) - class NetworkInitializerTest(unittest.TestCase): def test_every_network_initializer_is_valid(self): diff --git a/test/test_pipeline/test_tabular_classification.py b/test/test_pipeline/test_tabular_classification.py index d01fcaa08..6d19833ad 100644 --- a/test/test_pipeline/test_tabular_classification.py +++ b/test/test_pipeline/test_tabular_classification.py @@ -35,7 +35,7 @@ def test_pipeline_fit(self, fit_dictionary): assert 'accuracy' in run_summary.performance_tracker['train_metrics'][1] # Make sure a network was fit - assert isinstance(pipeline.named_steps['network'].choice.get_network(), torch.nn.Module) + assert isinstance(pipeline.named_steps['network'].get_network(), torch.nn.Module) def test_pipeline_predict(self, fit_dictionary): """This test makes sure that the pipeline is able to fit @@ -144,6 +144,8 @@ def test_network_optimizer_lr_handshake(self, fit_dictionary): # Make sure that fitting a network adds a "network" to X assert 'network' in pipeline.named_steps.keys() + fit_dictionary['network_backbone'] = torch.nn.Linear(3, 4) + fit_dictionary['network_head'] = torch.nn.Linear(4, 1) X = pipeline.named_steps['network'].fit( fit_dictionary, None @@ -175,7 +177,8 @@ def test_network_optimizer_lr_handshake(self, fit_dictionary): assert 'optimizer' in X def test_get_fit_requirements(self, fit_dictionary): - dataset_properties = {'numerical_columns': [], 'categorical_columns': []} + dataset_properties = {'numerical_columns': [], 'categorical_columns': [], + 'task_type': 'tabular_classification'} pipeline = TabularClassificationPipeline(dataset_properties=dataset_properties) fit_requirements = pipeline.get_fit_requirements() From d2a5878d2e03c2fcb5d3c70f30009fdb0b1c8a3c Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Tue, 26 Jan 2021 15:50:26 +0100 Subject: [PATCH 2/5] Fix test flake 8 --- .../setup/network_backbone/MLPBackbone.py | 4 +- .../setup/network_backbone/ResNetBackbone.py | 4 +- .../network_backbone/ShapedMLPBackbone.py | 4 +- .../network_backbone/ShapedResNetBackbone.py | 4 +- .../network_backbone/base_network_backbone.py | 2 - .../setup/network_head/fully_connected.py | 4 +- .../setup/network_head/fully_convolutional.py | 4 +- .../training/data_loader/base_data_loader.py | 2 +- .../components/test_setup_networks.py | 65 +++++++++++++++++++ 9 files changed, 78 insertions(+), 15 deletions(-) create mode 100644 test/test_pipeline/components/test_setup_networks.py diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/MLPBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/MLPBackbone.py index 2c30f7992..d261841eb 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/MLPBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/MLPBackbone.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import ConfigSpace as CS from ConfigSpace.configuration_space import ConfigurationSpace @@ -63,7 +63,7 @@ def _add_layer(self, layers: List[nn.Module], in_features: int, out_features: in layers.append(nn.Dropout(self.config["dropout_%d" % layer_id])) @staticmethod - def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, str]: + def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]: return { 'shortname': 'MLPBackbone', 'name': 'MLPBackbone', diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py index 7c4a5ecc1..9716f8aab 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import ConfigSpace as CS from ConfigSpace.configuration_space import ConfigurationSpace @@ -86,7 +86,7 @@ def _add_group(self, in_features: int, out_features: int, return nn.Sequential(*blocks) @staticmethod - def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, str]: + def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]: return { 'shortname': 'ResNetBackbone', 'name': 'ResidualNetworkBackbone', diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedMLPBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedMLPBackbone.py index 41e69f37a..c7cef2fd6 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedMLPBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedMLPBackbone.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import ConfigSpace as CS from ConfigSpace.configuration_space import ConfigurationSpace @@ -68,7 +68,7 @@ def _add_layer(self, layers: List[nn.Module], layers.append(nn.Dropout(dropout)) @staticmethod - def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, str]: + def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]: return { 'shortname': 'ShapedMLPBackbone', 'name': 'ShapedMLPBackbone', diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py index 10f34de8d..16435d0c9 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import ConfigSpace as CS from ConfigSpace.configuration_space import ConfigurationSpace @@ -71,7 +71,7 @@ def build_backbone(self, input_shape: Tuple[int, ...]) -> None: return backbone @staticmethod - def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, str]: + def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]: return { 'shortname': 'ShapedResNetBackbone', 'name': 'ShapedResidualNetworkBackbone', diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py index aad220db6..639975c1d 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py @@ -4,9 +4,7 @@ import torch from torch import nn -from autoPyTorch.constants import CLASSIFICATION_TASKS, STRING_TO_TASK_TYPES from autoPyTorch.pipeline.components.base_component import BaseEstimator -from autoPyTorch.pipeline.components.base_choice import autoPyTorchChoice from autoPyTorch.pipeline.components.base_component import ( autoPyTorchComponent, ) diff --git a/autoPyTorch/pipeline/components/setup/network_head/fully_connected.py b/autoPyTorch/pipeline/components/setup/network_head/fully_connected.py index 97bb86789..bada209ec 100644 --- a/autoPyTorch/pipeline/components/setup/network_head/fully_connected.py +++ b/autoPyTorch/pipeline/components/setup/network_head/fully_connected.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import ConfigSpace as CS from ConfigSpace.configuration_space import ConfigurationSpace @@ -42,7 +42,7 @@ def build_head(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...] return nn.Sequential(*layers) @staticmethod - def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, str]: + def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]: return { 'shortname': 'FullyConnectedHead', 'name': 'FullyConnectedHead', diff --git a/autoPyTorch/pipeline/components/setup/network_head/fully_convolutional.py b/autoPyTorch/pipeline/components/setup/network_head/fully_convolutional.py index 7c0f85a2c..ed83fc32e 100644 --- a/autoPyTorch/pipeline/components/setup/network_head/fully_convolutional.py +++ b/autoPyTorch/pipeline/components/setup/network_head/fully_convolutional.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import ConfigSpace as CS from ConfigSpace.configuration_space import ConfigurationSpace @@ -68,7 +68,7 @@ def build_head(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...] for i in range(1, self.config["num_layers"])]) @staticmethod - def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, str]: + def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]: return { 'shortname': 'FullyConvolutionalHead', 'name': 'FullyConvolutionalHead', diff --git a/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py b/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py index bfcffe593..4dd509e17 100644 --- a/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py +++ b/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py @@ -251,7 +251,7 @@ def get_torchvision_datasets(self) -> Dict[str, torchvision.datasets.VisionDatas def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None ) -> ConfigurationSpace: batch_size = UniformIntegerHyperparameter( - "batch_size", 32, 320, default_value=64) + "batch_size", 16, 512, default_value=64) cs = ConfigurationSpace() cs.add_hyperparameters([batch_size]) return cs diff --git a/test/test_pipeline/components/test_setup_networks.py b/test/test_pipeline/components/test_setup_networks.py new file mode 100644 index 000000000..2d9551e48 --- /dev/null +++ b/test/test_pipeline/components/test_setup_networks.py @@ -0,0 +1,65 @@ +import os +import sys + +import pytest + +import torch + +from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline + + +# Disable +def blockPrint(): + sys.stdout = open(os.devnull, 'w') + + +# Restore +def enablePrint(): + sys.stdout = sys.__stdout__ + + +@pytest.fixture(params=['MLPBackbone', 'ResNetBackbone', 'ShapedMLPBackbone', 'ShapedResNetBackbone']) +def backbone(request): + return request.param + + +@pytest.fixture(params=['fully_connected']) +def head(request): + return request.param + + +@pytest.mark.parametrize("fit_dictionary", ['fit_dictionary_numerical_only', + 'fit_dictionary_categorical_only', + 'fit_dictionary_num_and_categorical'], indirect=True) +class TestNetworks: + def test_pipeline_fit(self, fit_dictionary, backbone, head): + """This test makes sure that the pipeline is able to fit + given random combinations of hyperparameters across the pipeline""" + + pipeline = TabularClassificationPipeline( + dataset_properties=fit_dictionary['dataset_properties'], + include={'network_backbone': [backbone], 'network_head': [head]}) + cs = pipeline.get_hyperparameter_search_space() + config = cs.get_default_configuration() + + assert backbone == config.get('network_backbone:__choice__', None) + assert head == config.get('network_head:__choice__', None) + pipeline.set_hyperparameters(config) + pipeline.fit(fit_dictionary) + + # To make sure we fitted the model, there should be a + # run summary object with accuracy + run_summary = pipeline.named_steps['trainer'].run_summary + assert run_summary is not None + + # Make sure that performance was properly captured + assert run_summary.performance_tracker['train_loss'][1] > 0 + assert run_summary.total_parameter_count > 0 + assert 'accuracy' in run_summary.performance_tracker['train_metrics'][1] + + # Commented out the next line as some pipelines are not + # achieving this accuracy with default configuration and 10 epochs + # To be added once we fix the search space + # assert run_summary.performance_tracker['val_metrics'][fit_dictionary['epochs']]['accuracy'] >= 0.8 + # Make sure a network was fit + assert isinstance(pipeline.named_steps['network'].get_network(), torch.nn.Module) From 5914b8f0b84af72fe0ab2a1cb17ebc977f4e842e Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Wed, 27 Jan 2021 14:40:01 +0100 Subject: [PATCH 3/5] fix test api --- test/test_api/test_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index 60973c722..56ef760a0 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -107,7 +107,7 @@ def test_classification(openml_id, resampling_strategy, backend): assert os.path.exists(model_file), model_file model = estimator._backend.load_model_by_seed_and_id_and_budget( estimator.seed, run_key.config_id, run_key.budget) - assert isinstance(model.named_steps['network'].choice.get_network(), torch.nn.Module) + assert isinstance(model.named_steps['network'].get_network(), torch.nn.Module) elif resampling_strategy == CrossValTypes.k_fold_cross_validation: model_file = os.path.join( run_key_model_run_dir, @@ -118,7 +118,7 @@ def test_classification(openml_id, resampling_strategy, backend): estimator.seed, run_key.config_id, run_key.budget) assert isinstance(model, VotingClassifier) assert len(model.estimators_) == 3 - assert isinstance(model.estimators_[0].named_steps['network'].choice.get_network(), + assert isinstance(model.estimators_[0].named_steps['network'].get_network(), torch.nn.Module) else: pytest.fail(resampling_strategy) From 3d1c9c49976ba910370c0c46cfec9dc9b2ea624f Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Wed, 27 Jan 2021 19:00:48 +0100 Subject: [PATCH 4/5] increased time for func eval in cros validation --- test/test_api/test_api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index 56ef760a0..306209bb8 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -27,8 +27,8 @@ # Test # ======== @pytest.mark.parametrize('openml_id', (40981, )) -@pytest.mark.parametrize('resampling_strategy', (HoldoutValTypes.holdout_validation, - CrossValTypes.k_fold_cross_validation, )) +@pytest.mark.parametrize('resampling_strategy', (CrossValTypes.k_fold_cross_validation, + )) def test_classification(openml_id, resampling_strategy, backend): # Get the data and check that contents of data-manager make sense @@ -54,7 +54,7 @@ def test_classification(openml_id, resampling_strategy, backend): dataset=datamanager, optimize_metric='accuracy', total_walltime_limit=150, - func_eval_time_limit=30, + func_eval_time_limit=50, traditional_per_total_budget=0 ) From 1201a7c9984ca794ae1c1d11c351418e8f7448fd Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Wed, 27 Jan 2021 23:21:20 +0100 Subject: [PATCH 5/5] Addressed comments --- test/test_api/test_api.py | 3 ++- .../test_pipeline/components/test_setup_networks.py | 13 ------------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index 306209bb8..ce9a88e2e 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -27,7 +27,8 @@ # Test # ======== @pytest.mark.parametrize('openml_id', (40981, )) -@pytest.mark.parametrize('resampling_strategy', (CrossValTypes.k_fold_cross_validation, +@pytest.mark.parametrize('resampling_strategy', (HoldoutValTypes.holdout_validation, + CrossValTypes.k_fold_cross_validation, )) def test_classification(openml_id, resampling_strategy, backend): diff --git a/test/test_pipeline/components/test_setup_networks.py b/test/test_pipeline/components/test_setup_networks.py index 2d9551e48..a445d0771 100644 --- a/test/test_pipeline/components/test_setup_networks.py +++ b/test/test_pipeline/components/test_setup_networks.py @@ -1,6 +1,3 @@ -import os -import sys - import pytest import torch @@ -8,16 +5,6 @@ from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline -# Disable -def blockPrint(): - sys.stdout = open(os.devnull, 'w') - - -# Restore -def enablePrint(): - sys.stdout = sys.__stdout__ - - @pytest.fixture(params=['MLPBackbone', 'ResNetBackbone', 'ShapedMLPBackbone', 'ShapedResNetBackbone']) def backbone(request): return request.param