Skip to content

[FIX] Passing checks #298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
16cfb55
Initial fix for all tests passing locally py=3.8
ravinkohli Oct 22, 2021
90ebe72
fix bug in tests
ravinkohli Oct 22, 2021
f965094
fix bug in test for data
ravinkohli Oct 22, 2021
f7b4d70
debugging error in dummy forward pass
ravinkohli Oct 22, 2021
3181ed1
debug try -2
ravinkohli Oct 22, 2021
4d69ff5
catch runtime error in ci
ravinkohli Oct 25, 2021
e8d1eb3
catch runtime error in ci
ravinkohli Oct 25, 2021
243573b
add better debug test setup
ravinkohli Oct 25, 2021
dad3a4b
debug some more
ravinkohli Oct 25, 2021
6d94893
run this test only
ravinkohli Oct 25, 2021
e049177
remove sum backward
ravinkohli Oct 25, 2021
d079e04
remove inplace in inception block
ravinkohli Oct 25, 2021
245af31
undo silly change
ravinkohli Oct 25, 2021
0a18ab8
Enable all tests
ravinkohli Oct 25, 2021
ec62e2e
fix flake
ravinkohli Oct 25, 2021
f528698
fix bug in test setup
ravinkohli Oct 25, 2021
b399dac
remove anamoly detection
ravinkohli Oct 25, 2021
42ca211
change in trainer choice
ravinkohli Oct 25, 2021
296cc16
minor changes to comments
ravinkohli Oct 26, 2021
b4314f9
Apply suggestions from code review
ravinkohli Nov 3, 2021
fefbdcf
Address comments from Shuhei
ravinkohli Nov 8, 2021
042f478
revert change leading to bug
ravinkohli Nov 8, 2021
aaefc83
fix flake
ravinkohli Nov 8, 2021
8ebbc5e
change comment position in feature validator
ravinkohli Nov 8, 2021
e3c43ef
Add documentation for _is_datasets_consistent
ravinkohli Nov 9, 2021
3564fa1
address comments from arlind
ravinkohli Dec 6, 2021
10aea66
case when all nans in test
ravinkohli Dec 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,7 +1386,7 @@ def fit_ensemble(
Args:
optimize_metric (str): name of the metric that is used to
evaluate a pipeline. if not specified, value passed to search will be used
precision (int), (default=32): Numeric precision used when loading
precision (Optional[int]): Numeric precision used when loading
ensemble data. Can be either 16, 32 or 64.
ensemble_nbest (Optional[int]):
only consider the ensemble_nbest models to build the ensemble.
Expand Down Expand Up @@ -1429,6 +1429,7 @@ def fit_ensemble(
"Please call the `search()` method of {} prior to "
"fit_ensemble().".format(self.__class__.__name__))

precision = precision if precision is not None else self.precision
if precision not in [16, 32, 64]:
raise ValueError("precision must be one of 16, 32, 64 but got {}".format(precision))

Expand Down Expand Up @@ -1479,7 +1480,7 @@ def fit_ensemble(
manager = self._init_ensemble_builder(
time_left_for_ensembles=time_left_for_ensemble,
optimize_metric=self.opt_metric if optimize_metric is None else optimize_metric,
precision=self.precision if precision is None else precision,
precision=precision,
ensemble_size=ensemble_size,
ensemble_nbest=ensemble_nbest,
)
Expand Down
51 changes: 36 additions & 15 deletions autoPyTorch/data/tabular_feature_validator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import Dict, List, Optional, Tuple, cast
from typing import Dict, List, Optional, Tuple, Union, cast

import numpy as np

Expand Down Expand Up @@ -100,6 +100,7 @@ def _comparator(cmp1: str, cmp2: str) -> int:
if cmp1 not in choices or cmp2 not in choices:
raise ValueError('The comparator for the column order only accepts {}, '
'but got {} and {}'.format(choices, cmp1, cmp2))

idx1, idx2 = choices.index(cmp1), choices.index(cmp2)
return idx1 - idx2

Expand Down Expand Up @@ -246,13 +247,12 @@ def transform(
# having a value for a categorical column.
# We need to convert the column in test data to
# object otherwise the test column is interpreted as float
if len(self.categorical_columns) > 0:
categorical_columns = self.column_transformer.transformers_[0][-1]
for column in categorical_columns:
if X[column].isna().all():
X[column] = X[column].astype('object')

if self.column_transformer is not None:
if len(self.categorical_columns) > 0:
categorical_columns = self.column_transformer.transformers_[0][-1]
for column in categorical_columns:
if X[column].isna().all():
X[column] = X[column].astype('object')
X = self.column_transformer.transform(X)

# Sparse related transformations
Expand Down Expand Up @@ -337,16 +337,10 @@ def _check_data(

dtypes = [dtype.name for dtype in X.dtypes]

dtypes_diff = [s_dtype != dtype for s_dtype, dtype in zip(self.dtypes, dtypes)]
diff_cols = X.columns[[s_dtype != dtype for s_dtype, dtype in zip(self.dtypes, dtypes)]]
if len(self.dtypes) == 0:
self.dtypes = dtypes
elif (
any(dtypes_diff) # the dtypes of some columns are different in train and test dataset
and self.all_nan_columns is not None # Ignore all_nan_columns is None
and len(set(X.columns[dtypes_diff]).difference(self.all_nan_columns)) != 0
):
# The dtypes can be different if and only if the column belongs
# to all_nan_columns as these columns would be imputed.
elif not self._is_datasets_consistent(diff_cols, X):
raise ValueError("The dtype of the features must not be changed after fit(), but"
" the dtypes of some columns are different between training ({}) and"
" test ({}) datasets.".format(self.dtypes, dtypes))
Expand Down Expand Up @@ -508,6 +502,33 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:

return X

def _is_datasets_consistent(self, diff_cols: List[Union[int, str]], X: pd.DataFrame) -> bool:
"""
Check the consistency of dtypes between training and test datasets.
The dtypes can be different if the column belongs to `self.all_nan_columns`
(list of column names with all nans in training data) or if the column is
all nan as these columns would be imputed.

Args:
diff_cols (List[bool]):
The column labels that have different dtypes.
X (pd.DataFrame):
A validation or test dataset to be compared with the training dataset
Returns:
_ (bool): Whether the training and test datasets are consistent.
"""
if self.all_nan_columns is None:
if len(diff_cols) == 0:
return True
else:
return all(X[diff_cols].isna().all())

# dtype is different ==> the column in at least either of train or test datasets must be all NaN
# inconsistent <==> dtype is different and the col in both train and test is not all NaN
inconsistent_cols = list(set(diff_cols) - self.all_nan_columns)

return len(inconsistent_cols) == 0 or all(X[inconsistent_cols].isna().all())


def has_object_columns(
feature_types: pd.Series,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
Returns:
(Dict[str, Any]): the updated 'X' dictionary
"""
X.update({'encoder': self.preprocessor})
# X.update({'encoder': self.preprocessor})
return X

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
Returns:
np.ndarray: Transformed features
"""
X.update({'scaler': self.preprocessor})
# X.update({'scaler': self.preprocessor})
return X

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, n_res_inputs: int, n_outputs: int):
def forward(self, x: torch.Tensor, res: torch.Tensor) -> torch.Tensor:
shortcut = self.shortcut(res)
shortcut = self.bn(shortcut)
x += shortcut
x = x + shortcut
return torch.relu(x)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:

self.embedding = self.build_embedding(
num_input_features=num_input_features,
num_numerical_features=num_numerical_columns)
num_numerical_features=num_numerical_columns) # type: ignore[arg-type]
return self

def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,7 @@ def train_step(self, data: np.ndarray, targets: np.ndarray) -> Tuple[float, torc
loss = loss_func(self.criterion, original_outputs, adversarial_outputs)
loss.backward()
self.optimizer.step()
if self.scheduler:
if 'ReduceLROnPlateau' in self.scheduler.__class__.__name__:
self.scheduler.step(loss)
else:
self.scheduler.step()

# only passing the original outputs since we do not care about
# the adversarial performance.
return loss.item(), original_outputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> autoPyTorchCom
y=y,
**kwargs
)

# Add snapshots to base network to enable
# predicting with snapshot ensemble
self.choice: autoPyTorchComponent = cast(autoPyTorchComponent, self.choice)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_search_space_updates():
value_range=['shake-shake'],
default_value='shake-shake')
updates.append(node_name='network_backbone',
hyperparameter='ResNetBackbone:shake_shake_method',
hyperparameter='ResNetBackbone:shake_shake_update_func',
value_range=['M3'],
default_value='M3'
)
Expand Down
77 changes: 31 additions & 46 deletions test/test_data/test_feature_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def test_featurevalidator_supported_types(input_data_featuretest):
assert sparse.issparse(transformed_X)
else:
assert isinstance(transformed_X, np.ndarray)
assert np.shape(input_data_featuretest) == np.shape(transformed_X)
assert np.issubdtype(transformed_X.dtype, np.number)
assert validator._is_fitted

Expand Down Expand Up @@ -237,9 +236,10 @@ def test_featurevalidator_categorical_nan(input_data_featuretest):
validator.fit(input_data_featuretest)
transformed_X = validator.transform(input_data_featuretest)
assert any(pd.isna(input_data_featuretest))
assert any((-1 in categories) or ('-1' in categories) or ('Missing!' in categories) for categories in
validator.encoder.named_transformers_['encoder'].categories_)
assert np.shape(input_data_featuretest) == np.shape(transformed_X)
categories_ = validator.column_transformer.\
named_transformers_['categorical_pipeline'].named_steps['onehotencoder'].categories_
assert any(('0' in categories) or (0 in categories) or ('missing_value' in categories) for categories in
Copy link

@ArlindKadra ArlindKadra Dec 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This are the unique category values for the one hot encoder right, what would be the case where they are 0?
Is it not 'missing_value' for categorical columns?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will be 0 when the column is categorical but the dtype of the column is int

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

according to this

if column_dtype in ['category', 'bool']:
categorical_columns.append(column)

It will be a categorical column only when it has a string or bool dtype.
And if between test and train there are differences, it will either throw an error or convert a null column to object for test if not empty in train.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it fail if you run it without the 0 checks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be a categorical column only when it has a string or bool dtype

not string, but category. And category can be string or int.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works fine without '0' check but removing 0 gives an error

categories_)
assert np.issubdtype(transformed_X.dtype, np.number)
assert validator._is_fitted
assert isinstance(transformed_X, np.ndarray)
Expand Down Expand Up @@ -292,7 +292,6 @@ def test_featurevalidator_fitontypeA_transformtypeB(input_data_featuretest):
else:
raise ValueError(type(input_data_featuretest))
transformed_X = validator.transform(complementary_type)
assert np.shape(input_data_featuretest) == np.shape(transformed_X)
assert np.issubdtype(transformed_X.dtype, np.number)
assert validator._is_fitted

Expand Down Expand Up @@ -436,36 +435,29 @@ def test_features_unsupported_calls_are_raised():
expected
"""
validator = TabularFeatureValidator()
with pytest.raises(ValueError, match=r"AutoPyTorch does not support time"):
with pytest.raises(TypeError, match=r".*?Convert the time information to a numerical value"):
validator.fit(
pd.DataFrame({'datetime': [pd.Timestamp('20180310')]})
)
validator = TabularFeatureValidator()
with pytest.raises(ValueError, match=r"AutoPyTorch only supports.*yet, the provided input"):
validator.fit({'input1': 1, 'input2': 2})
with pytest.raises(ValueError, match=r"has unsupported dtype string"):
validator = TabularFeatureValidator()
with pytest.raises(TypeError, match=r".*?but input column A has an invalid type `string`.*"):
validator.fit(pd.DataFrame([{'A': 1, 'B': 2}], dtype='string'))
validator = TabularFeatureValidator()
with pytest.raises(ValueError, match=r"The feature dimensionality of the train and test"):
validator.fit(X_train=np.array([[1, 2, 3], [4, 5, 6]]),
X_test=np.array([[1, 2, 3, 4], [4, 5, 6, 7]]),
)
validator = TabularFeatureValidator()
with pytest.raises(ValueError, match=r"Cannot call transform on a validator that is not fit"):
validator.transform(np.array([[1, 2, 3], [4, 5, 6]]))


@pytest.mark.parametrize(
'input_data_featuretest',
(
'numpy_numericalonly_nonan',
'numpy_numericalonly_nan',
'pandas_numericalonly_nonan',
'pandas_numericalonly_nan',
'list_numericalonly_nonan',
'list_numericalonly_nan',
# Category in numpy is handled via feat_type
'numpy_categoricalonly_nonan',
'numpy_mixed_nonan',
'numpy_categoricalonly_nan',
'numpy_mixed_nan',
'sparse_bsr_nonan',
'sparse_bsr_nan',
'sparse_coo_nonan',
Expand All @@ -483,14 +475,14 @@ def test_features_unsupported_calls_are_raised():
),
indirect=True
)
def test_no_encoder_created(input_data_featuretest):
def test_no_column_transformer_created(input_data_featuretest):
"""
Makes sure that for numerical only features, no encoder is created
"""
validator = TabularFeatureValidator()
validator.fit(input_data_featuretest)
validator.transform(input_data_featuretest)
assert validator.encoder is None
assert validator.column_transformer is None


@pytest.mark.parametrize(
Expand All @@ -501,18 +493,18 @@ def test_no_encoder_created(input_data_featuretest):
),
indirect=True
)
def test_encoder_created(input_data_featuretest):
def test_column_transformer_created(input_data_featuretest):
"""
This test ensures an encoder is created if categorical data is provided
This test ensures an column transformer is created if categorical data is provided
"""
validator = TabularFeatureValidator()
validator.fit(input_data_featuretest)
transformed_X = validator.transform(input_data_featuretest)
assert validator.encoder is not None
assert validator.column_transformer is not None

# Make sure that the encoded features are actually encoded. Categorical columns are at
# the start after transformation. In our fixtures, this is also honored prior encode
enc_columns, feature_types = validator._get_columns_to_encode(input_data_featuretest)
cat_columns, _, feature_types = validator._get_columns_info(input_data_featuretest)

# At least one categorical
assert 'categorical' in validator.feat_type
Expand All @@ -521,20 +513,13 @@ def test_encoder_created(input_data_featuretest):
if np.any([pd.api.types.is_numeric_dtype(input_data_featuretest[col]
) for col in input_data_featuretest.columns]):
assert 'numerical' in validator.feat_type
for i, feat_type in enumerate(feature_types):
if 'numerical' in feat_type:
np.testing.assert_array_equal(
transformed_X[:, i],
input_data_featuretest[input_data_featuretest.columns[i]].to_numpy()
)
elif 'categorical' in feat_type:
np.testing.assert_array_equal(
transformed_X[:, i],
# Expect always 0, 1... because we use a ordinal encoder
np.array([0, 1])
)
else:
raise ValueError(feat_type)
# we expect this input to be the fixture 'pandas_mixed_nan'
np.testing.assert_array_equal(transformed_X, np.array([[1., 0., -1.], [0., 1., 1.]]))
else:
np.testing.assert_array_equal(transformed_X, np.array([[1., 0., 1., 0.], [0., 1., 0., 1.]]))

if not all([feat_type in ['numerical', 'categorical'] for feat_type in feature_types]):
raise ValueError("Expected only numerical and categorical feature types")


def test_no_new_category_after_fit():
Expand Down Expand Up @@ -566,13 +551,12 @@ def test_unknown_encode_value():
x['c'].cat.add_categories(['NA'], inplace=True)
x.loc[0, 'c'] = 'NA' # unknown value
x_t = validator.transform(x)
# The first row should have a -1 as we added a new categorical there
expected_row = [-1, -41, -3, -987.2]
# The first row should have a 0, 0 as we added a
# new categorical there and one hot encoder marks
# it as all zeros for the transformed column
expected_row = [0.0, 0.0, -0.5584294383572701, 0.5000000000000004, -1.5136598016833485]
assert expected_row == x_t[0].tolist()

# Notice how there is only one column 'c' to encode
assert validator.categories == [list(range(2)) for i in range(1)]


# Actual checks for the features
@pytest.mark.parametrize(
Expand Down Expand Up @@ -624,19 +608,20 @@ def test_feature_validator_new_data_after_fit(
assert sparse.issparse(transformed_X)
else:
assert isinstance(transformed_X, np.ndarray)
assert np.shape(X_test) == np.shape(transformed_X)

# And then check proper error messages
if train_data_type == 'pandas':
old_dtypes = copy.deepcopy(validator.dtypes)
validator.dtypes = ['dummy' for dtype in X_train.dtypes]
with pytest.raises(ValueError, match=r"Changing the dtype of the features after fit"):
with pytest.raises(ValueError,
match=r"The dtype of the features must not be changed after fit"):
transformed_X = validator.transform(X_test)
validator.dtypes = old_dtypes
if test_data_type == 'pandas':
columns = X_test.columns.tolist()
X_test = X_test[reversed(columns)]
with pytest.raises(ValueError, match=r"Changing the column order of the features"):
with pytest.raises(ValueError,
match=r"The column order of the features must not be changed after fit"):
transformed_X = validator.transform(X_test)


Expand Down
Loading