diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index c0548012d..8f4dbe9e6 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -551,7 +551,30 @@ def _do_dummy_prediction(self, num_run: int) -> None: % (str(status), str(additional_info)) ) - def _do_traditional_prediction(self, num_run: int, time_for_traditional: int) -> int: + def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_time_limit_secs: int + ) -> int: + """ + Fits traditional machine learning algorithms to the provided dataset, while + complying with time resource allocation. + + This method currently only supports classification. + + Args: + num_run: (int) + An identifier to indicate the current machine learning algorithm + being processed + time_left: (int) + Hard limit on how many machine learning algorithms can be fit. Depending on how + fast a traditional machine learning algorithm trains, it will allow multiple + models to be fitted. + func_eval_time_limit_secs: (int) + Maximum training time each algorithm is allowed to take, during training + + Returns: + num_run: (int) + The incremented identifier index. This depends on how many machine learning + models were fitted. + """ # Mypy Checkings -- Traditional prediction is only called for search # where the following objects are created @@ -559,81 +582,119 @@ def _do_traditional_prediction(self, num_run: int, time_for_traditional: int) -> assert self._logger is not None assert self._dask_client is not None - self._logger.info("Starting to create dummy predictions.") - memory_limit = self._memory_limit if memory_limit is not None: memory_limit = int(math.ceil(memory_limit)) available_classifiers = get_available_classifiers() - dask_futures = list() - time_for_traditional_classifier_sec = int(time_for_traditional / len(available_classifiers)) + dask_futures = [] + + total_number_classifiers = len(available_classifiers) + num_run for n_r, classifier in enumerate(available_classifiers, start=num_run): + + # Only launch a task if there is time start_time = time.time() - scenario_mock = unittest.mock.Mock() - scenario_mock.wallclock_limit = time_for_traditional_classifier_sec - # This stats object is a hack - maybe the SMAC stats object should - # already be generated here! - stats = Stats(scenario_mock) - stats.start_timing() - ta = ExecuteTaFuncWithQueue( - backend=self._backend, - seed=self.seed, - metric=self._metric, - logger_port=self._logger_port, - cost_for_crash=get_cost_of_crash(self._metric), - abort_on_first_run_crash=False, - initial_num_run=num_run, - stats=stats, - memory_limit=memory_limit, - disable_file_output=True if len(self._disable_file_output) > 0 else False, - all_supported_metrics=self._all_supported_metrics - ) - dask_futures.append((classifier, self._dask_client.submit(ta.run, config=classifier, - cutoff=time_for_traditional_classifier_sec))) + if time_left >= func_eval_time_limit_secs: + self._logger.info(f"{n_r}: Started fitting {classifier} with cutoff={func_eval_time_limit_secs}") + scenario_mock = unittest.mock.Mock() + scenario_mock.wallclock_limit = time_left + # This stats object is a hack - maybe the SMAC stats object should + # already be generated here! + stats = Stats(scenario_mock) + stats.start_timing() + ta = ExecuteTaFuncWithQueue( + backend=self._backend, + seed=self.seed, + metric=self._metric, + logger_port=self._logger_port, + cost_for_crash=get_cost_of_crash(self._metric), + abort_on_first_run_crash=False, + initial_num_run=n_r, + stats=stats, + memory_limit=memory_limit, + disable_file_output=True if len(self._disable_file_output) > 0 else False, + all_supported_metrics=self._all_supported_metrics + ) + dask_futures.append([ + classifier, + self._dask_client.submit( + ta.run, config=classifier, + cutoff=func_eval_time_limit_secs, + ) + ]) + + # Increment the launched job index + num_run = n_r + + # When managing time, we need to take into account the allocated time resources, + # which are dependent on the number of cores. 'dask_futures' is a proxy to the number + # of workers /n_jobs that we have, in that if there are 4 cores allocated, we can run at most + # 4 task in parallel. Every 'cutoff' seconds, we generate up to 4 tasks. + # If we only have 4 workers and there are 4 futures in dask_futures, it means that every + # worker has a task. We would not like to launch another job until a worker is available. To this + # end, the following if-statement queries the number of active jobs, and forces to wait for a job + # completion via future.result(), so that a new worker is available for the next iteration. + if len(dask_futures) >= self.n_jobs: + + # How many workers to wait before starting fitting the next iteration + workers_to_wait = 1 + if n_r >= total_number_classifiers - 1 or time_left <= func_eval_time_limit_secs: + # If on the last iteration, flush out all tasks + workers_to_wait = len(dask_futures) + + while workers_to_wait >= 1: + workers_to_wait -= 1 + # We launch dask jobs only when there are resources available. + # This allow us to control time allocation properly, and early terminate + # the traditional machine learning pipeline + cls, future = dask_futures.pop(0) + status, cost, runtime, additional_info = future.result() + if status == StatusType.SUCCESS: + self._logger.info( + f"Fitting {cls} took {runtime}s, performance:{cost}/{additional_info}") + else: + if additional_info.get('exitcode') == -6: + self._logger.error( + "Traditional prediction for %s failed with run state %s. " + "The error suggests that the provided memory limits were too tight. Please " + "increase the 'ml_memory_limit' and try again. If this does not solve your " + "problem, please open an issue and paste the additional output. " + "Additional output: %s.", + cls, str(status), str(additional_info), + ) + else: + self._logger.error( + "Traditional prediction for %s failed with run state %s and additional output: %s.", + cls, str(status), str(additional_info), + ) # In the case of a serial execution, calling submit halts the run for a resource # dynamically adjust time in this case - time_for_traditional_classifier_sec -= int(time.time() - start_time) - num_run = n_r + time_left -= int(time.time() - start_time) + + # Exit if no more time is available for a new classifier + if time_left < func_eval_time_limit_secs: + self._logger.warning("Not enough time to fit all traditional machine learning models." + "Please consider increasing the run time to further improve performance.") + break - for (classifier, future) in dask_futures: - status, cost, runtime, additional_info = future.result() - if status == StatusType.SUCCESS: - self._logger.info("Finished creating predictions for {}".format(classifier)) - else: - if additional_info.get('exitcode') == -6: - self._logger.error( - "Traditional prediction for %s failed with run state %s. " - "The error suggests that the provided memory limits were too tight. Please " - "increase the 'ml_memory_limit' and try again. If this does not solve your " - "problem, please open an issue and paste the additional output. " - "Additional output: %s.", - classifier, str(status), str(additional_info), - ) - else: - # TODO: add check for timeout, and provide feedback to user to consider increasing the time limit - self._logger.error( - "Traditional prediction for %s failed with run state %s and additional output: %s.", - classifier, str(status), str(additional_info), - ) return num_run def _search( - self, - optimize_metric: str, - dataset: BaseDataset, - budget_type: Optional[str] = None, - budget: Optional[float] = None, - total_walltime_limit: int = 100, - func_eval_time_limit: int = 60, - traditional_per_total_budget: float = 0.1, - memory_limit: Optional[int] = 4096, - smac_scenario_args: Optional[Dict[str, Any]] = None, - get_smac_object_callback: Optional[Callable] = None, - all_supported_metrics: bool = True, - precision: int = 32, - disable_file_output: List = [], - load_models: bool = True, + self, + optimize_metric: str, + dataset: BaseDataset, + budget_type: Optional[str] = None, + budget: Optional[float] = None, + total_walltime_limit: int = 100, + func_eval_time_limit_secs: Optional[int] = None, + enable_traditional_pipeline: bool = True, + memory_limit: Optional[int] = 4096, + smac_scenario_args: Optional[Dict[str, Any]] = None, + get_smac_object_callback: Optional[Callable] = None, + all_supported_metrics: bool = True, + precision: int = 32, + disable_file_output: List = [], + load_models: bool = True, ) -> 'BaseTask': """ Search for the best pipeline configuration for the given dataset. @@ -660,16 +721,24 @@ def _search( in seconds for the search of appropriate models. By increasing this value, autopytorch has a higher chance of finding better models. - func_eval_time_limit (int), (default=60): Time limit + func_eval_time_limit_secs (int), (default=None): Time limit for a single call to the machine learning model. Model fitting will be terminated if the machine learning algorithm runs over the time limit. Set this value high enough so that typical machine learning algorithms can be fit on the training data. - traditional_per_total_budget (float), (default=0.1): - Percent of total walltime to be allocated for - running traditional classifiers. + When set to None, this time will automatically be set to + total_walltime_limit // 2 to allow enough time to fit + at least 2 individual machine learning algorithms. + Set to np.inf in case no time limit is desired. + enable_traditional_pipeline (bool), (default=True): + We fit traditional machine learning algorithms + (LightGBM, CatBoost, RandomForest, ExtraTrees, KNN, SVM) + prior building PyTorch Neural Networks. You can disable this + feature by turning this flag to False. All machine learning + algorithms that are fitted during search() are considered for + ensemble building. memory_limit (Optional[int]), (default=4096): Memory limit in MB for the machine learning algorithm. autopytorch will stop fitting the machine learning algorithm if it tries @@ -755,6 +824,28 @@ def _search( else: self._is_dask_client_internally_created = False + # Handle time resource allocation + elapsed_time = self._stopwatch.wall_elapsed(experiment_task_name) + time_left_for_modelfit = int(max(0, total_walltime_limit - elapsed_time)) + if func_eval_time_limit_secs is None or func_eval_time_limit_secs > time_left_for_modelfit: + self._logger.warning( + 'Time limit for a single run is higher than total time ' + 'limit. Capping the limit for a single run to the total ' + 'time given to SMAC (%f)' % time_left_for_modelfit + ) + func_eval_time_limit_secs = time_left_for_modelfit + + # Make sure that at least 2 models are created for the ensemble process + num_models = time_left_for_modelfit // func_eval_time_limit_secs + if num_models < 2: + func_eval_time_limit_secs = time_left_for_modelfit // 2 + self._logger.warning( + "Capping the func_eval_time_limit_secs to {} to have " + "time for a least 2 models to ensemble.".format( + func_eval_time_limit_secs + ) + ) + # ============> Run dummy predictions num_run = 1 dummy_task_name = 'runDummy' @@ -764,16 +855,22 @@ def _search( # ============> Run traditional ml - traditional_task_name = 'runTraditional' - self._stopwatch.start_task(traditional_task_name) - elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name) - time_for_traditional = int(traditional_per_total_budget * max(0, (self._time_for_task - elapsed_time))) - if time_for_traditional <= 0: - if traditional_per_total_budget > 0: - raise ValueError("Not enough time allocated to run traditional algorithms") - elif traditional_per_total_budget != 0: - num_run = self._do_traditional_prediction(num_run=num_run + 1, time_for_traditional=time_for_traditional) - self._stopwatch.stop_task(traditional_task_name) + if enable_traditional_pipeline: + if STRING_TO_TASK_TYPES[self.task_type] in REGRESSION_TASKS: + self._logger.warning("Traditional Pipeline is not enabled for regression. Skipping...") + else: + traditional_task_name = 'runTraditional' + self._stopwatch.start_task(traditional_task_name) + elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name) + # We want time for at least 1 Neural network in SMAC + time_for_traditional = int( + self._time_for_task - elapsed_time - func_eval_time_limit_secs + ) + num_run = self._do_traditional_prediction( + num_run=num_run + 1, func_eval_time_limit_secs=func_eval_time_limit_secs, + time_left=time_for_traditional, + ) + self._stopwatch.stop_task(traditional_task_name) # ============> Starting ensemble elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name) @@ -830,7 +927,7 @@ def _search( dataset_name=dataset.dataset_name, backend=self._backend, total_walltime_limit=total_walltime_limit, - func_eval_time_limit=func_eval_time_limit, + func_eval_time_limit_secs=func_eval_time_limit_secs, dask_client=self._dask_client, memory_limit=self._memory_limit, n_jobs=self.n_jobs, diff --git a/autoPyTorch/api/tabular_classification.py b/autoPyTorch/api/tabular_classification.py index 04d30dd9e..deeb5244b 100644 --- a/autoPyTorch/api/tabular_classification.py +++ b/autoPyTorch/api/tabular_classification.py @@ -122,8 +122,8 @@ def search( budget_type: Optional[str] = None, budget: Optional[float] = None, total_walltime_limit: int = 100, - func_eval_time_limit: int = 60, - traditional_per_total_budget: float = 0.1, + func_eval_time_limit_secs: Optional[int] = None, + enable_traditional_pipeline: bool = True, memory_limit: Optional[int] = 4096, smac_scenario_args: Optional[Dict[str, Any]] = None, get_smac_object_callback: Optional[Callable] = None, @@ -156,16 +156,24 @@ def search( in seconds for the search of appropriate models. By increasing this value, autopytorch has a higher chance of finding better models. - func_eval_time_limit (int), (default=60): Time limit + func_eval_time_limit_secs (int), (default=None): Time limit for a single call to the machine learning model. Model fitting will be terminated if the machine learning algorithm runs over the time limit. Set this value high enough so that typical machine learning algorithms can be fit on the training data. - traditional_per_total_budget (float), (default=0.1): - Percent of total walltime to be allocated for - running traditional classifiers. + When set to None, this time will automatically be set to + total_walltime_limit // 2 to allow enough time to fit + at least 2 individual machine learning algorithms. + Set to np.inf in case no time limit is desired. + enable_traditional_pipeline (bool), (default=True): + We fit traditional machine learning algorithms + (LightGBM, CatBoost, RandomForest, ExtraTrees, KNN, SVM) + before building PyTorch Neural Networks. You can disable this + feature by turning this flag to False. All machine learning + algorithms that are fitted during search() are considered for + ensemble building. memory_limit (Optional[int]), (default=4096): Memory limit in MB for the machine learning algorithm. autopytorch will stop fitting the machine learning algorithm if it tries @@ -228,8 +236,8 @@ def search( budget_type=budget_type, budget=budget, total_walltime_limit=total_walltime_limit, - func_eval_time_limit=func_eval_time_limit, - traditional_per_total_budget=traditional_per_total_budget, + func_eval_time_limit_secs=func_eval_time_limit_secs, + enable_traditional_pipeline=enable_traditional_pipeline, memory_limit=memory_limit, smac_scenario_args=smac_scenario_args, get_smac_object_callback=get_smac_object_callback, diff --git a/autoPyTorch/api/tabular_regression.py b/autoPyTorch/api/tabular_regression.py index 394a7230f..afef8ce9f 100644 --- a/autoPyTorch/api/tabular_regression.py +++ b/autoPyTorch/api/tabular_regression.py @@ -103,26 +103,27 @@ def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, An def build_pipeline(self, dataset_properties: Dict[str, Any]) -> TabularRegressionPipeline: return TabularRegressionPipeline(dataset_properties=dataset_properties) - def search(self, - optimize_metric: str, - X_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, - y_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, - X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, - y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, - dataset_name: Optional[str] = None, - budget_type: Optional[str] = None, - budget: Optional[float] = None, - total_walltime_limit: int = 100, - func_eval_time_limit: int = 60, - traditional_per_total_budget: float = 0.1, - memory_limit: Optional[int] = 4096, - smac_scenario_args: Optional[Dict[str, Any]] = None, - get_smac_object_callback: Optional[Callable] = None, - all_supported_metrics: bool = True, - precision: int = 32, - disable_file_output: List = [], - load_models: bool = True, - ) -> 'BaseTask': + def search( + self, + optimize_metric: str, + X_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, + y_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, + X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, + y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, + dataset_name: Optional[str] = None, + budget_type: Optional[str] = None, + budget: Optional[float] = None, + total_walltime_limit: int = 100, + func_eval_time_limit_secs: Optional[int] = None, + enable_traditional_pipeline: bool = False, + memory_limit: Optional[int] = 4096, + smac_scenario_args: Optional[Dict[str, Any]] = None, + get_smac_object_callback: Optional[Callable] = None, + all_supported_metrics: bool = True, + precision: int = 32, + disable_file_output: List = [], + load_models: bool = True, + ) -> 'BaseTask': """ Search for the best pipeline configuration for the given dataset. @@ -147,16 +148,20 @@ def search(self, in seconds for the search of appropriate models. By increasing this value, autopytorch has a higher chance of finding better models. - func_eval_time_limit (int), (default=60): Time limit + func_eval_time_limit_secs (int), (default=None): Time limit for a single call to the machine learning model. Model fitting will be terminated if the machine learning algorithm runs over the time limit. Set this value high enough so that typical machine learning algorithms can be fit on the training data. - traditional_per_total_budget (float), (default=0.1): - Percent of total walltime to be allocated for - running traditional classifiers. + When set to None, this time will automatically be set to + total_walltime_limit // 2 to allow enough time to fit + at least 2 individual machine learning algorithms. + Set to np.inf in case no time limit is desired. + enable_traditional_pipeline (bool), (default=False): + Not enabled for regression. This flag is here to comply + with the API. memory_limit (Optional[int]), (default=4096): Memory limit in MB for the machine learning algorithm. autopytorch will stop fitting the machine learning algorithm if it tries @@ -219,8 +224,8 @@ def search(self, budget_type=budget_type, budget=budget, total_walltime_limit=total_walltime_limit, - func_eval_time_limit=func_eval_time_limit, - traditional_per_total_budget=traditional_per_total_budget, + func_eval_time_limit_secs=func_eval_time_limit_secs, + enable_traditional_pipeline=enable_traditional_pipeline, memory_limit=memory_limit, smac_scenario_args=smac_scenario_args, get_smac_object_callback=get_smac_object_callback, diff --git a/autoPyTorch/evaluation/abstract_evaluator.py b/autoPyTorch/evaluation/abstract_evaluator.py index 37990d779..3c6bba258 100644 --- a/autoPyTorch/evaluation/abstract_evaluator.py +++ b/autoPyTorch/evaluation/abstract_evaluator.py @@ -377,8 +377,9 @@ def _loss(self, y_true: np.ndarray, y_hat: np.ndarray) -> Dict[str, float]: """ - if not isinstance(self.configuration, Configuration): - return {self.metric.name: self.metric._worst_possible_result} + if isinstance(self.configuration, int): + # We do not calculate performance of the dummy configurations + return {self.metric.name: self.metric._optimum - self.metric._sign * self.metric._worst_possible_result} if self.additional_metrics is not None: metrics = self.additional_metrics diff --git a/autoPyTorch/evaluation/tae.py b/autoPyTorch/evaluation/tae.py index fdb33d9a0..b4f53b6e4 100644 --- a/autoPyTorch/evaluation/tae.py +++ b/autoPyTorch/evaluation/tae.py @@ -442,7 +442,7 @@ def run( empty_queue(queue) self.logger.debug( - 'Finished function evaluation. Status: %s, Cost: %f, Runtime: %f, Additional %s', - status, cost, runtime, additional_run_info, + 'Finished function evaluation %s. Status: %s, Cost: %f, Runtime: %f, Additional %s', + str(num_run), status, cost, runtime, additional_run_info, ) return status, cost, runtime, additional_run_info diff --git a/autoPyTorch/evaluation/train_evaluator.py b/autoPyTorch/evaluation/train_evaluator.py index 0945ff9d6..88b5e81da 100644 --- a/autoPyTorch/evaluation/train_evaluator.py +++ b/autoPyTorch/evaluation/train_evaluator.py @@ -110,6 +110,10 @@ def fit_predict_and_loss(self) -> None: status = StatusType.SUCCESS + self.logger.debug("In train evaluator fit_predict_and_loss, num_run: {} loss:{}".format( + self.num_run, + loss + )) self.finish_up( loss=loss, train_loss=train_loss, @@ -236,7 +240,10 @@ def fit_predict_and_loss(self) -> None: self.pipeline = self._get_pipeline() status = StatusType.SUCCESS - self.logger.debug("In train evaluator fit_predict_and_loss, loss:{}".format(opt_loss)) + self.logger.debug("In train evaluator fit_predict_and_loss, num_run: {} loss:{}".format( + self.num_run, + opt_loss + )) self.finish_up( loss=opt_loss, train_loss=train_loss, diff --git a/autoPyTorch/optimizer/smbo.py b/autoPyTorch/optimizer/smbo.py index c2f20f07f..c00965bbb 100644 --- a/autoPyTorch/optimizer/smbo.py +++ b/autoPyTorch/optimizer/smbo.py @@ -84,7 +84,7 @@ def __init__(self, dataset_name: str, backend: Backend, total_walltime_limit: float, - func_eval_time_limit: float, + func_eval_time_limit_secs: float, memory_limit: typing.Optional[int], metric: autoPyTorchMetric, watcher: StopWatch, @@ -120,7 +120,7 @@ def __init__(self, An interface with disk total_walltime_limit (float): The maximum allowed time for this job - func_eval_time_limit (float): + func_eval_time_limit_secs (float): How much each individual task is allowed to last memory_limit (typing.Optional[int]): Maximum allowed CPU memory this task can use @@ -180,7 +180,7 @@ def __init__(self, # and a bunch of useful limits self.worst_possible_result = get_cost_of_crash(self.metric) self.total_walltime_limit = int(total_walltime_limit) - self.func_eval_time_limit = int(func_eval_time_limit) + self.func_eval_time_limit_secs = int(func_eval_time_limit_secs) self.memory_limit = memory_limit self.watcher = watcher self.seed = seed @@ -265,7 +265,7 @@ def run_smbo(self, func: typing.Optional[typing.Callable] = None scenario_dict = { 'abort_on_first_run_crash': False, 'cs': self.config_space, - 'cutoff_time': self.func_eval_time_limit, + 'cutoff_time': self.func_eval_time_limit_secs, 'deterministic': 'true', 'instances': instances, 'memory_limit': self.memory_limit, diff --git a/autoPyTorch/pipeline/components/setup/traditional_ml/classifier_models/__init__.py b/autoPyTorch/pipeline/components/setup/traditional_ml/classifier_models/__init__.py index c973fef00..0187df10d 100644 --- a/autoPyTorch/pipeline/components/setup/traditional_ml/classifier_models/__init__.py +++ b/autoPyTorch/pipeline/components/setup/traditional_ml/classifier_models/__init__.py @@ -13,12 +13,22 @@ SVMModel) _classifiers = { + # Sort by more robust models + # Depending on the allocated time budget, only the + # top models from this dict are two be fitted. + # LGBM is the more robust model, with + # internal measures to prevent crashes, overfit + # Additionally, it is one of the state of the art + # methods for tabular prediction. + # Then follow with catboost for categorical heavy + # datasets. The other models are complementary and + # their ordering is not critical + 'lgb': LGBModel, 'catboost': CatboostModel, + 'random_forest': RFModel, 'extra_trees': ExtraTreesModel, + 'svm_classifier': SVMModel, 'knn_classifier': KNNModel, - 'lgb': LGBModel, - 'random_forest': RFModel, - 'svm_classifier': SVMModel } _addons = ThirdPartyComponents(BaseClassifier) diff --git a/autoPyTorch/pipeline/traditional_tabular_classification.py b/autoPyTorch/pipeline/traditional_tabular_classification.py index 8ef19a38e..1ac6aac50 100644 --- a/autoPyTorch/pipeline/traditional_tabular_classification.py +++ b/autoPyTorch/pipeline/traditional_tabular_classification.py @@ -211,8 +211,13 @@ def get_pipeline_representation(self) -> Dict[str, str]: """ estimator_name = 'TraditionalTabularClassification' if self.steps[0][1].choice is not None: - estimator_name = cast(str, - self.steps[0][1].choice.model.get_properties()['shortname']) + if self.steps[0][1].choice.model is None: + estimator_name = self.steps[0][1].choice.__class__.__name__ + else: + estimator_name = cast( + str, + self.steps[0][1].choice.model.get_properties()['shortname'] + ) return { 'Preprocessing': 'None', 'Estimator': estimator_name, diff --git a/examples/tabular/20_basics/example_tabular_classification.py b/examples/tabular/20_basics/example_tabular_classification.py index 047f01842..1e5b08cac 100644 --- a/examples/tabular/20_basics/example_tabular_classification.py +++ b/examples/tabular/20_basics/example_tabular_classification.py @@ -57,7 +57,7 @@ y_test=y_test.copy(), optimize_metric='accuracy', total_walltime_limit=300, - func_eval_time_limit=50 + func_eval_time_limit_secs=50 ) ############################################################################ diff --git a/examples/tabular/20_basics/example_tabular_regression.py b/examples/tabular/20_basics/example_tabular_regression.py index 7bd48155f..d2afedd52 100644 --- a/examples/tabular/20_basics/example_tabular_regression.py +++ b/examples/tabular/20_basics/example_tabular_regression.py @@ -66,8 +66,8 @@ y_test=y_test_scaled.copy(), optimize_metric='r2', total_walltime_limit=300, - func_eval_time_limit=50, - traditional_per_total_budget=0 + func_eval_time_limit_secs=50, + enable_traditional_pipeline=False, ) ############################################################################ diff --git a/examples/tabular/40_advanced/example_custom_configuration_space.py b/examples/tabular/40_advanced/example_custom_configuration_space.py index 772c268b9..6a3764b94 100644 --- a/examples/tabular/40_advanced/example_custom_configuration_space.py +++ b/examples/tabular/40_advanced/example_custom_configuration_space.py @@ -88,7 +88,7 @@ def get_search_space_updates(): y_test=y_test.copy(), optimize_metric='accuracy', total_walltime_limit=300, - func_eval_time_limit=50 + func_eval_time_limit_secs=50 ) ############################################################################ @@ -119,7 +119,7 @@ def get_search_space_updates(): y_test=y_test.copy(), optimize_metric='accuracy', total_walltime_limit=300, - func_eval_time_limit=50 + func_eval_time_limit_secs=50 ) ############################################################################ diff --git a/examples/tabular/40_advanced/example_resampling_strategy.py b/examples/tabular/40_advanced/example_resampling_strategy.py index 5217afab2..9c6c00959 100644 --- a/examples/tabular/40_advanced/example_resampling_strategy.py +++ b/examples/tabular/40_advanced/example_resampling_strategy.py @@ -66,7 +66,7 @@ y_test=y_test.copy(), optimize_metric='accuracy', total_walltime_limit=150, - func_eval_time_limit=30 + func_eval_time_limit_secs=30 ) ############################################################################ @@ -104,7 +104,7 @@ y_test=y_test.copy(), optimize_metric='accuracy', total_walltime_limit=150, - func_eval_time_limit=30 + func_eval_time_limit_secs=30 ) ############################################################################ @@ -145,7 +145,7 @@ y_test=y_test.copy(), optimize_metric='accuracy', total_walltime_limit=150, - func_eval_time_limit=30 + func_eval_time_limit_secs=30 ) ############################################################################ diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index d30593bb0..6af387298 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -62,8 +62,8 @@ def test_tabular_classification(openml_id, resampling_strategy, backend): X_test=X_test, y_test=y_test, optimize_metric='accuracy', total_walltime_limit=150, - func_eval_time_limit=50, - traditional_per_total_budget=0 + func_eval_time_limit_secs=50, + enable_traditional_pipeline=False, ) # Internal dataset has expected settings @@ -230,8 +230,8 @@ def test_tabular_regression(openml_name, resampling_strategy, backend): X_test=X_test, y_test=y_test, optimize_metric='r2', total_walltime_limit=50, - func_eval_time_limit=10, - traditional_per_total_budget=0 + func_eval_time_limit_secs=10, + enable_traditional_pipeline=False, ) # Internal dataset has expected settings @@ -390,7 +390,7 @@ def test_tabular_input_support(openml_id, backend): X_test=X_test, y_test=y_test, optimize_metric='accuracy', total_walltime_limit=150, - func_eval_time_limit=50, - traditional_per_total_budget=0, + func_eval_time_limit_secs=50, + enable_traditional_pipeline=False, load_models=False, ) diff --git a/test/test_pipeline/test_tabular_classification.py b/test/test_pipeline/test_tabular_classification.py index ef508dc7b..da04be256 100644 --- a/test/test_pipeline/test_tabular_classification.py +++ b/test/test_pipeline/test_tabular_classification.py @@ -144,7 +144,10 @@ def test_pipeline_predict_proba(self, fit_dictionary_tabular): config = cs.sample_configuration() pipeline.set_hyperparameters(config) - pipeline.fit(fit_dictionary_tabular) + try: + pipeline.fit(fit_dictionary_tabular) + except Exception as e: + pytest.fail(f"Failed on config={config} with {e}") # we expect the output to have the same batch size as the test input, # and number of outputs per batch sample equal to the number of classes ("num_classes" in dataset_properties)