diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index b765f7c74..c5468eae7 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -307,6 +307,7 @@ def _get_dataset_input_validator( resampling_strategy_args: Optional[Dict[str, Any]] = None, dataset_name: Optional[str] = None, dataset_compression: Optional[DatasetCompressionSpec] = None, + **kwargs: Any ) -> Tuple[BaseDataset, BaseInputValidator]: """ Returns an object of a child class of `BaseDataset` and @@ -353,6 +354,7 @@ def get_dataset( resampling_strategy_args: Optional[Dict[str, Any]] = None, dataset_name: Optional[str] = None, dataset_compression: Optional[DatasetCompressionSpec] = None, + **kwargs: Any ) -> BaseDataset: """ Returns an object of a child class of `BaseDataset` according to the current task. @@ -407,6 +409,10 @@ def get_dataset( Subsampling takes into account classification labels and stratifies accordingly. We guarantee that at least one occurrence of each label is included in the sampled set. + kwargs (Any): + can be used to pass task specific dataset arguments. Currently supports + passing `feat_types` for tabular tasks which specifies whether a feature is + 'numerical' or 'categorical'. Returns: BaseDataset: @@ -420,7 +426,8 @@ def get_dataset( resampling_strategy=resampling_strategy, resampling_strategy_args=resampling_strategy_args, dataset_name=dataset_name, - dataset_compression=dataset_compression) + dataset_compression=dataset_compression, + **kwargs) return dataset diff --git a/autoPyTorch/api/tabular_classification.py b/autoPyTorch/api/tabular_classification.py index 3d80a0338..facb59f99 100644 --- a/autoPyTorch/api/tabular_classification.py +++ b/autoPyTorch/api/tabular_classification.py @@ -168,6 +168,7 @@ def _get_dataset_input_validator( resampling_strategy_args: Optional[Dict[str, Any]] = None, dataset_name: Optional[str] = None, dataset_compression: Optional[DatasetCompressionSpec] = None, + **kwargs: Any, ) -> Tuple[TabularDataset, TabularInputValidator]: """ Returns an object of `TabularDataset` and an object of @@ -194,6 +195,9 @@ def _get_dataset_input_validator( dataset_compression (Optional[DatasetCompressionSpec]): specifications for dataset compression. For more info check documentation for `BaseTask.get_dataset`. + kwargs (Any): + Currently for tabular tasks, expect `feat_types: (Optional[List[str]]` which + specifies whether a feature is 'numerical' or 'categorical'. Returns: TabularDataset: @@ -206,12 +210,14 @@ def _get_dataset_input_validator( resampling_strategy_args = resampling_strategy_args if resampling_strategy_args is not None else \ self.resampling_strategy_args + feat_types = kwargs.pop('feat_types', None) # Create a validator object to make sure that the data provided by # the user matches the autopytorch requirements input_validator = TabularInputValidator( is_classification=True, logger_port=self._logger_port, - dataset_compression=dataset_compression + dataset_compression=dataset_compression, + feat_types=feat_types ) # Fit a input validator to check the provided data @@ -238,6 +244,7 @@ def search( X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, dataset_name: Optional[str] = None, + feat_types: Optional[List[str]] = None, budget_type: str = 'epochs', min_budget: int = 5, max_budget: int = 50, @@ -266,6 +273,10 @@ def search( A pair of features (X_train) and targets (y_train) used to fit a pipeline. Additionally, a holdout of this pairs (X_test, y_test) can be provided to track the generalization performance of each stage. + feat_types (Optional[List[str]]): + Description about the feature types of the columns. + Accepts `numerical` for integers, float data and `categorical` + for categories, strings and bool. Defaults to None. optimize_metric (str): name of the metric that is used to evaluate a pipeline. budget_type (str): @@ -433,7 +444,8 @@ def search( resampling_strategy=self.resampling_strategy, resampling_strategy_args=self.resampling_strategy_args, dataset_name=dataset_name, - dataset_compression=self._dataset_compression) + dataset_compression=self._dataset_compression, + feat_types=feat_types) return self._search( dataset=self.dataset, diff --git a/autoPyTorch/api/tabular_regression.py b/autoPyTorch/api/tabular_regression.py index fa8cf8081..e0c1e4eac 100644 --- a/autoPyTorch/api/tabular_regression.py +++ b/autoPyTorch/api/tabular_regression.py @@ -169,6 +169,7 @@ def _get_dataset_input_validator( resampling_strategy_args: Optional[Dict[str, Any]] = None, dataset_name: Optional[str] = None, dataset_compression: Optional[DatasetCompressionSpec] = None, + **kwargs: Any ) -> Tuple[TabularDataset, TabularInputValidator]: """ Returns an object of `TabularDataset` and an object of @@ -195,6 +196,9 @@ def _get_dataset_input_validator( dataset_compression (Optional[DatasetCompressionSpec]): specifications for dataset compression. For more info check documentation for `BaseTask.get_dataset`. + kwargs (Any): + Currently for tabular tasks, expect `feat_types: (Optional[List[str]]` which + specifies whether a feature is 'numerical' or 'categorical'. Returns: TabularDataset: the dataset object. @@ -206,12 +210,14 @@ def _get_dataset_input_validator( resampling_strategy_args = resampling_strategy_args if resampling_strategy_args is not None else \ self.resampling_strategy_args + feat_types = kwargs.pop('feat_types', None) # Create a validator object to make sure that the data provided by # the user matches the autopytorch requirements input_validator = TabularInputValidator( is_classification=False, logger_port=self._logger_port, - dataset_compression=dataset_compression + dataset_compression=dataset_compression, + feat_types=feat_types ) # Fit a input validator to check the provided data @@ -238,6 +244,7 @@ def search( X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, dataset_name: Optional[str] = None, + feat_types: Optional[List[str]] = None, budget_type: str = 'epochs', min_budget: int = 5, max_budget: int = 50, @@ -266,6 +273,10 @@ def search( A pair of features (X_train) and targets (y_train) used to fit a pipeline. Additionally, a holdout of this pairs (X_test, y_test) can be provided to track the generalization performance of each stage. + feat_types (Optional[List[str]]): + Description about the feature types of the columns. + Accepts `numerical` for integers, float data and `categorical` + for categories, strings and bool. Defaults to None. optimize_metric (str): Name of the metric that is used to evaluate a pipeline. budget_type (str): @@ -434,7 +445,8 @@ def search( resampling_strategy=self.resampling_strategy, resampling_strategy_args=self.resampling_strategy_args, dataset_name=dataset_name, - dataset_compression=self._dataset_compression) + dataset_compression=self._dataset_compression, + feat_types=feat_types) return self._search( dataset=self.dataset, diff --git a/autoPyTorch/data/base_feature_validator.py b/autoPyTorch/data/base_feature_validator.py index 11c6cf577..2d09c474e 100644 --- a/autoPyTorch/data/base_feature_validator.py +++ b/autoPyTorch/data/base_feature_validator.py @@ -35,7 +35,7 @@ def __init__( logger: Optional[Union[PicklableClientLogger, logging.Logger]] = None, ): # Register types to detect unsupported data format changes - self.feat_type: Optional[List[str]] = None + self.feat_types: Optional[List[str]] = None self.data_type: Optional[type] = None self.dtypes: List[str] = [] self.column_order: List[str] = [] diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index 8dad37205..fab2471c4 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -94,12 +94,18 @@ class TabularFeatureValidator(BaseFeatureValidator): List of indices of numerical columns categorical_columns (List[int]): List of indices of categorical columns + feat_types (List[str]): + Description about the feature types of the columns. + Accepts `numerical` for integers, float data and `categorical` + for categories, strings and bool. """ def __init__( self, logger: Optional[Union[PicklableClientLogger, Logger]] = None, + feat_types: Optional[List[str]] = None, ): super().__init__(logger) + self.feat_types = feat_types @staticmethod def _comparator(cmp1: str, cmp2: str) -> int: @@ -167,9 +173,9 @@ def _fit( if not X.select_dtypes(include='object').empty: X = self.infer_objects(X) - self.transformed_columns, self.feat_type = self._get_columns_to_encode(X) + self.transformed_columns, self.feat_types = self.get_columns_to_encode(X) - assert self.feat_type is not None + assert self.feat_types is not None if len(self.transformed_columns) > 0: @@ -186,8 +192,8 @@ def _fit( # The column transformer reorders the feature types # therefore, we need to change the order of columns as well # This means categorical columns are shifted to the left - self.feat_type = sorted( - self.feat_type, + self.feat_types = sorted( + self.feat_types, key=functools.cmp_to_key(self._comparator) ) @@ -201,7 +207,7 @@ def _fit( for cat in encoded_categories ] - for i, type_ in enumerate(self.feat_type): + for i, type_ in enumerate(self.feat_types): if 'numerical' in type_: self.numerical_columns.append(i) else: @@ -336,7 +342,7 @@ def _check_data( # Define the column to be encoded here as the feature validator is fitted once # per estimator - self.transformed_columns, self.feat_type = self._get_columns_to_encode(X) + self.transformed_columns, self.feat_types = self.get_columns_to_encode(X) column_order = [column for column in X.columns] if len(self.column_order) > 0: @@ -361,12 +367,72 @@ def _check_data( else: self.dtypes = dtypes + def get_columns_to_encode( + self, + X: pd.DataFrame + ) -> Tuple[List[str], List[str]]: + """ + Return the columns to be transformed as well as + the type of feature for each column. + + The returned values are dependent on `feat_types` passed to the `__init__`. + + Args: + X (pd.DataFrame) + A set of features that are going to be validated (type and dimensionality + checks) and an encoder fitted in the case the data needs encoding + + Returns: + transformed_columns (List[str]): + Columns to encode, if any + feat_type: + Type of each column numerical/categorical + """ + transformed_columns, feat_types = self._get_columns_to_encode(X) + if self.feat_types is not None: + self._validate_feat_types(X) + transformed_columns = [X.columns[i] for i, col in enumerate(self.feat_types) + if col.lower() == 'categorical'] + return transformed_columns, self.feat_types + else: + return transformed_columns, feat_types + + def _validate_feat_types(self, X: pd.DataFrame) -> None: + """ + Checks if the passed `feat_types` is compatible with what + AutoPyTorch expects, i.e, it should only contain `numerical` + or `categorical` and the number of feature types is equal to + the number of features. The case does not matter. + + Args: + X (pd.DataFrame): + input features set + + Raises: + ValueError: + if the number of feat_types is not equal to the number of features + if the feature type are not one of "numerical", "categorical" + """ + assert self.feat_types is not None # mypy check + + if len(self.feat_types) != len(X.columns): + raise ValueError(f"Expected number of `feat_types`: {len(self.feat_types)}" + f" to be the same as the number of features {len(X.columns)}") + for feat_type in set(self.feat_types): + if feat_type.lower() not in ['numerical', 'categorical']: + raise ValueError(f"Expected type of features to be in `['numerical', " + f"'categorical']`, but got {feat_type}") + def _get_columns_to_encode( self, X: pd.DataFrame, ) -> Tuple[List[str], List[str]]: """ - Return the columns to be encoded from a pandas dataframe + Return the columns to be transformed as well as + the type of feature for each column from a pandas dataframe. + + If `self.feat_types` is not None, it also validates that the + dataframe dtypes dont disagree with the ones passed in `__init__`. Args: X (pd.DataFrame) @@ -380,21 +446,24 @@ def _get_columns_to_encode( Type of each column numerical/categorical """ - if len(self.transformed_columns) > 0 and self.feat_type is not None: - return self.transformed_columns, self.feat_type + if len(self.transformed_columns) > 0 and self.feat_types is not None: + return self.transformed_columns, self.feat_types # Register if a column needs encoding transformed_columns = [] # Also, register the feature types for the estimator - feat_type = [] + feat_types = [] # Make sure each column is a valid type for i, column in enumerate(X.columns): if X[column].dtype.name in ['category', 'bool']: transformed_columns.append(column) - feat_type.append('categorical') + if self.feat_types is not None and self.feat_types[i].lower() == 'numerical': + raise ValueError(f"Passed numerical as the feature type for column: {column} " + f"but the column is categorical") + feat_types.append('categorical') # Move away from np.issubdtype as it causes # TypeError: data type not understood in certain pandas types elif not is_numeric_dtype(X[column]): @@ -434,8 +503,8 @@ def _get_columns_to_encode( ) ) else: - feat_type.append('numerical') - return transformed_columns, feat_type + feat_types.append('numerical') + return transformed_columns, feat_types def list_to_dataframe( self, diff --git a/autoPyTorch/data/tabular_validator.py b/autoPyTorch/data/tabular_validator.py index 492327fbe..0f6f89e1c 100644 --- a/autoPyTorch/data/tabular_validator.py +++ b/autoPyTorch/data/tabular_validator.py @@ -1,6 +1,6 @@ # -*- encoding: utf-8 -*- import logging -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np @@ -41,18 +41,24 @@ class TabularInputValidator(BaseInputValidator): dataset_compression (Optional[DatasetCompressionSpec]): specifications for dataset compression. For more info check documentation for `BaseTask.get_dataset`. + feat_types (List[str]): + Description about the feature types of the columns. + Accepts `numerical` for integers, float data and `categorical` + for categories, strings and bool """ def __init__( self, is_classification: bool = False, logger_port: Optional[int] = None, dataset_compression: Optional[DatasetCompressionSpec] = None, + feat_types: Optional[List[str]] = None, seed: int = 42, ): self.dataset_compression = dataset_compression self._reduced_dtype: Optional[DatasetDTypeContainerType] = None self.is_classification = is_classification self.logger_port = logger_port + self.feat_types = feat_types self.seed = seed if self.logger_port is not None: self.logger: Union[logging.Logger, PicklableClientLogger] = get_named_client_logger( @@ -63,7 +69,8 @@ def __init__( self.logger = logging.getLogger('Validation') self.feature_validator = TabularFeatureValidator( - logger=self.logger) + logger=self.logger, + feat_types=self.feat_types) self.target_validator = TabularTargetValidator( is_classification=self.is_classification, logger=self.logger diff --git a/examples/40_advanced/example_pass_feature_types.py b/examples/40_advanced/example_pass_feature_types.py new file mode 100644 index 000000000..658796a28 --- /dev/null +++ b/examples/40_advanced/example_pass_feature_types.py @@ -0,0 +1,93 @@ +""" +===================================================== +Tabular Classification with user passed feature types +===================================================== + +The following example shows how to pass feature typesfor datasets which are in +numpy format (also works for dataframes and lists) fit a sample classification +model with AutoPyTorch. + +AutoPyTorch relies on column dtypes for intepreting the feature types. But they +can be misinterpreted for example, when dataset is passed as a numpy array, all +the data is interpreted as numerical if it's dtype is int or float. However, the +categorical values could have been encoded as integers. + +Passing feature types helps AutoPyTorch interpreting them correctly as well as +validates the dataset by checking the dtype of the columns for any incompatibilities. +""" +import os +import tempfile as tmp +import warnings + +os.environ['JOBLIB_TEMP_FOLDER'] = tmp.gettempdir() +os.environ['OMP_NUM_THREADS'] = '1' +os.environ['OPENBLAS_NUM_THREADS'] = '1' +os.environ['MKL_NUM_THREADS'] = '1' + +warnings.simplefilter(action='ignore', category=UserWarning) +warnings.simplefilter(action='ignore', category=FutureWarning) + +import openml +import sklearn.model_selection + +from autoPyTorch.api.tabular_classification import TabularClassificationTask + + +############################################################################ +# Data Loading +# ============ +task = openml.tasks.get_task(task_id=146821) +dataset = task.get_dataset() +X, y, categorical_indicator, _ = dataset.get_data( + dataset_format='array', + target=dataset.default_target_attribute, +) +X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( + X, + y, + random_state=1, +) + +feat_types = ["numerical" if not indicator else "categorical" for indicator in categorical_indicator] + +# +############################################################################ +# Build and fit a classifier +# ========================== +api = TabularClassificationTask( + # To maintain logs of the run, you can uncomment the + # Following lines + # temporary_directory='./tmp/autoPyTorch_example_tmp_01', + # output_directory='./tmp/autoPyTorch_example_out_01', + # delete_tmp_folder_after_terminate=False, + # delete_output_folder_after_terminate=False, + seed=42, +) + +############################################################################ +# Search for an ensemble of machine learning algorithms +# ===================================================== +api.search( + X_train=X_train, + y_train=y_train, + X_test=X_test.copy(), + y_test=y_test.copy(), + dataset_name='Australian', + optimize_metric='accuracy', + total_walltime_limit=100, + func_eval_time_limit_secs=50, + feat_types=feat_types, + enable_traditional_pipeline=False +) + +############################################################################ +# Print the final ensemble performance +# ==================================== +y_pred = api.predict(X_test) +score = api.score(y_pred, y_test) +print(score) +# Print the final ensemble built by AutoPyTorch +print(api.show_models()) + +# Print statistics from search +print(api.sprint_statistics()) diff --git a/setup.py b/setup.py index 4198ff6b5..f040882ca 100755 --- a/setup.py +++ b/setup.py @@ -77,6 +77,7 @@ "jupyter", "notebook", "seaborn", + "openml" ], "docs": ["sphinx", "sphinx-gallery", "sphinx_bootstrap_theme", "numpydoc"], }, diff --git a/test/conftest.py b/test/conftest.py index 2bc292fff..2cf976d7a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -59,6 +59,10 @@ def callattr_ahead_of_alltests(request): 4871, # sensory 4857, # boston 3916, # kc1 + 2295, # cholesterol + 3916, # kc1-binary + 293554, # reuters + 294846 # rf1 ] # Populate the cache diff --git a/test/test_data/test_feature_validator.py b/test/test_data/test_feature_validator.py index 2daa271b7..08da7d7fd 100644 --- a/test/test_data/test_feature_validator.py +++ b/test/test_data/test_feature_validator.py @@ -259,12 +259,12 @@ def test_column_transformer_created(input_data_featuretest): transformed_columns, feature_types = validator._get_columns_to_encode(input_data_featuretest) # At least one categorical - assert 'categorical' in validator.feat_type + assert 'categorical' in validator.feat_types # Numerical if the original data has numerical only columns if np.any([pd.api.types.is_numeric_dtype(input_data_featuretest[col] ) for col in input_data_featuretest.columns]): - assert 'numerical' in validator.feat_type + assert 'numerical' in validator.feat_types for i, feat_type in enumerate(feature_types): if 'numerical' in feat_type: np.testing.assert_array_equal( @@ -406,3 +406,123 @@ def test_comparator(): key=functools.cmp_to_key(validator._comparator) ) assert ans == feat_type + + +@pytest.fixture +def input_data_feature_feat_types(request): + if request.param == 'pandas_categoricalonly': + return pd.DataFrame([ + {'A': 1, 'B': 2}, + {'A': 3, 'B': 4}, + ], dtype='category'), ['categorical', 'categorical'] + elif request.param == 'pandas_numericalonly': + return pd.DataFrame([ + {'A': 1, 'B': 2}, + {'A': 3, 'B': 4}, + ], dtype='float'), ['numerical', 'numerical'] + elif request.param == 'pandas_mixed': + frame = pd.DataFrame([ + {'A': 1, 'B': 2}, + {'A': 3, 'B': 4}, + ], dtype='category') + frame['B'] = pd.to_numeric(frame['B']) + return frame, ['categorical', 'numerical'] + elif request.param == 'pandas_string_error': + frame = pd.DataFrame([ + {'A': 1, 'B': '2'}, + {'A': 3, 'B': '4'}, + ], dtype='category') + return frame, ['categorical', 'numerical'] + elif request.param == 'pandas_length_error': + frame = pd.DataFrame([ + {'A': 1, 'B': '2'}, + {'A': 3, 'B': '4'}, + ], dtype='category') + return frame, ['categorical', 'categorical', 'numerical'] + elif request.param == 'pandas_feat_type_error': + frame = pd.DataFrame([ + {'A': 1, 'B': '2'}, + {'A': 3, 'B': '4'}, + ], dtype='category') + return frame, ['not_categorical', 'numerical'] + else: + ValueError("Unsupported indirect fixture {}".format(request.param)) + + +@pytest.mark.parametrize( + 'input_data_feature_feat_types', + ( + 'pandas_categoricalonly', + 'pandas_numericalonly', + 'pandas_mixed', + ), + indirect=True +) +def test_feature_validator_get_columns_to_encode(input_data_feature_feat_types): + X, feat_types = input_data_feature_feat_types + validator = TabularFeatureValidator(feat_types=feat_types) + transformed_columns, val_feat_types = validator.get_columns_to_encode(X) + + assert feat_types == val_feat_types + + for feat_type, col in zip(X.columns, val_feat_types): + if feat_type.lower() == 'categorical': + assert col in transformed_columns + + +@pytest.mark.parametrize( + 'input_data_feature_feat_types', + ( + 'pandas_string_error', + ), + indirect=True +) +def test_feature_validator_get_columns_to_encode_error_string(input_data_feature_feat_types): + """ + Tests the correct error is raised when feat types passed to + the validator disagree with the column dtypes. + + """ + X, feat_types = input_data_feature_feat_types + validator = TabularFeatureValidator(feat_types=feat_types) + with pytest.raises(ValueError, match=r"Passed numerical as the feature type for column: B but " + r"the column is categorical"): + validator.get_columns_to_encode(X) + + +@pytest.mark.parametrize( + 'input_data_feature_feat_types', + ( + 'pandas_length_error', + ), + indirect=True +) +def test_feature_validator_get_columns_to_encode_error_length(input_data_feature_feat_types): + """ + Tests the correct error is raised when the length of feat types passed to + the validator is not the same as the number of features + + """ + X, feat_types = input_data_feature_feat_types + validator = TabularFeatureValidator(feat_types=feat_types) + with pytest.raises(ValueError, match=r"Expected number of `feat_types`: .*"): + validator._validate_feat_types(X) + + +@pytest.mark.parametrize( + 'input_data_feature_feat_types', + ( + 'pandas_feat_type_error', + ), + indirect=True +) +def test_feature_validator_get_columns_to_encode_error_feat_type(input_data_feature_feat_types): + """ + Tests the correct error is raised when the length of feat types passed to + the validator is not the same as the number of features + + """ + X, feat_types = input_data_feature_feat_types + validator = TabularFeatureValidator(feat_types=feat_types) + with pytest.raises(ValueError, match=r"Expected type of features to be in .*"): + validator._validate_feat_types(X) diff --git a/test/test_data/test_target_validator.py b/test/test_data/test_target_validator.py index 8fd4527d9..3866bfb79 100644 --- a/test/test_data/test_target_validator.py +++ b/test/test_data/test_target_validator.py @@ -126,7 +126,7 @@ def input_data_targettest(request): 'sparse_csc_nonan', 'sparse_csr_nonan', 'sparse_lil_nonan', - 'openml_204', + 'openml_204', # openml cholesterol dataset ), indirect=True ) @@ -182,7 +182,7 @@ def test_targetvalidator_supported_types_noclassification(input_data_targettest) 'sparse_csc_nonan', 'sparse_csr_nonan', 'sparse_lil_nonan', - 'openml_2', + 'openml_2', # anneal dataset ), indirect=True ) @@ -246,7 +246,7 @@ def test_targetvalidator_supported_types_classification(input_data_targettest): 'pandas_binary', 'numpy_binary', 'list_binary', - 'openml_1066', + 'openml_1066', # kc1-binary dataset ), indirect=True ) @@ -266,7 +266,7 @@ def test_targetvalidator_binary(input_data_targettest): 'pandas_multiclass', 'numpy_multiclass', 'list_multiclass', - 'openml_54', + 'openml_54', # vehicle dataset ), indirect=True ) @@ -285,7 +285,7 @@ def test_targetvalidator_multiclass(input_data_targettest): 'pandas_multilabel', 'numpy_multilabel', 'list_multilabel', - 'openml_40594', + 'openml_40594', # reuters dataset ), indirect=True ) @@ -305,7 +305,7 @@ def test_targetvalidator_multilabel(input_data_targettest): 'pandas_continuous', 'numpy_continuous', 'list_continuous', - 'openml_531', + 'openml_531', # boston dataset ), indirect=True ) @@ -324,7 +324,7 @@ def test_targetvalidator_continuous(input_data_targettest): 'pandas_continuous-multioutput', 'numpy_continuous-multioutput', 'list_continuous-multioutput', - 'openml_41483', + 'openml_41483', # rf1 dataset ), indirect=True ) diff --git a/test/test_data/test_validation.py b/test/test_data/test_validation.py index f7755e35e..ba60a1760 100644 --- a/test/test_data/test_validation.py +++ b/test/test_data/test_validation.py @@ -49,8 +49,8 @@ def test_data_validation_for_classification(openmlid, as_frame): # Categorical columns are sorted to the beginning if as_frame: - validator.feature_validator.feat_type is not None - ordered_unique_elements = list(dict.fromkeys(validator.feature_validator.feat_type)) + validator.feature_validator.feat_types is not None + ordered_unique_elements = list(dict.fromkeys(validator.feature_validator.feat_types)) if len(ordered_unique_elements) > 1: assert ordered_unique_elements[0] == 'categorical' @@ -91,8 +91,8 @@ def test_data_validation_for_regression(openmlid, as_frame): # Categorical columns are sorted to the beginning if as_frame: - validator.feature_validator.feat_type is not None - ordered_unique_elements = list(dict.fromkeys(validator.feature_validator.feat_type)) + validator.feature_validator.feat_types is not None + ordered_unique_elements = list(dict.fromkeys(validator.feature_validator.feat_types)) if len(ordered_unique_elements) > 1: assert ordered_unique_elements[0] == 'categorical'