From e9cb8ac06947dd451e60c377f26febe0a0ab3c9c Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Thu, 25 Feb 2021 17:41:39 +0100 Subject: [PATCH 1/5] handle nans in categorical columns --- autoPyTorch/data/tabular_feature_validator.py | 28 ++++++++++--------- test/test_data/test_feature_validator.py | 10 +++++-- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index fb9a72082..9be8de97c 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -53,10 +53,25 @@ def _fit( for column in X.columns: if X[column].isna().all(): X[column] = pd.to_numeric(X[column]) + # Also note this change in self.dtypes + self.dtypes[list(X.columns).index(column)] = X[column].dtype self.enc_columns, self.feat_type = self._get_columns_to_encode(X) if len(self.enc_columns) > 0: + # impute missing values before encoding, + # remove once sklearn natively supports + # it in ordinal encoding. Sklearn issue: + # "https://github.com/scikit-learn/scikit-learn/issues/17123)" + for column in self.enc_columns: + if X[column].isna().any(): + missing_value = -1 + # make sure for a string column we give + # string missing value else we give numeric + if type(X[column][0]) == str: + missing_value = str(missing_value) + X[column] = X[column].cat.add_categories([missing_value]) + X[column] = X[column].fillna(missing_value) self.encoder = make_column_transformer( (preprocessing.OrdinalEncoder( @@ -217,19 +232,6 @@ def _check_data( # per estimator enc_columns, _ = self._get_columns_to_encode(X) - if len(enc_columns) > 0: - if np.any(pd.isnull( - X[enc_columns].dropna( # type: ignore[call-overload] - axis='columns', how='all') - )): - # Ignore all NaN columns, and if still a NaN - # Error out - raise ValueError("Categorical features in a dataframe cannot contain " - "missing/NaN values. The OrdinalEncoder used by " - "AutoPyTorch cannot handle this yet (due to a " - "limitation on scikit-learn being addressed via: " - "https://github.com/scikit-learn/scikit-learn/issues/17123)" - ) column_order = [column for column in X.columns] if len(self.column_order) > 0: if self.column_order != column_order: diff --git a/test/test_data/test_feature_validator.py b/test/test_data/test_feature_validator.py index afa2b43e1..14a50b3a3 100644 --- a/test/test_data/test_feature_validator.py +++ b/test/test_data/test_feature_validator.py @@ -231,10 +231,14 @@ def test_featurevalidator_unsupported_numpy(input_data_featuretest): ), indirect=True ) -def test_featurevalidator_unsupported_pandas(input_data_featuretest): +def test_featurevalidator_categorical_nan(input_data_featuretest): validator = TabularFeatureValidator() - with pytest.raises(ValueError, match=r"Categorical features in a dataframe.*missing/NaN"): - validator.fit(input_data_featuretest) + validator.fit(input_data_featuretest) + transformed_X = validator.transform(input_data_featuretest) + assert np.shape(input_data_featuretest) == np.shape(transformed_X) + assert np.issubdtype(transformed_X.dtype, np.number) + assert validator._is_fitted + assert isinstance(transformed_X, np.ndarray) @pytest.mark.parametrize( From 600f554e4da57d56922825a62e162edd8d2f8fd1 Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Fri, 26 Feb 2021 13:39:52 +0100 Subject: [PATCH 2/5] Fixed error in self dtypes --- autoPyTorch/data/tabular_feature_validator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index 9be8de97c..70719c31c 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -54,7 +54,8 @@ def _fit( if X[column].isna().all(): X[column] = pd.to_numeric(X[column]) # Also note this change in self.dtypes - self.dtypes[list(X.columns).index(column)] = X[column].dtype + if len(self.dtypes) != 0: + self.dtypes[list(X.columns).index(column)] = X[column].dtype self.enc_columns, self.feat_type = self._get_columns_to_encode(X) @@ -65,7 +66,7 @@ def _fit( # "https://github.com/scikit-learn/scikit-learn/issues/17123)" for column in self.enc_columns: if X[column].isna().any(): - missing_value = -1 + missing_value: typing.Union[int, str] = -1 # make sure for a string column we give # string missing value else we give numeric if type(X[column][0]) == str: From 735f38e55a40d601a6afe7975393adfb2b73ec6e Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Mon, 1 Mar 2021 15:01:18 +0100 Subject: [PATCH 3/5] Addressed comments from francisco --- test/test_data/test_feature_validator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_data/test_feature_validator.py b/test/test_data/test_feature_validator.py index 14a50b3a3..6d90ef2f9 100644 --- a/test/test_data/test_feature_validator.py +++ b/test/test_data/test_feature_validator.py @@ -235,6 +235,9 @@ def test_featurevalidator_categorical_nan(input_data_featuretest): validator = TabularFeatureValidator() validator.fit(input_data_featuretest) transformed_X = validator.transform(input_data_featuretest) + assert any(pd.isna(input_data_featuretest)) + assert any((-1 in categories) or ('-1' in categories) for categories in + validator.encoder.named_transformers_['encoder'].categories_) assert np.shape(input_data_featuretest) == np.shape(transformed_X) assert np.issubdtype(transformed_X.dtype, np.number) assert validator._is_fitted From 67b20c7a566f2ad737b058484498f996be612dea Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Mon, 1 Mar 2021 15:04:06 +0100 Subject: [PATCH 4/5] Forgot to commit --- autoPyTorch/data/tabular_feature_validator.py | 42 ++++++++++--------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index 70719c31c..0a5d67475 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -11,7 +11,7 @@ import sklearn.utils from sklearn import preprocessing from sklearn.base import BaseEstimator -from sklearn.compose import make_column_transformer +from sklearn.compose import ColumnTransformer from sklearn.exceptions import NotFittedError from autoPyTorch.data.base_feature_validator import BaseFeatureValidator, SUPPORTED_FEAT_TYPES @@ -74,11 +74,13 @@ def _fit( X[column] = X[column].cat.add_categories([missing_value]) X[column] = X[column].fillna(missing_value) - self.encoder = make_column_transformer( - (preprocessing.OrdinalEncoder( - handle_unknown='use_encoded_value', - unknown_value=-1, - ), self.enc_columns), + self.encoder = ColumnTransformer( + [ + ("encoder", + preprocessing.OrdinalEncoder( + handle_unknown='use_encoded_value', + unknown_value=-1, + ), self.enc_columns)], remainder="passthrough" ) @@ -101,6 +103,7 @@ def comparator(cmp1: str, cmp2: str) -> int: return 1 else: raise ValueError((cmp1, cmp2)) + self.feat_type = sorted( self.feat_type, key=functools.cmp_to_key(comparator) @@ -199,8 +202,8 @@ def _check_data( raise ValueError("AutoPyTorch only supports Numpy arrays, Pandas DataFrames," " scipy sparse and Python Lists, yet, the provided input is" " of type {}".format( - type(X) - )) + type(X) + )) if self.data_type is None: self.data_type = type(X) @@ -239,9 +242,9 @@ def _check_data( raise ValueError("Changing the column order of the features after fit() is " "not supported. Fit() method was called with " "{} whereas the new features have {} as type".format( - self.column_order, - column_order, - )) + self.column_order, + column_order, + )) else: self.column_order = column_order dtypes = [dtype.name for dtype in X.dtypes] @@ -250,9 +253,9 @@ def _check_data( raise ValueError("Changing the dtype of the features after fit() is " "not supported. Fit() method was called with " "{} whereas the new features have {} as type".format( - self.dtypes, - dtypes, - )) + self.dtypes, + dtypes, + )) else: self.dtypes = dtypes @@ -297,7 +300,8 @@ def _get_columns_to_encode( "pandas.Series.astype ." "If working with string objects, the following " "tutorial illustrates how to work with text data: " - "https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html".format( # noqa: E501 + "https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html".format( + # noqa: E501 column, ) ) @@ -353,14 +357,14 @@ def list_to_dataframe( X_train = pd.DataFrame(data=X_train).infer_objects() self.logger.warning("The provided feature types to AutoPyTorch are of type list." "Features have been interpreted as: {}".format( - [(col, t) for col, t in zip(X_train.columns, X_train.dtypes)] - )) + [(col, t) for col, t in zip(X_train.columns, X_train.dtypes)] + )) if X_test is not None: if not isinstance(X_test, list): self.logger.warning("Train features are a list while the provided test data" "is {}. X_test will be casted as DataFrame.".format( - type(X_test) - )) + type(X_test) + )) X_test = pd.DataFrame(data=X_test).infer_objects() return X_train, X_test From 2de6c0a49c871a3dda2770d73a7d30a08fad422d Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Mon, 1 Mar 2021 15:11:48 +0100 Subject: [PATCH 5/5] Fix flake --- autoPyTorch/data/tabular_feature_validator.py | 30 ++++++++----------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index 0a5d67475..e73b66bb1 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -201,9 +201,8 @@ def _check_data( if not isinstance(X, (np.ndarray, pd.DataFrame)) and not scipy.sparse.issparse(X): raise ValueError("AutoPyTorch only supports Numpy arrays, Pandas DataFrames," " scipy sparse and Python Lists, yet, the provided input is" - " of type {}".format( - type(X) - )) + " of type {}".format(type(X)) + ) if self.data_type is None: self.data_type = type(X) @@ -241,10 +240,9 @@ def _check_data( if self.column_order != column_order: raise ValueError("Changing the column order of the features after fit() is " "not supported. Fit() method was called with " - "{} whereas the new features have {} as type".format( - self.column_order, - column_order, - )) + "{} whereas the new features have {} as type".format(self.column_order, + column_order,) + ) else: self.column_order = column_order dtypes = [dtype.name for dtype in X.dtypes] @@ -252,10 +250,10 @@ def _check_data( if self.dtypes != dtypes: raise ValueError("Changing the dtype of the features after fit() is " "not supported. Fit() method was called with " - "{} whereas the new features have {} as type".format( - self.dtypes, - dtypes, - )) + "{} whereas the new features have {} as type".format(self.dtypes, + dtypes, + ) + ) else: self.dtypes = dtypes @@ -356,15 +354,13 @@ def list_to_dataframe( # If a list was provided, it will be converted to pandas X_train = pd.DataFrame(data=X_train).infer_objects() self.logger.warning("The provided feature types to AutoPyTorch are of type list." - "Features have been interpreted as: {}".format( - [(col, t) for col, t in zip(X_train.columns, X_train.dtypes)] - )) + "Features have been interpreted as: {}".format([(col, t) for col, t in + zip(X_train.columns, X_train.dtypes)])) if X_test is not None: if not isinstance(X_test, list): self.logger.warning("Train features are a list while the provided test data" - "is {}. X_test will be casted as DataFrame.".format( - type(X_test) - )) + "is {}. X_test will be casted as DataFrame.".format(type(X_test)) + ) X_test = pd.DataFrame(data=X_test).infer_objects() return X_train, X_test