Skip to content

Commit 98c93c4

Browse files
[FIX] Passing checks (#298)
* Initial fix for all tests passing locally py=3.8 * fix bug in tests * fix bug in test for data * debugging error in dummy forward pass * debug try -2 * catch runtime error in ci * catch runtime error in ci * add better debug test setup * debug some more * run this test only * remove sum backward * remove inplace in inception block * undo silly change * Enable all tests * fix flake * fix bug in test setup * remove anamoly detection * minor changes to comments * Apply suggestions from code review Co-authored-by: nabenabe0928 <[email protected]> * Address comments from Shuhei * revert change leading to bug * fix flake * change comment position in feature validator * Add documentation for _is_datasets_consistent * address comments from arlind * case when all nans in test Co-authored-by: nabenabe0928 <[email protected]>
1 parent 5971f22 commit 98c93c4

File tree

18 files changed

+118
-119
lines changed

18 files changed

+118
-119
lines changed

autoPyTorch/api/base_task.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1698,7 +1698,7 @@ def fit_ensemble(
16981698
Args:
16991699
optimize_metric (str): name of the metric that is used to
17001700
evaluate a pipeline. if not specified, value passed to search will be used
1701-
precision (int), (default=32): Numeric precision used when loading
1701+
precision (Optional[int]): Numeric precision used when loading
17021702
ensemble data. Can be either 16, 32 or 64.
17031703
ensemble_nbest (Optional[int]):
17041704
only consider the ensemble_nbest models to build the ensemble.
@@ -1741,6 +1741,7 @@ def fit_ensemble(
17411741
"Please call the `search()` method of {} prior to "
17421742
"fit_ensemble().".format(self.__class__.__name__))
17431743

1744+
precision = precision if precision is not None else self.precision
17441745
if precision not in [16, 32, 64]:
17451746
raise ValueError("precision must be one of 16, 32, 64 but got {}".format(precision))
17461747

@@ -1791,7 +1792,7 @@ def fit_ensemble(
17911792
manager = self._init_ensemble_builder(
17921793
time_left_for_ensembles=time_left_for_ensemble,
17931794
optimize_metric=self.opt_metric if optimize_metric is None else optimize_metric,
1794-
precision=self.precision if precision is None else precision,
1795+
precision=precision,
17951796
ensemble_size=ensemble_size,
17961797
ensemble_nbest=ensemble_nbest,
17971798
)

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def _comparator(cmp1: str, cmp2: str) -> int:
140140
if cmp1 not in choices or cmp2 not in choices:
141141
raise ValueError('The comparator for the column order only accepts {}, '
142142
'but got {} and {}'.format(choices, cmp1, cmp2))
143+
143144
idx1, idx2 = choices.index(cmp1), choices.index(cmp2)
144145
return idx1 - idx2
145146

@@ -307,13 +308,12 @@ def transform(
307308
# having a value for a categorical column.
308309
# We need to convert the column in test data to
309310
# object otherwise the test column is interpreted as float
310-
if len(self.categorical_columns) > 0:
311-
categorical_columns = self.column_transformer.transformers_[0][-1]
312-
for column in categorical_columns:
313-
if X[column].isna().all():
314-
X[column] = X[column].astype('object')
315-
316311
if self.column_transformer is not None:
312+
if len(self.categorical_columns) > 0:
313+
categorical_columns = self.column_transformer.transformers_[0][-1]
314+
for column in categorical_columns:
315+
if X[column].isna().all():
316+
X[column] = X[column].astype('object')
317317
X = self.column_transformer.transform(X)
318318

319319
# Sparse related transformations
@@ -429,16 +429,10 @@ def _check_data(
429429

430430
dtypes = [dtype.name for dtype in X.dtypes]
431431

432-
dtypes_diff = [s_dtype != dtype for s_dtype, dtype in zip(self.dtypes, dtypes)]
432+
diff_cols = X.columns[[s_dtype != dtype for s_dtype, dtype in zip(self.dtypes, dtypes)]]
433433
if len(self.dtypes) == 0:
434434
self.dtypes = dtypes
435-
elif (
436-
any(dtypes_diff) # the dtypes of some columns are different in train and test dataset
437-
and self.all_nan_columns is not None # Ignore all_nan_columns is None
438-
and len(set(X.columns[dtypes_diff]).difference(self.all_nan_columns)) != 0
439-
):
440-
# The dtypes can be different if and only if the column belongs
441-
# to all_nan_columns as these columns would be imputed.
435+
elif not self._is_datasets_consistent(diff_cols, X):
442436
raise ValueError("The dtype of the features must not be changed after fit(), but"
443437
" the dtypes of some columns are different between training ({}) and"
444438
" test ({}) datasets.".format(self.dtypes, dtypes))
@@ -606,6 +600,33 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
606600

607601
return X
608602

603+
def _is_datasets_consistent(self, diff_cols: List[Union[int, str]], X: pd.DataFrame) -> bool:
604+
"""
605+
Check the consistency of dtypes between training and test datasets.
606+
The dtypes can be different if the column belongs to `self.all_nan_columns`
607+
(list of column names with all nans in training data) or if the column is
608+
all nan as these columns would be imputed.
609+
610+
Args:
611+
diff_cols (List[bool]):
612+
The column labels that have different dtypes.
613+
X (pd.DataFrame):
614+
A validation or test dataset to be compared with the training dataset
615+
Returns:
616+
_ (bool): Whether the training and test datasets are consistent.
617+
"""
618+
if self.all_nan_columns is None:
619+
if len(diff_cols) == 0:
620+
return True
621+
else:
622+
return all(X[diff_cols].isna().all())
623+
624+
# dtype is different ==> the column in at least either of train or test datasets must be all NaN
625+
# inconsistent <==> dtype is different and the col in both train and test is not all NaN
626+
inconsistent_cols = list(set(diff_cols) - self.all_nan_columns)
627+
628+
return len(inconsistent_cols) == 0 or all(X[inconsistent_cols].isna().all())
629+
609630

610631
def has_object_columns(
611632
feature_types: pd.Series,

autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/NoEncoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
4040
Returns:
4141
(Dict[str, Any]): the updated 'X' dictionary
4242
"""
43-
X.update({'encoder': self.preprocessor})
43+
# X.update({'encoder': self.preprocessor})
4444
return X
4545

4646
@staticmethod

autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/NoScaler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
4343
Returns:
4444
np.ndarray: Transformed features
4545
"""
46-
X.update({'scaler': self.preprocessor})
46+
# X.update({'scaler': self.preprocessor})
4747
return X
4848

4949
@staticmethod

autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:
2121

2222
self.embedding = self.build_embedding(
2323
num_input_features=num_input_features,
24-
num_numerical_features=num_numerical_columns)
24+
num_numerical_features=num_numerical_columns) # type: ignore[arg-type]
2525
return self
2626

2727
def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:

autoPyTorch/pipeline/components/training/trainer/AdversarialTrainer.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,7 @@ def train_step(self, data: np.ndarray, targets: np.ndarray) -> Tuple[float, torc
109109
loss = loss_func(self.criterion, original_outputs, adversarial_outputs)
110110
loss.backward()
111111
self.optimizer.step()
112-
if self.scheduler:
113-
if 'ReduceLROnPlateau' in self.scheduler.__class__.__name__:
114-
self.scheduler.step(loss)
115-
else:
116-
self.scheduler.step()
112+
117113
# only passing the original outputs since we do not care about
118114
# the adversarial performance.
119115
return loss.item(), original_outputs

autoPyTorch/pipeline/components/training/trainer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> autoPyTorchCom
283283
y=y,
284284
**kwargs
285285
)
286+
286287
# Add snapshots to base network to enable
287288
# predicting with snapshot ensemble
288289
self.choice: autoPyTorchComponent = cast(autoPyTorchComponent, self.choice)

examples/40_advanced/40_advanced/example_custom_configuration_space.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def get_search_space_updates():
5959
value_range=['shake-shake'],
6060
default_value='shake-shake')
6161
updates.append(node_name='network_backbone',
62-
hyperparameter='ResNetBackbone:shake_shake_method',
62+
hyperparameter='ResNetBackbone:shake_shake_update_func',
6363
value_range=['M3'],
6464
default_value='M3'
6565
)

test/test_data/test_feature_validator.py

Lines changed: 25 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,6 @@ def test_featurevalidator_supported_types(input_data_featuretest):
205205
assert sparse.issparse(transformed_X)
206206
else:
207207
assert isinstance(transformed_X, np.ndarray)
208-
assert np.shape(input_data_featuretest) == np.shape(transformed_X)
209208
assert np.issubdtype(transformed_X.dtype, np.number)
210209
assert validator._is_fitted
211210

@@ -238,11 +237,10 @@ def test_featurevalidator_categorical_nan(input_data_featuretest):
238237
validator.fit(input_data_featuretest)
239238
transformed_X = validator.transform(input_data_featuretest)
240239
assert any(pd.isna(input_data_featuretest))
241-
categories_ = validator.column_transformer.named_transformers_['categorical_pipeline'].\
242-
named_steps['ordinalencoder'].categories_
240+
categories_ = validator.column_transformer.\
241+
named_transformers_['categorical_pipeline'].named_steps['onehotencoder'].categories_
243242
assert any(('0' in categories) or (0 in categories) or ('missing_value' in categories) for categories in
244243
categories_)
245-
assert np.shape(input_data_featuretest) == np.shape(transformed_X)
246244
assert np.issubdtype(transformed_X.dtype, np.number)
247245
assert validator._is_fitted
248246
assert isinstance(transformed_X, np.ndarray)
@@ -295,7 +293,6 @@ def test_featurevalidator_fitontypeA_transformtypeB(input_data_featuretest):
295293
else:
296294
raise ValueError(type(input_data_featuretest))
297295
transformed_X = validator.transform(complementary_type)
298-
assert np.shape(input_data_featuretest) == np.shape(transformed_X)
299296
assert np.issubdtype(transformed_X.dtype, np.number)
300297
assert validator._is_fitted
301298

@@ -315,12 +312,6 @@ def test_featurevalidator_get_columns_to_encode():
315312
for col in df.columns:
316313
df[col] = df[col].astype(col)
317314

318-
<<<<<<< HEAD
319-
transformed_columns, feature_types = validator._get_columns_to_encode(df)
320-
321-
assert transformed_columns == ['category', 'bool']
322-
assert feature_types == ['numerical', 'numerical', 'categorical', 'categorical']
323-
=======
324315
validator.fit(df)
325316

326317
categorical_columns, numerical_columns, feat_type = validator._get_columns_info(df)
@@ -436,7 +427,6 @@ def test_feature_validator_remove_nan_catcolumns():
436427
)
437428
ans_test = np.array([[0, 0, 0, 0], [0, 0, 0, 0]], dtype=np.float64)
438429
feature_validator_remove_nan_catcolumns(df_train, df_test, ans_train, ans_test)
439-
>>>>>>> Bug fixes (#249)
440430

441431

442432
def test_features_unsupported_calls_are_raised():
@@ -446,36 +436,29 @@ def test_features_unsupported_calls_are_raised():
446436
expected
447437
"""
448438
validator = TabularFeatureValidator()
449-
with pytest.raises(ValueError, match=r"AutoPyTorch does not support time"):
439+
with pytest.raises(TypeError, match=r".*?Convert the time information to a numerical value"):
450440
validator.fit(
451441
pd.DataFrame({'datetime': [pd.Timestamp('20180310')]})
452442
)
443+
validator = TabularFeatureValidator()
453444
with pytest.raises(ValueError, match=r"AutoPyTorch only supports.*yet, the provided input"):
454445
validator.fit({'input1': 1, 'input2': 2})
455-
with pytest.raises(ValueError, match=r"has unsupported dtype string"):
446+
validator = TabularFeatureValidator()
447+
with pytest.raises(TypeError, match=r".*?but input column A has an invalid type `string`.*"):
456448
validator.fit(pd.DataFrame([{'A': 1, 'B': 2}], dtype='string'))
449+
validator = TabularFeatureValidator()
457450
with pytest.raises(ValueError, match=r"The feature dimensionality of the train and test"):
458451
validator.fit(X_train=np.array([[1, 2, 3], [4, 5, 6]]),
459452
X_test=np.array([[1, 2, 3, 4], [4, 5, 6, 7]]),
460453
)
454+
validator = TabularFeatureValidator()
461455
with pytest.raises(ValueError, match=r"Cannot call transform on a validator that is not fit"):
462456
validator.transform(np.array([[1, 2, 3], [4, 5, 6]]))
463457

464458

465459
@pytest.mark.parametrize(
466460
'input_data_featuretest',
467461
(
468-
'numpy_numericalonly_nonan',
469-
'numpy_numericalonly_nan',
470-
'pandas_numericalonly_nonan',
471-
'pandas_numericalonly_nan',
472-
'list_numericalonly_nonan',
473-
'list_numericalonly_nan',
474-
# Category in numpy is handled via feat_type
475-
'numpy_categoricalonly_nonan',
476-
'numpy_mixed_nonan',
477-
'numpy_categoricalonly_nan',
478-
'numpy_mixed_nan',
479462
'sparse_bsr_nonan',
480463
'sparse_bsr_nan',
481464
'sparse_coo_nonan',
@@ -513,7 +496,7 @@ def test_no_column_transformer_created(input_data_featuretest):
513496
)
514497
def test_column_transformer_created(input_data_featuretest):
515498
"""
516-
This test ensures an encoder is created if categorical data is provided
499+
This test ensures an column transformer is created if categorical data is provided
517500
"""
518501
validator = TabularFeatureValidator()
519502
validator.fit(input_data_featuretest)
@@ -522,7 +505,7 @@ def test_column_transformer_created(input_data_featuretest):
522505

523506
# Make sure that the encoded features are actually encoded. Categorical columns are at
524507
# the start after transformation. In our fixtures, this is also honored prior encode
525-
transformed_columns, feature_types = validator._get_columns_to_encode(input_data_featuretest)
508+
cat_columns, _, feature_types = validator._get_columns_info(input_data_featuretest)
526509

527510
# At least one categorical
528511
assert 'categorical' in validator.feat_type
@@ -531,20 +514,13 @@ def test_column_transformer_created(input_data_featuretest):
531514
if np.any([pd.api.types.is_numeric_dtype(input_data_featuretest[col]
532515
) for col in input_data_featuretest.columns]):
533516
assert 'numerical' in validator.feat_type
534-
for i, feat_type in enumerate(feature_types):
535-
if 'numerical' in feat_type:
536-
np.testing.assert_array_equal(
537-
transformed_X[:, i],
538-
input_data_featuretest[input_data_featuretest.columns[i]].to_numpy()
539-
)
540-
elif 'categorical' in feat_type:
541-
np.testing.assert_array_equal(
542-
transformed_X[:, i],
543-
# Expect always 0, 1... because we use a ordinal encoder
544-
np.array([0, 1])
545-
)
546-
else:
547-
raise ValueError(feat_type)
517+
# we expect this input to be the fixture 'pandas_mixed_nan'
518+
np.testing.assert_array_equal(transformed_X, np.array([[1., 0., -1.], [0., 1., 1.]]))
519+
else:
520+
np.testing.assert_array_equal(transformed_X, np.array([[1., 0., 1., 0.], [0., 1., 0., 1.]]))
521+
522+
if not all([feat_type in ['numerical', 'categorical'] for feat_type in feature_types]):
523+
raise ValueError("Expected only numerical and categorical feature types")
548524

549525

550526
def test_no_new_category_after_fit():
@@ -576,13 +552,12 @@ def test_unknown_encode_value():
576552
x['c'].cat.add_categories(['NA'], inplace=True)
577553
x.loc[0, 'c'] = 'NA' # unknown value
578554
x_t = validator.transform(x)
579-
# The first row should have a -1 as we added a new categorical there
580-
expected_row = [-1, -41, -3, -987.2]
555+
# The first row should have a 0, 0 as we added a
556+
# new categorical there and one hot encoder marks
557+
# it as all zeros for the transformed column
558+
expected_row = [0.0, 0.0, -0.5584294383572701, 0.5000000000000004, -1.5136598016833485]
581559
assert expected_row == x_t[0].tolist()
582560

583-
# Notice how there is only one column 'c' to encode
584-
assert validator.categories == [list(range(2)) for i in range(1)]
585-
586561

587562
# Actual checks for the features
588563
@pytest.mark.parametrize(
@@ -634,19 +609,20 @@ def test_feature_validator_new_data_after_fit(
634609
assert sparse.issparse(transformed_X)
635610
else:
636611
assert isinstance(transformed_X, np.ndarray)
637-
assert np.shape(X_test) == np.shape(transformed_X)
638612

639613
# And then check proper error messages
640614
if train_data_type == 'pandas':
641615
old_dtypes = copy.deepcopy(validator.dtypes)
642616
validator.dtypes = ['dummy' for dtype in X_train.dtypes]
643-
with pytest.raises(ValueError, match=r"Changing the dtype of the features after fit"):
617+
with pytest.raises(ValueError,
618+
match=r"The dtype of the features must not be changed after fit"):
644619
transformed_X = validator.transform(X_test)
645620
validator.dtypes = old_dtypes
646621
if test_data_type == 'pandas':
647622
columns = X_test.columns.tolist()
648623
X_test = X_test[reversed(columns)]
649-
with pytest.raises(ValueError, match=r"Changing the column order of the features"):
624+
with pytest.raises(ValueError,
625+
match=r"The column order of the features must not be changed after fit"):
650626
transformed_X = validator.transform(X_test)
651627

652628

test/test_data/test_validation.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import numpy as np
22

3-
import pandas as pd
4-
53
import pytest
64

75
from scipy import sparse
@@ -32,14 +30,6 @@ def test_data_validation_for_classification(openmlid, as_frame):
3230

3331
validator.fit(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test)
3432
X_train_t, y_train_t = validator.transform(X_train, y_train)
35-
assert np.shape(X_train) == np.shape(X_train_t)
36-
37-
# Leave columns that are complete NaN
38-
# The sklearn pipeline will handle that
39-
if as_frame and np.any(pd.isnull(X_train).values.all(axis=0)):
40-
assert np.any(pd.isnull(X_train_t).values.all(axis=0))
41-
elif not as_frame and np.any(pd.isnull(X_train).all(axis=0)):
42-
assert np.any(pd.isnull(X_train_t).all(axis=0))
4333

4434
# make sure everything was encoded to number
4535
assert np.issubdtype(X_train_t.dtype, np.number)
@@ -74,14 +64,6 @@ def test_data_validation_for_regression(openmlid, as_frame):
7464
validator.fit(X_train=X_train, y_train=y_train)
7565

7666
X_train_t, y_train_t = validator.transform(X_train, y_train)
77-
assert np.shape(X_train) == np.shape(X_train_t)
78-
79-
# Leave columns that are complete NaN
80-
# The sklearn pipeline will handle that
81-
if as_frame and np.any(pd.isnull(X_train).values.all(axis=0)):
82-
assert np.any(pd.isnull(X_train_t).values.all(axis=0))
83-
elif not as_frame and np.any(pd.isnull(X_train).all(axis=0)):
84-
assert np.any(pd.isnull(X_train_t).all(axis=0))
8567

8668
# make sure everything was encoded to number
8769
assert np.issubdtype(X_train_t.dtype, np.number)
@@ -103,8 +85,6 @@ def test_sparse_data_validation_for_regression():
10385
validator.fit(X_train=X_sp, y_train=y)
10486

10587
X_t, y_t = validator.transform(X, y)
106-
assert np.shape(X) == np.shape(X_t)
107-
10888
# make sure everything was encoded to number
10989
assert np.issubdtype(X_t.dtype, np.number)
11090
assert np.issubdtype(y_t.dtype, np.number)

0 commit comments

Comments
 (0)