Skip to content

Commit 147cf20

Browse files
committed
Additional metrics during train (#194)
* Added additional metrics to fit dictionary * Added in test also Fix mypy and flake after rebase, added random state to mixup and cutout and changs no resampling for new code fix bug in setup.py
1 parent 246a34d commit 147cf20

File tree

6 files changed

+24
-49
lines changed

6 files changed

+24
-49
lines changed

autoPyTorch/datasets/base_dataset.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,7 @@ def __init__(
112112
dataset_name: Optional[str] = None,
113113
val_tensors: Optional[BaseDatasetInputType] = None,
114114
test_tensors: Optional[BaseDatasetInputType] = None,
115-
<<<<<<< HEAD
116115
resampling_strategy: ResamplingStrategies = HoldoutValTypes.holdout_validation,
117-
=======
118-
resampling_strategy: Union[CrossValTypes,
119-
HoldoutValTypes,
120-
NoResamplingStrategyTypes] = HoldoutValTypes.holdout_validation,
121-
>>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
122116
resampling_strategy_args: Optional[Dict[str, Any]] = None,
123117
shuffle: Optional[bool] = True,
124118
seed: Optional[int] = 42,
@@ -135,12 +129,7 @@ def __init__(
135129
validation data
136130
test_tensors (An optional tuple of objects that have a __len__ and a __getitem__ attribute):
137131
test data
138-
<<<<<<< HEAD
139132
resampling_strategy (RESAMPLING_STRATEGIES: default=HoldoutValTypes.holdout_validation):
140-
=======
141-
resampling_strategy (Union[CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes]),
142-
(default=HoldoutValTypes.holdout_validation):
143-
>>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
144133
strategy to split the training data.
145134
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
146135
required for the chosen resampling strategy. If None, uses
@@ -162,17 +151,11 @@ def __init__(
162151
if not hasattr(train_tensors[0], 'shape'):
163152
type_check(train_tensors, val_tensors)
164153
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
165-
<<<<<<< HEAD
166154
self.cross_validators: Dict[str, CrossValFunc] = {}
167155
self.holdout_validators: Dict[str, HoldOutFunc] = {}
168156
self.no_resampling_validators: Dict[str, NoResamplingFunc] = {}
169157
self.random_state = np.random.RandomState(seed=seed)
170-
=======
171-
self.cross_validators: Dict[str, CROSS_VAL_FN] = {}
172-
self.holdout_validators: Dict[str, HOLDOUT_FN] = {}
173-
self.no_resampling_validators: Dict[str, NO_RESAMPLING_FN] = {}
174-
self.rng = np.random.RandomState(seed=seed)
175-
>>>>>>> Fix mypy and flake
158+
self.no_resampling_validators: Dict[str, NoResamplingFunc] = {}
176159
self.shuffle = shuffle
177160
self.resampling_strategy = resampling_strategy
178161
self.resampling_strategy_args = resampling_strategy_args
@@ -189,11 +172,8 @@ def __init__(
189172
# Make sure cross validation splits are created once
190173
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
191174
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes)
192-
<<<<<<< HEAD
175+
193176
self.no_resampling_validators = NoResamplingFuncs.get_no_resampling_validators(*NoResamplingStrategyTypes)
194-
=======
195-
self.no_resampling_validators = get_no_resampling_validators(*NoResamplingStrategyTypes)
196-
>>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
197177

198178
self.splits = self.get_splits_from_resampling_strategy()
199179

@@ -294,12 +274,8 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], Optional[
294274
)
295275
)
296276
elif isinstance(self.resampling_strategy, NoResamplingStrategyTypes):
297-
<<<<<<< HEAD
298277
splits.append((self.no_resampling_validators[self.resampling_strategy.name](self.random_state,
299278
self._get_indices()), None))
300-
=======
301-
splits.append((self.no_resampling_validators[self.resampling_strategy.name](self._get_indices()), None))
302-
>>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
303279
else:
304280
raise ValueError(f"Unsupported resampling strategy={self.resampling_strategy}")
305281
return splits
@@ -371,11 +347,7 @@ def create_holdout_val_split(
371347
self.random_state, val_share, self._get_indices(), **kwargs)
372348
return train, val
373349

374-
<<<<<<< HEAD
375350
def get_dataset(self, split_id: int, train: bool) -> Dataset:
376-
=======
377-
def get_dataset_for_training(self, split_id: int, train: bool) -> Dataset:
378-
>>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
379351
"""
380352
The above split methods employ the Subset to internally subsample the whole dataset.
381353
@@ -390,7 +362,6 @@ def get_dataset_for_training(self, split_id: int, train: bool) -> Dataset:
390362
Dataset: the reduced dataset to be used for testing
391363
"""
392364
# Subset creates a dataset. Splits is a (train_indices, test_indices) tuple
393-
<<<<<<< HEAD
394365
if split_id >= len(self.splits): # old version: split_id > len(self.splits)
395366
raise IndexError(f"self.splits index out of range, got split_id={split_id}"
396367
f" (>= num_splits={len(self.splits)})")
@@ -399,9 +370,6 @@ def get_dataset_for_training(self, split_id: int, train: bool) -> Dataset:
399370
raise ValueError("Specified fold (or subset) does not exist")
400371

401372
return TransformSubset(self, indices, train=train)
402-
=======
403-
return TransformSubset(self, self.splits[split_id][0], train=train)
404-
>>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
405373

406374
def replace_data(self, X_train: BaseDatasetInputType,
407375
X_test: Optional[BaseDatasetInputType]) -> 'BaseDataset':

autoPyTorch/datasets/resampling_strategy.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ def __call__(self, random_state: np.random.RandomState, val_share: float,
3939
...
4040

4141

42-
class NO_RESAMPLING_FN(Protocol):
43-
def __call__(self, indices: np.ndarray) -> np.ndarray:
42+
class NoResamplingFunc(Protocol):
43+
def __call__(self,
44+
random_state: np.random.RandomState,
45+
indices: np.ndarray) -> np.ndarray:
4446
...
4547

4648

@@ -90,22 +92,13 @@ def is_stratified(self) -> bool:
9092

9193
class NoResamplingStrategyTypes(IntEnum):
9294
no_resampling = 8
93-
<<<<<<< HEAD
9495

9596
def is_stratified(self) -> bool:
9697
return False
9798

9899

99100
# TODO: replace it with another way
100101
ResamplingStrategies = Union[CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes]
101-
=======
102-
shuffle_no_resampling = 9
103-
104-
105-
# TODO: replace it with another way
106-
RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes]
107-
108-
>>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
109102

110103
DEFAULT_RESAMPLING_PARAMETERS: Dict[
111104
ResamplingStrategies,

autoPyTorch/pipeline/components/training/trainer/cutout_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
import numpy as np
1212

13+
from sklearn.utils import check_random_state
14+
1315
from autoPyTorch.constants import CLASSIFICATION_TASKS, STRING_TO_TASK_TYPES
1416
from autoPyTorch.pipeline.components.training.trainer.utils import Lookahead
1517
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter, get_hyperparameter
@@ -35,7 +37,12 @@ def __init__(self, patch_ratio: float,
3537
"""
3638
self.use_stochastic_weight_averaging = use_stochastic_weight_averaging
3739
self.weighted_loss = weighted_loss
38-
self.random_state = random_state
40+
if random_state is None:
41+
# A trainer components need a random state for
42+
# sampling -- for example in MixUp training
43+
self.random_state = check_random_state(1)
44+
else:
45+
self.random_state = random_state
3946
self.use_snapshot_ensemble = use_snapshot_ensemble
4047
self.se_lastk = se_lastk
4148
self.use_lookahead_optimizer = use_lookahead_optimizer

autoPyTorch/pipeline/components/training/trainer/mixup_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
import numpy as np
1212

13+
from sklearn.utils import check_random_state
14+
1315
from autoPyTorch.constants import CLASSIFICATION_TASKS, STRING_TO_TASK_TYPES
1416
from autoPyTorch.pipeline.components.training.trainer.utils import Lookahead
1517
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter, get_hyperparameter
@@ -34,7 +36,12 @@ def __init__(self, alpha: float,
3436
"""
3537
self.use_stochastic_weight_averaging = use_stochastic_weight_averaging
3638
self.weighted_loss = weighted_loss
37-
self.random_state = random_state
39+
if random_state is None:
40+
# A trainer components need a random state for
41+
# sampling -- for example in MixUp training
42+
self.random_state = check_random_state(1)
43+
else:
44+
self.random_state = random_state
3845
self.use_snapshot_ensemble = use_snapshot_ensemble
3946
self.se_lastk = se_lastk
4047
self.use_lookahead_optimizer = use_lookahead_optimizer

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
"pre-commit",
5959
"pytest-cov",
6060
'pytest-forked',
61-
"pytest-mock"
61+
"pytest-mock",
6262
"codecov",
6363
"pep8",
6464
"mypy",

test/test_api/test_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from autoPyTorch.pipeline.components.setup.traditional_ml.traditional_learner import _traditional_learners
3737
from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy
3838

39-
from test.test_api.api_utils import print_debug_information
39+
from test.test_api.api_utils import print_debug_information # noqa E402
4040

4141

4242
CV_NUM_SPLITS = 2

0 commit comments

Comments
 (0)