From 0900210a9b419d4b6bf9ba131bf3e1a0a0b5253f Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Thu, 3 Jun 2021 20:43:49 +0200 Subject: [PATCH 1/5] fix dropout bug --- .../components/setup/network_backbone/ResNetBackbone.py | 6 +++--- .../setup/network_backbone/ShapedResNetBackbone.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py index 6391baa6a..b259786cc 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py @@ -41,7 +41,7 @@ def build_backbone(self, input_shape: Tuple[int, ...]) -> None: out_features=self.config["num_units_%d" % i], blocks_per_group=self.config["blocks_per_group_%d" % i], last_block_index=(i - 1) * self.config["blocks_per_group_%d" % i], - dropout=self.config['use_dropout'] + dropout=self.config[f'dropout_{i}'] if self.config['use_dropout'] else None, ) ) @@ -52,7 +52,7 @@ def build_backbone(self, input_shape: Tuple[int, ...]) -> None: return backbone def _add_group(self, in_features: int, out_features: int, - blocks_per_group: int, last_block_index: int, dropout: bool + blocks_per_group: int, last_block_index: int, dropout: Optional[float] ) -> nn.Module: """ Adds a group into the main backbone. @@ -206,7 +206,7 @@ def __init__( out_features: int, blocks_per_group: int, block_index: int, - dropout: bool, + dropout: Optional[float], activation: nn.Module ): super(ResBlock, self).__init__() diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py index c8ffa1b4e..d91c1bf6d 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py @@ -38,7 +38,7 @@ def build_backbone(self, input_shape: Tuple[int, ...]) -> None: self.config.update( {"num_units_%d" % (i): num for i, num in enumerate(neuron_counts)} ) - if self.config['use_dropout'] and self.config["max_dropout"] > 0.05: + if self.config['use_dropout']: dropout_shape = get_shaped_neuron_counts( self.config['resnet_shape'], 0, 0, 1000, self.config['num_groups'] ) @@ -61,7 +61,7 @@ def build_backbone(self, input_shape: Tuple[int, ...]) -> None: out_features=self.config["num_units_%d" % i], blocks_per_group=self.config["blocks_per_group"], last_block_index=(i - 1) * self.config["blocks_per_group"], - dropout=self.config['use_dropout'] + dropout=self.config[f'dropout_{i}'] if self.config['use_dropout'] else None ) ) From 2e0d35c49486cd747e811033c3c35236163b9ff9 Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Thu, 3 Jun 2021 20:45:02 +0200 Subject: [PATCH 2/5] fix dropout shape discrepancy --- .../components/setup/network_backbone/ShapedResNetBackbone.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py index d91c1bf6d..69447fbfe 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py @@ -41,7 +41,7 @@ def build_backbone(self, input_shape: Tuple[int, ...]) -> None: if self.config['use_dropout']: dropout_shape = get_shaped_neuron_counts( self.config['resnet_shape'], 0, 0, 1000, self.config['num_groups'] - ) + )[:-1] dropout_shape = [ dropout / 1000 * self.config["max_dropout"] for dropout in dropout_shape From 2d0bc1e215e4b42087244e588e887727eb22c407 Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Wed, 9 Jun 2021 11:53:22 +0200 Subject: [PATCH 3/5] Fix unit test bug --- .../setup/network_backbone/ShapedResNetBackbone.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py index 69447fbfe..1abd84df9 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py @@ -38,9 +38,13 @@ def build_backbone(self, input_shape: Tuple[int, ...]) -> None: self.config.update( {"num_units_%d" % (i): num for i, num in enumerate(neuron_counts)} ) + # we are skipping the last layer, as the function get_shaped_neuron_counts + # is built for getting neuron counts, so it will add the out_features to + # the last layer. However, in dropout we dont want to have that, we just + # want to use the shape and not worry about the output. if self.config['use_dropout']: dropout_shape = get_shaped_neuron_counts( - self.config['resnet_shape'], 0, 0, 1000, self.config['num_groups'] + self.config['resnet_shape'], 0, 0, 1000, self.config['num_groups'] + 1 )[:-1] dropout_shape = [ From 3bd25e7287ff643d236cbacb6d5c2804bdac0c4b Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Thu, 10 Jun 2021 16:10:39 +0200 Subject: [PATCH 4/5] Add tests for dropout shape asper comments from fransisco --- .../components/setup/test_setup.py | 60 ++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/test/test_pipeline/components/setup/test_setup.py b/test/test_pipeline/components/setup/test_setup.py index 9d9b6f7ad..5fc20d56c 100644 --- a/test/test_pipeline/components/setup/test_setup.py +++ b/test/test_pipeline/components/setup/test_setup.py @@ -23,6 +23,9 @@ ) from autoPyTorch.pipeline.components.setup.network_backbone import NetworkBackboneChoice from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import NetworkBackboneComponent +from autoPyTorch.pipeline.components.setup.network_backbone.ResNetBackbone import ResBlock +from autoPyTorch.pipeline.components.setup.network_backbone.ShapedResNetBackbone import ShapedResNetBackbone +from autoPyTorch.pipeline.components.setup.network_backbone.utils import get_shaped_neuron_counts from autoPyTorch.pipeline.components.setup.network_head import NetworkHeadChoice from autoPyTorch.pipeline.components.setup.network_head.base_network_head import NetworkHeadComponent from autoPyTorch.pipeline.components.setup.network_initializer import ( @@ -33,7 +36,10 @@ BaseOptimizerComponent, OptimizerChoice ) -from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates +from autoPyTorch.utils.hyperparameter_search_space_update import ( + HyperparameterSearchSpaceUpdates, + HyperparameterSearchSpace +) class DummyLR(BaseLRComponent): @@ -417,6 +423,58 @@ def test_add_network_backbone(self): # clear addons base_network_backbone_choice._addons = ThirdPartyComponents(NetworkBackboneComponent) + @pytest.mark.parametrize('resnet_shape', ['funnel', 'long_funnel', + 'diamond', 'hexagon', + 'brick', 'triangle', + 'stairs']) + def test_dropout(self, resnet_shape): + # ensures that dropout is assigned to the resblock as expected + dataset_properties = {"task_type": constants.TASK_TYPES_TO_STRING[1]} + max_dropout = 0.5 + num_groups = 4 + config_space = ShapedResNetBackbone.get_hyperparameter_search_space(dataset_properties=dataset_properties, + use_dropout=HyperparameterSearchSpace( + hyperparameter='use_dropout', + value_range=[True], + default_value=True), + max_dropout=HyperparameterSearchSpace( + hyperparameter='max_dropout', + value_range=[max_dropout], + default_value=max_dropout), + resnet_shape=HyperparameterSearchSpace( + hyperparameter='resnet_shape', + value_range=[resnet_shape], + default_value=resnet_shape), + num_groups=HyperparameterSearchSpace( + hyperparameter='num_groups', + value_range=[num_groups], + default_value=num_groups), + blocks_per_group=HyperparameterSearchSpace( + hyperparameter='blocks_per_group', + value_range=[1], + default_value=1 + ) + ) + + config = config_space.sample_configuration().get_dictionary() + resnet_backbone = ShapedResNetBackbone(**config) + resnet_backbone.build_backbone((100, 5)) + dropout_probabilites = [resnet_backbone.config[key] for key in resnet_backbone.config if 'dropout_' in key] + dropout_shape = get_shaped_neuron_counts( + resnet_shape, 0, 0, 1000, num_groups + 1 + )[:-1] + + dropout_shape = [ + dropout / 1000 * max_dropout for dropout in dropout_shape + ] + blocks_dropout = [] + for block in resnet_backbone.backbone: + if isinstance(block, torch.nn.Sequential): + for inner_block in block: + if isinstance(inner_block, ResBlock): + blocks_dropout.append(inner_block.dropout) + assert dropout_probabilites == dropout_shape == blocks_dropout + class TestNetworkHead: def test_all_heads_available(self): From 6686328de2f7ccd38011f138ae427562fdb6fe20 Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Thu, 10 Jun 2021 16:13:45 +0200 Subject: [PATCH 5/5] Fix flake --- test/test_pipeline/components/setup/test_setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_pipeline/components/setup/test_setup.py b/test/test_pipeline/components/setup/test_setup.py index 5fc20d56c..74ef4f92f 100644 --- a/test/test_pipeline/components/setup/test_setup.py +++ b/test/test_pipeline/components/setup/test_setup.py @@ -22,9 +22,9 @@ SchedulerChoice ) from autoPyTorch.pipeline.components.setup.network_backbone import NetworkBackboneChoice -from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import NetworkBackboneComponent from autoPyTorch.pipeline.components.setup.network_backbone.ResNetBackbone import ResBlock from autoPyTorch.pipeline.components.setup.network_backbone.ShapedResNetBackbone import ShapedResNetBackbone +from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import NetworkBackboneComponent from autoPyTorch.pipeline.components.setup.network_backbone.utils import get_shaped_neuron_counts from autoPyTorch.pipeline.components.setup.network_head import NetworkHeadChoice from autoPyTorch.pipeline.components.setup.network_head.base_network_head import NetworkHeadComponent @@ -37,8 +37,8 @@ OptimizerChoice ) from autoPyTorch.utils.hyperparameter_search_space_update import ( - HyperparameterSearchSpaceUpdates, - HyperparameterSearchSpace + HyperparameterSearchSpace, + HyperparameterSearchSpaceUpdates )