Skip to content

Commit 4af3932

Browse files
committed
Additional metrics during train (#194)
* Added additional metrics to fit dictionary * Added in test also Fix mypy and flake after rebase, added random state to mixup and cutout and changs no resampling for new code fix bug in setup.py
1 parent f2fb6d4 commit 4af3932

File tree

6 files changed

+71
-66
lines changed

6 files changed

+71
-66
lines changed

autoPyTorch/datasets/base_dataset.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
HoldOutFunc,
2323
HoldOutFuncs,
2424
HoldoutValTypes,
25-
get_no_resampling_validators,
26-
NoResamplingStrategyTypes,
27-
NO_RESAMPLING_FN
25+
NoResamplingFunc,
26+
NoResamplingFuncs,
27+
NoResamplingStrategyTypes
2828
)
2929
from autoPyTorch.utils.common import FitRequirement
3030

@@ -114,24 +114,19 @@ def __init__(
114114
val_transforms (Optional[torchvision.transforms.Compose]):
115115
Additional Transforms to be applied to the validation/test data
116116
"""
117-
self.dataset_name = dataset_name
118117

119-
if self.dataset_name is None:
118+
if dataset_name is None:
120119
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
120+
else:
121+
self.dataset_name = dataset_name
121122

122123
if not hasattr(train_tensors[0], 'shape'):
123124
type_check(train_tensors, val_tensors)
124125
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
125-
<<<<<<< HEAD
126126
self.cross_validators: Dict[str, CrossValFunc] = {}
127127
self.holdout_validators: Dict[str, HoldOutFunc] = {}
128128
self.random_state = np.random.RandomState(seed=seed)
129-
=======
130-
self.cross_validators: Dict[str, CROSS_VAL_FN] = {}
131-
self.holdout_validators: Dict[str, HOLDOUT_FN] = {}
132-
self.no_resampling_validators: Dict[str, NO_RESAMPLING_FN] = {}
133-
self.rng = np.random.RandomState(seed=seed)
134-
>>>>>>> Fix mypy and flake
129+
self.no_resampling_validators: Dict[str, NoResamplingFunc] = {}
135130
self.shuffle = shuffle
136131
self.resampling_strategy = resampling_strategy
137132
self.resampling_strategy_args = resampling_strategy_args
@@ -156,7 +151,7 @@ def __init__(
156151
# Make sure cross validation splits are created once
157152
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
158153
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes)
159-
self.no_resampling_validators = get_no_resampling_validators(*NoResamplingStrategyTypes)
154+
self.no_resampling_validators = NoResamplingFuncs.get_no_resampling_validators(*NoResamplingStrategyTypes)
160155

161156
self.splits = self.get_splits_from_resampling_strategy()
162157

@@ -257,7 +252,8 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], Optional[
257252
)
258253
)
259254
elif isinstance(self.resampling_strategy, NoResamplingStrategyTypes):
260-
splits.append((self.no_resampling_validators[self.resampling_strategy.name](self._get_indices()), None))
255+
splits.append((self.no_resampling_validators[self.resampling_strategy.name](self.random_state,
256+
self._get_indices()), None))
261257
else:
262258
raise ValueError(f"Unsupported resampling strategy={self.resampling_strategy}")
263259
return splits

autoPyTorch/datasets/resampling_strategy.py

Lines changed: 43 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ def __call__(self, random_state: np.random.RandomState, val_share: float,
3232
...
3333

3434

35-
class NO_RESAMPLING_FN(Protocol):
36-
def __call__(self, indices: np.ndarray) -> np.ndarray:
35+
class NoResamplingFunc(Protocol):
36+
def __call__(self,
37+
random_state: np.random.RandomState,
38+
indices: np.ndarray) -> np.ndarray:
3739
...
3840

3941

@@ -244,53 +246,46 @@ def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, Cros
244246
return cross_validators
245247

246248

247-
def get_no_resampling_validators(*no_resampling: NoResamplingStrategyTypes) -> Dict[str, NO_RESAMPLING_FN]:
248-
no_resampling_strategies = {} # type: Dict[str, NO_RESAMPLING_FN]
249-
for strategy in no_resampling:
250-
no_resampling_fn = globals()[strategy.name]
251-
no_resampling_strategies[strategy.name] = no_resampling_fn
252-
return no_resampling_strategies
249+
class NoResamplingFuncs():
250+
@classmethod
251+
def get_no_resampling_validators(cls, *no_resampling_types: NoResamplingStrategyTypes
252+
) -> Dict[str, NoResamplingFunc]:
253+
no_resampling_strategies: Dict[str, NoResamplingFunc] = {
254+
no_resampling_type.name: getattr(cls, no_resampling_type.name)
255+
for no_resampling_type in no_resampling_types
256+
}
257+
return no_resampling_strategies
253258

259+
@staticmethod
260+
def no_resampling(random_state: np.random.RandomState,
261+
indices: np.ndarray) -> np.ndarray:
262+
"""
263+
Returns the indices without performing
264+
any operation on them. To be used for
265+
fitting on the whole dataset.
266+
This strategy is not compatible with
267+
HPO search.
268+
Args:
269+
indices: array of indices
254270
255-
def no_resampling(indices: np.ndarray) -> np.ndarray:
256-
"""
257-
Returns the indices without performing
258-
any operation on them. To be used for
259-
fitting on the whole dataset.
260-
This strategy is not compatible with
261-
HPO search.
262-
Args:
263-
indices: array of indices
264-
265-
Returns:
266-
np.ndarray: array of indices
267-
"""
268-
return indices
271+
Returns:
272+
np.ndarray: array of indices
273+
"""
274+
return indices
269275

276+
@staticmethod
277+
def shuffle_no_resampling(random_state: np.random.RandomState,
278+
indices: np.ndarray) -> np.ndarray:
279+
"""
280+
Returns the indices after shuffling them.
281+
To be used for fitting on the whole dataset.
282+
This strategy is not compatible with HPO search.
283+
Args:
284+
random_state: random state
285+
indices: array of indices
270286
271-
def shuffle_no_resampling(indices: np.ndarray, **kwargs: Any) -> np.ndarray:
272-
"""
273-
Returns the indices after shuffling them.
274-
To be used for fitting on the whole dataset.
275-
This strategy is not compatible with HPO search.
276-
Args:
277-
indices: array of indices
278-
279-
Returns:
280-
np.ndarray: shuffled array of indices
281-
"""
282-
if 'random_state' in kwargs:
283-
if isinstance(kwargs['random_state'], np.random.RandomState):
284-
kwargs['random_state'].shuffle(indices)
285-
elif isinstance(kwargs['random_state'], int):
286-
np.random.seed(kwargs['random_state'])
287-
np.random.shuffle(indices)
288-
else:
289-
raise ValueError("Illegal value for 'random_state' entered. "
290-
"Expected it to be {} or {} but got {}".format(int,
291-
np.random.RandomState,
292-
type(kwargs['random_state'])))
293-
else:
294-
np.random.shuffle(indices)
295-
296-
return indices
287+
Returns:
288+
np.ndarray: shuffled array of indices
289+
"""
290+
random_state.shuffle(indices)
291+
return indices

autoPyTorch/pipeline/components/training/trainer/cutout_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
import numpy as np
1212

13+
from sklearn.utils import check_random_state
14+
1315
from autoPyTorch.constants import CLASSIFICATION_TASKS, STRING_TO_TASK_TYPES
1416
from autoPyTorch.pipeline.components.training.trainer.utils import Lookahead
1517
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter, get_hyperparameter
@@ -35,7 +37,12 @@ def __init__(self, patch_ratio: float,
3537
"""
3638
self.use_stochastic_weight_averaging = use_stochastic_weight_averaging
3739
self.weighted_loss = weighted_loss
38-
self.random_state = random_state
40+
if random_state is None:
41+
# A trainer components need a random state for
42+
# sampling -- for example in MixUp training
43+
self.random_state = check_random_state(1)
44+
else:
45+
self.random_state = random_state
3946
self.use_snapshot_ensemble = use_snapshot_ensemble
4047
self.se_lastk = se_lastk
4148
self.use_lookahead_optimizer = use_lookahead_optimizer

autoPyTorch/pipeline/components/training/trainer/mixup_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
import numpy as np
1212

13+
from sklearn.utils import check_random_state
14+
1315
from autoPyTorch.constants import CLASSIFICATION_TASKS, STRING_TO_TASK_TYPES
1416
from autoPyTorch.pipeline.components.training.trainer.utils import Lookahead
1517
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter, get_hyperparameter
@@ -34,7 +36,12 @@ def __init__(self, alpha: float,
3436
"""
3537
self.use_stochastic_weight_averaging = use_stochastic_weight_averaging
3638
self.weighted_loss = weighted_loss
37-
self.random_state = random_state
39+
if random_state is None:
40+
# A trainer components need a random state for
41+
# sampling -- for example in MixUp training
42+
self.random_state = check_random_state(1)
43+
else:
44+
self.random_state = random_state
3845
self.use_snapshot_ensemble = use_snapshot_ensemble
3946
self.se_lastk = se_lastk
4047
self.use_lookahead_optimizer = use_lookahead_optimizer

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
"pre-commit",
5959
"pytest-cov",
6060
'pytest-forked',
61-
"pytest-mock"
61+
"pytest-mock",
6262
"codecov",
6363
"pep8",
6464
"mypy",

test/test_api/test_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from autoPyTorch.pipeline.components.setup.traditional_ml.traditional_learner import _traditional_learners
3636
from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy
3737

38-
from test.test_api.api_utils import print_debug_information
38+
from test.test_api.api_utils import print_debug_information # noqa E402
3939

4040

4141
CV_NUM_SPLITS = 2

0 commit comments

Comments
 (0)