Skip to content

Fix bugs in cutout training #233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,9 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
if beta <= 0 or r > self.alpha:
return X, {'y_a': y, 'y_b': y[index], 'lam': 1}

# The mixup component mixes up also on the batch dimension
# It is unlikely that the batch size is lower than the number of features, but
# be safe
size = min(X.shape[0], X.shape[1])
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int(size * lam))))
size = X.shape[1]
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int32(size * lam)),
replace=False))

X[:, indices] = X[index, :][:, indices]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@


class RowCutOutTrainer(CutOut, BaseTrainerComponent):
NUMERICAL_VALUE = 0
CATEGORICAL_VALUE = -1

def data_preparation(self, X: np.ndarray, y: np.ndarray,
) -> typing.Tuple[np.ndarray, typing.Dict[str, np.ndarray]]:
Expand All @@ -34,17 +36,26 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
lam = 1
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}

# The mixup component mixes up also on the batch dimension
# It is unlikely that the batch size is lower than the number of features, but
# be safe
size = min(X.shape[0], X.shape[1])
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int(size * self.patch_ratio))))
size = X.shape[1]
indices = self.random_state.choice(range(1, size), max(1, np.int32(size * self.patch_ratio)),
replace=False)

# We use an ordinal encoder on the tabular data
if not isinstance(self.numerical_columns, typing.Iterable):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the numerical columns are None, we should still continue with only categorical imputing in this case or not.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also if there are only numerical columns, there should not be a conversion for categorical ones.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually when there are no numerical columns, it is not none but it is an empty list. And indexing with an empty list does not affect the tensor so this should work

Copy link

@ArlindKadra ArlindKadra May 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numerical_columns=X['dataset_properties']['numerical_columns'] if 'numerical_columns' in X[
'dataset_properties'] else None

Is numerical_columns always in dataset_properties ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when its tabular data then yeah

raise ValueError("{} requires numerical columns information of {}"
"to prepare data got {}.".format(self.__class__.__name__,
typing.Iterable,
self.numerical_columns))
numerical_indices = torch.tensor(self.numerical_columns)
categorical_indices = torch.tensor([index for index in indices if index not in self.numerical_columns])

# We use an ordinal encoder on the categorical columns of tabular data
# -1 is the conceptual equivalent to 0 in a image, that does not
# have color as a feature and hence the network has to learn to deal
# without this data
X[:, indices.long()] = -1
# without this data. For numerical columns we use 0 to cutout the features
# similar to the effect that setting 0 as a pixel value in an image.
X[:, categorical_indices.long()] = self.CATEGORICAL_VALUE
X[:, numerical_indices.long()] = self.NUMERICAL_VALUE

lam = 1
y_a = y
y_b = y
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ def prepare(
metrics_during_training: bool,
scheduler: _LRScheduler,
task_type: int,
labels: Union[np.ndarray, torch.Tensor, pd.DataFrame]
labels: Union[np.ndarray, torch.Tensor, pd.DataFrame],
numerical_columns: Optional[List[int]] = None
) -> None:

# Save the device to be used
Expand Down Expand Up @@ -289,6 +290,9 @@ def prepare(
# task type (used for calculating metrics)
self.task_type = task_type

# for cutout trainer, we need the list of numerical columns
self.numerical_columns = numerical_columns

def on_epoch_start(self, X: Dict[str, Any], epoch: int) -> None:
"""
Optional place holder for AutoPytorch Extensions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,9 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
metrics_during_training=X['metrics_during_training'],
scheduler=X['lr_scheduler'],
task_type=STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']],
labels=X['y_train'][X['backend'].load_datamanager().splits[X['split_id']][0]]
labels=X['y_train'][X['backend'].load_datamanager().splits[X['split_id']][0]],
numerical_columns=X['dataset_properties']['numerical_columns'] if 'numerical_columns' in X[
'dataset_properties'] else None
)
total_parameter_count, trainable_parameter_count = self.count_parameters(X['network'])
self.run_summary = RunSummary(
Expand Down