From 623490e7aaeef93668ede45ec7aabd3abd579d20 Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Fri, 23 Sep 2022 16:32:28 +0200 Subject: [PATCH 1/5] fixed cut mix --- .../training/trainer/RowCutMixTrainer.py | 18 ++++++++++-------- .../training/trainer/RowCutOutTrainer.py | 16 +++++++++------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py b/autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py index bb4ccdb9a..c02bf133a 100644 --- a/autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py @@ -37,17 +37,19 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray, if beta <= 0 or r > self.alpha: return X, {'y_a': y, 'y_b': y[shuffled_indices], 'lam': 1} - cut_column_indices = torch.as_tensor( - self.random_state.choice( - range(n_columns), - max(1, np.int32(n_columns * lam)), - replace=False, - ), - ) + for i, idx in enumerate(shuffled_indices): + cut_column_indices = torch.as_tensor( + self.random_state.choice( + range(n_columns), + max(1, np.int32(n_columns * lam)), + replace=False, + ), + ) + X[i, cut_column_indices] = X[idx, cut_column_indices] # Replace the values in `cut_indices` columns with # the values from `permed_indices` - X[:, cut_column_indices] = X[shuffled_indices, :][:, cut_column_indices] + # X[:, cut_column_indices] = X[shuffled_indices, :][:, cut_column_indices] # Since we cannot cut exactly `lam x 100 %` of rows, we need to adjust the `lam` lam = 1 - (len(cut_column_indices) / n_columns) diff --git a/autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py b/autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py index 7b679976e..13511a96f 100644 --- a/autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py @@ -39,15 +39,17 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray, lam = 1 return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam} - size: int = np.shape(X)[1] - cut_column_indices = self.random_state.choice( - range(size), - max(1, np.int32(size * self.patch_ratio)), - replace=False, - ) + n_rows, size = np.shape(X) + for i in range(n_rows): + cut_column_indices = self.random_state.choice( + range(size), + max(1, np.int32(size * self.patch_ratio)), + replace=False, + ) + X[i, cut_column_indices] = 0 + # Mask the selected features as 0 - X[:, cut_column_indices] = 0 lam = 1 y_a = y y_b = y From da419ddea6e27ad6b35610225da8b641a65da5cf Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Fri, 23 Sep 2022 17:27:39 +0200 Subject: [PATCH 2/5] remove unnecessary comment --- .../components/training/trainer/RowCutMixTrainer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py b/autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py index c02bf133a..149d3bd9a 100644 --- a/autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py @@ -37,6 +37,9 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray, if beta <= 0 or r > self.alpha: return X, {'y_a': y, 'y_b': y[shuffled_indices], 'lam': 1} + + # Replace the values in `cut_indices` columns with + # the values from `permed_indices` for i, idx in enumerate(shuffled_indices): cut_column_indices = torch.as_tensor( self.random_state.choice( @@ -47,10 +50,6 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray, ) X[i, cut_column_indices] = X[idx, cut_column_indices] - # Replace the values in `cut_indices` columns with - # the values from `permed_indices` - # X[:, cut_column_indices] = X[shuffled_indices, :][:, cut_column_indices] - # Since we cannot cut exactly `lam x 100 %` of rows, we need to adjust the `lam` lam = 1 - (len(cut_column_indices) / n_columns) From 4e31ae861c99a323d2927db7dc3b566a5e571c9e Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Fri, 23 Sep 2022 17:29:13 +0200 Subject: [PATCH 3/5] change all_supported_metrics --- autoPyTorch/api/base_task.py | 4 ++-- autoPyTorch/api/tabular_classification.py | 4 ++-- autoPyTorch/api/tabular_regression.py | 4 ++-- autoPyTorch/api/time_series_forecasting.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index 303cefc4e..8618731f5 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -978,7 +978,7 @@ def _search( smac_scenario_args: Optional[Dict[str, Any]] = None, get_smac_object_callback: Optional[Callable] = None, tae_func: Optional[Callable] = None, - all_supported_metrics: bool = True, + all_supported_metrics: bool = False, precision: int = 32, disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None, load_models: bool = True, @@ -1076,7 +1076,7 @@ def _search( TargetAlgorithm to be optimised. If None, `eval_function` available in autoPyTorch/evaluation/train_evaluator is used. Must be child class of AbstractEvaluator. - all_supported_metrics (bool: default=True): + all_supported_metrics (bool: default=False): If True, all metrics supporting current task will be calculated for each pipeline and results will be available via cv_results precision (int: default=32): diff --git a/autoPyTorch/api/tabular_classification.py b/autoPyTorch/api/tabular_classification.py index facb59f99..aa6796ae2 100644 --- a/autoPyTorch/api/tabular_classification.py +++ b/autoPyTorch/api/tabular_classification.py @@ -254,7 +254,7 @@ def search( memory_limit: int = 4096, smac_scenario_args: Optional[Dict[str, Any]] = None, get_smac_object_callback: Optional[Callable] = None, - all_supported_metrics: bool = True, + all_supported_metrics: bool = False, precision: int = 32, disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None, load_models: bool = True, @@ -354,7 +354,7 @@ def search( TargetAlgorithm to be optimised. If None, `eval_function` available in autoPyTorch/evaluation/train_evaluator is used. Must be child class of AbstractEvaluator. - all_supported_metrics (bool: default=True): + all_supported_metrics (bool: default=False): If True, all metrics supporting current task will be calculated for each pipeline and results will be available via cv_results precision (int: default=32): diff --git a/autoPyTorch/api/tabular_regression.py b/autoPyTorch/api/tabular_regression.py index faf36097a..d6c30aa3a 100644 --- a/autoPyTorch/api/tabular_regression.py +++ b/autoPyTorch/api/tabular_regression.py @@ -253,7 +253,7 @@ def search( memory_limit: int = 4096, smac_scenario_args: Optional[Dict[str, Any]] = None, get_smac_object_callback: Optional[Callable] = None, - all_supported_metrics: bool = True, + all_supported_metrics: bool = False, precision: int = 32, disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None, load_models: bool = True, @@ -353,7 +353,7 @@ def search( TargetAlgorithm to be optimised. If None, `eval_function` available in autoPyTorch/evaluation/train_evaluator is used. Must be child class of AbstractEvaluator. - all_supported_metrics (bool: default=True): + all_supported_metrics (bool: default=False): If True, all metrics supporting current task will be calculated for each pipeline and results will be available via cv_results precision (int: default=32): diff --git a/autoPyTorch/api/time_series_forecasting.py b/autoPyTorch/api/time_series_forecasting.py index 67f6e5eaa..d564f8f47 100644 --- a/autoPyTorch/api/time_series_forecasting.py +++ b/autoPyTorch/api/time_series_forecasting.py @@ -289,7 +289,7 @@ def search( memory_limit: Optional[int] = 4096, smac_scenario_args: Optional[Dict[str, Any]] = None, get_smac_object_callback: Optional[Callable] = None, - all_supported_metrics: bool = True, + all_supported_metrics: bool = False, precision: int = 32, disable_file_output: List = [], load_models: bool = True, @@ -396,7 +396,7 @@ def search( instances, num_params, runhistory, seed and ta. This is an advanced feature. Use only if you are familiar with [SMAC](https://automl.github.io/SMAC3/master/index.html). - all_supported_metrics (bool), (default=True): if True, all + all_supported_metrics (bool), (default=False): if True, all metrics supporting current task will be calculated for each pipeline and results will be available via cv_results precision (int), (default=32): Numeric precision used when loading From 50d34c80ab6f8a695e5e1cf32623a89bdb9290d9 Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Mon, 17 Oct 2022 14:23:31 +0200 Subject: [PATCH 4/5] fix roc_auc for multiclass --- autoPyTorch/api/tabular_regression.py | 12 ++++++++++++ .../components/setup/network/base_network.py | 5 ++--- .../pipeline/components/training/metrics/base.py | 2 +- .../pipeline/components/training/metrics/metrics.py | 2 +- .../pipeline/components/training/metrics/utils.py | 4 ++-- 5 files changed, 18 insertions(+), 7 deletions(-) diff --git a/autoPyTorch/api/tabular_regression.py b/autoPyTorch/api/tabular_regression.py index d6c30aa3a..d6b14cd9b 100644 --- a/autoPyTorch/api/tabular_regression.py +++ b/autoPyTorch/api/tabular_regression.py @@ -483,3 +483,15 @@ def predict( # Allow to predict in the original domain -- that is, the user is not interested # in our encoded values return self.input_validator.target_validator.inverse_transform(predicted_values) + + def score( + self, + y_pred: np.ndarray, + y_test: Union[np.ndarray, pd.DataFrame] + ) -> Dict[str, float]: + if self.input_validator is None or not self.input_validator._is_fitted: + raise ValueError("predict() is only supported after calling search. Kindly call first " + "the estimator search() method.") + y_pred = self.input_validator.target_validator.transform(y_pred) + y_test = self.input_validator.target_validator.transform(y_test) + return super().score(y_pred=y_pred, y_test=y_test) \ No newline at end of file diff --git a/autoPyTorch/pipeline/components/setup/network/base_network.py b/autoPyTorch/pipeline/components/setup/network/base_network.py index 7ec872b96..0d4d3b34d 100644 --- a/autoPyTorch/pipeline/components/setup/network/base_network.py +++ b/autoPyTorch/pipeline/components/setup/network/base_network.py @@ -56,15 +56,14 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchTrainingComponent: self.network = torch.nn.Sequential(X['network_embedding'], X['network_backbone'], X['network_head']) + if STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']] in CLASSIFICATION_TASKS: + self.network = torch.nn.Sequential(self.network, nn.Softmax(dim=1)) # Properly set the network training device if self.device is None: self.device = get_device_from_fit_dictionary(X) self.to(self.device) - if STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']] in CLASSIFICATION_TASKS: - self.final_activation = nn.Softmax(dim=1) - self.is_fitted_ = True return self diff --git a/autoPyTorch/pipeline/components/training/metrics/base.py b/autoPyTorch/pipeline/components/training/metrics/base.py index 0cac3c560..4f9037cd8 100644 --- a/autoPyTorch/pipeline/components/training/metrics/base.py +++ b/autoPyTorch/pipeline/components/training/metrics/base.py @@ -173,7 +173,7 @@ def __call__( Score function applied to prediction of estimator on X. """ y_type = type_of_target(y_true) - if y_type not in ("binary", "multilabel-indicator"): + if y_type not in ("binary", "multilabel-indicator") and self.name != 'roc_auc': raise ValueError("{0} format is not supported".format(y_type)) if y_type == "binary": diff --git a/autoPyTorch/pipeline/components/training/metrics/metrics.py b/autoPyTorch/pipeline/components/training/metrics/metrics.py index 5fa60a24d..ed0c068f2 100644 --- a/autoPyTorch/pipeline/components/training/metrics/metrics.py +++ b/autoPyTorch/pipeline/components/training/metrics/metrics.py @@ -57,7 +57,7 @@ # Score functions that need decision values -roc_auc = make_metric('roc_auc', sklearn.metrics.roc_auc_score, needs_threshold=True) +roc_auc = make_metric('roc_auc', sklearn.metrics.roc_auc_score, needs_threshold=True, multi_class= 'ovo') average_precision = make_metric('average_precision', sklearn.metrics.average_precision_score, needs_threshold=True) diff --git a/autoPyTorch/pipeline/components/training/metrics/utils.py b/autoPyTorch/pipeline/components/training/metrics/utils.py index e72c1afce..2a4865aa5 100644 --- a/autoPyTorch/pipeline/components/training/metrics/utils.py +++ b/autoPyTorch/pipeline/components/training/metrics/utils.py @@ -99,8 +99,8 @@ def get_metrics(dataset_properties: Dict[str, Any], if names is not None: for name in names: if name not in supported_metrics.keys(): - raise ValueError("Invalid name entered for task {}, currently " - "supported metrics for task include {}".format(dataset_properties['task_type'], + raise ValueError("Invalid name {} entered for task {}, currently " + "supported metrics for task include {}".format(name, dataset_properties['task_type'], list(supported_metrics.keys()))) else: metric = supported_metrics[name] From ac3c31086fe371b2bcd87926ec67a4682c0bdfe7 Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Mon, 17 Oct 2022 14:26:20 +0200 Subject: [PATCH 5/5] remove unnecessary code --- autoPyTorch/api/tabular_regression.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/autoPyTorch/api/tabular_regression.py b/autoPyTorch/api/tabular_regression.py index d6b14cd9b..d6c30aa3a 100644 --- a/autoPyTorch/api/tabular_regression.py +++ b/autoPyTorch/api/tabular_regression.py @@ -483,15 +483,3 @@ def predict( # Allow to predict in the original domain -- that is, the user is not interested # in our encoded values return self.input_validator.target_validator.inverse_transform(predicted_values) - - def score( - self, - y_pred: np.ndarray, - y_test: Union[np.ndarray, pd.DataFrame] - ) -> Dict[str, float]: - if self.input_validator is None or not self.input_validator._is_fitted: - raise ValueError("predict() is only supported after calling search. Kindly call first " - "the estimator search() method.") - y_pred = self.input_validator.target_validator.transform(y_pred) - y_test = self.input_validator.target_validator.transform(y_test) - return super().score(y_pred=y_pred, y_test=y_test) \ No newline at end of file