Skip to content

Commit 463c166

Browse files
authored
Fix bugs in cutout training (#233)
* Fix bugs in cutout training * Address comments from arlind
1 parent efacc95 commit 463c166

File tree

4 files changed

+30
-15
lines changed

4 files changed

+30
-15
lines changed

autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,9 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
3535
if beta <= 0 or r > self.alpha:
3636
return X, {'y_a': y, 'y_b': y[index], 'lam': 1}
3737

38-
# The mixup component mixes up also on the batch dimension
39-
# It is unlikely that the batch size is lower than the number of features, but
40-
# be safe
41-
size = min(X.shape[0], X.shape[1])
42-
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int(size * lam))))
38+
size = X.shape[1]
39+
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int32(size * lam)),
40+
replace=False))
4341

4442
X[:, indices] = X[index, :][:, indices]
4543

autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010

1111
class RowCutOutTrainer(CutOut, BaseTrainerComponent):
12+
NUMERICAL_VALUE = 0
13+
CATEGORICAL_VALUE = -1
1214

1315
def data_preparation(self, X: np.ndarray, y: np.ndarray,
1416
) -> typing.Tuple[np.ndarray, typing.Dict[str, np.ndarray]]:
@@ -34,17 +36,26 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
3436
lam = 1
3537
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}
3638

37-
# The mixup component mixes up also on the batch dimension
38-
# It is unlikely that the batch size is lower than the number of features, but
39-
# be safe
40-
size = min(X.shape[0], X.shape[1])
41-
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int(size * self.patch_ratio))))
39+
size = X.shape[1]
40+
indices = self.random_state.choice(range(1, size), max(1, np.int32(size * self.patch_ratio)),
41+
replace=False)
4242

43-
# We use an ordinal encoder on the tabular data
43+
if not isinstance(self.numerical_columns, typing.Iterable):
44+
raise ValueError("{} requires numerical columns information of {}"
45+
"to prepare data got {}.".format(self.__class__.__name__,
46+
typing.Iterable,
47+
self.numerical_columns))
48+
numerical_indices = torch.tensor(self.numerical_columns)
49+
categorical_indices = torch.tensor([index for index in indices if index not in self.numerical_columns])
50+
51+
# We use an ordinal encoder on the categorical columns of tabular data
4452
# -1 is the conceptual equivalent to 0 in a image, that does not
4553
# have color as a feature and hence the network has to learn to deal
46-
# without this data
47-
X[:, indices.long()] = -1
54+
# without this data. For numerical columns we use 0 to cutout the features
55+
# similar to the effect that setting 0 as a pixel value in an image.
56+
X[:, categorical_indices.long()] = self.CATEGORICAL_VALUE
57+
X[:, numerical_indices.long()] = self.NUMERICAL_VALUE
58+
4859
lam = 1
4960
y_a = y
5061
y_b = y

autoPyTorch/pipeline/components/training/trainer/base_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,8 @@ def prepare(
233233
metrics_during_training: bool,
234234
scheduler: _LRScheduler,
235235
task_type: int,
236-
labels: Union[np.ndarray, torch.Tensor, pd.DataFrame]
236+
labels: Union[np.ndarray, torch.Tensor, pd.DataFrame],
237+
numerical_columns: Optional[List[int]] = None
237238
) -> None:
238239

239240
# Save the device to be used
@@ -289,6 +290,9 @@ def prepare(
289290
# task type (used for calculating metrics)
290291
self.task_type = task_type
291292

293+
# for cutout trainer, we need the list of numerical columns
294+
self.numerical_columns = numerical_columns
295+
292296
def on_epoch_start(self, X: Dict[str, Any], epoch: int) -> None:
293297
"""
294298
Optional place holder for AutoPytorch Extensions.

autoPyTorch/pipeline/components/training/trainer/base_trainer_choice.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,9 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
336336
metrics_during_training=X['metrics_during_training'],
337337
scheduler=X['lr_scheduler'],
338338
task_type=STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']],
339-
labels=X['y_train'][X['backend'].load_datamanager().splits[X['split_id']][0]]
339+
labels=X['y_train'][X['backend'].load_datamanager().splits[X['split_id']][0]],
340+
numerical_columns=X['dataset_properties']['numerical_columns'] if 'numerical_columns' in X[
341+
'dataset_properties'] else None
340342
)
341343
total_parameter_count, trainable_parameter_count = self.count_parameters(X['network'])
342344
self.run_summary = RunSummary(

0 commit comments

Comments
 (0)