Skip to content

[ADD] Post-Hoc ensemble fitting #260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 33 additions & 20 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,25 @@ def _close_dask_client(self) -> None:
self._is_dask_client_internally_created = False
del self._is_dask_client_internally_created

def _cleanup(self) -> None:
"""

Closes the different servers created during api search.

Returns:
None
"""
if self._logger is not None:
self._logger.info("Closing the dask infrastructure")
self._close_dask_client()
self._logger.info("Finished closing the dask infrastructure")

# Clean up the logger
self._logger.info("Starting to clean up the logger")
self._clean_logger()
else:
self._close_dask_client()

def _load_models(self) -> bool:

"""
Expand Down Expand Up @@ -1017,18 +1036,12 @@ def _search(
pd.DataFrame(self.ensemble_performance_history).to_json(
os.path.join(self._backend.internals_directory, 'ensemble_history.json'))

self._logger.info("Closing the dask infrastructure")
self._close_dask_client()
self._logger.info("Finished closing the dask infrastructure")

if load_models:
self._logger.info("Loading models...")
self._load_models()
self._logger.info("Finished loading models...")

# Clean up the logger
self._logger.info("Starting to clean up the logger")
self._clean_logger()
self._cleanup()

return self

Expand Down Expand Up @@ -1103,7 +1116,7 @@ def refit(
# the ordering of the data.
fit_and_suppress_warnings(self._logger, model, X, y=None)

self._clean_logger()
self._cleanup()

return self

Expand Down Expand Up @@ -1168,14 +1181,15 @@ def fit(self,

fit_and_suppress_warnings(self._logger, pipeline, X, y=None)

self._clean_logger()
self._cleanup()
return pipeline

def fit_ensemble(
self,
ensemble_nbest: Optional[int] = None,
ensemble_nbest: int = 50,
ensemble_size: int = 50,
precision: int = 32
precision: int = 32,
load_models: bool = True
) -> 'BaseTask':
"""
Enables post-hoc fitting of the ensemble after the `search()`
Expand Down Expand Up @@ -1215,7 +1229,7 @@ def fit_ensemble(
optimize_metric=self.opt_metric,
precision=precision,
ensemble_size=ensemble_size,
ensemble_nbest=ensemble_nbest if ensemble_nbest is not None else self.ensemble_nbest,
ensemble_nbest=ensemble_nbest,
)

manager.build_ensemble(self._dask_client)
Expand All @@ -1226,8 +1240,9 @@ def fit_ensemble(
" check the log file and command line output for error messages.")
self.ensemble_performance_history, _, _, _ = result

self._load_models()
self._close_dask_client()
if load_models:
self._load_models()
self._cleanup()
return self

def _init_ensemble_builder(
Expand Down Expand Up @@ -1259,7 +1274,8 @@ def _init_ensemble_builder(
EnsembleBuilderManager

"""
assert self._logger is not None, "logger should be initialised to fit ensemble"
if self._logger is None:
raise ValueError("logger should be initialised to fit ensemble")
if self.dataset is None:
raise ValueError("ensemble can only be initialised after or during `search()`. "
"Please call the `search()` method of {}.".format(self.__class__.__name__))
Expand Down Expand Up @@ -1346,7 +1362,7 @@ def predict(

predictions = self.ensemble_.predict(all_predictions)

self._clean_logger()
self._cleanup()

return predictions

Expand Down Expand Up @@ -1383,10 +1399,7 @@ def __getstate__(self) -> Dict[str, Any]:
return self.__dict__

def __del__(self) -> None:
# Clean up the logger
self._clean_logger()

self._close_dask_client()
self._cleanup()

# When a multiprocessing work is done, the
# objects are deleted. We don't want to delete run areas
Expand Down
11 changes: 7 additions & 4 deletions test/test_api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,8 +494,7 @@ def test_do_dummy_prediction(dask_client, fit_dictionary_tabular):
with open(model_path, 'rb') as model_handler:
clone(pickle.load(model_handler))

estimator._close_dask_client()
estimator._clean_logger()
estimator._cleanup()

del estimator

Expand Down Expand Up @@ -694,8 +693,7 @@ def test_do_traditional_pipeline(fit_dictionary_tabular):
if not at_least_one_model_checked:
pytest.fail("Not even one single traditional pipeline was fitted")

estimator._close_dask_client()
estimator._clean_logger()
estimator._cleanup()

del estimator

Expand Down Expand Up @@ -758,6 +756,11 @@ def test_fit_ensemble(backend, n_samples, dataset_name):
estimator._backend.internals_directory, 'ensembles'))])
assert any(['ensemble_' or '_ensemble.npy' in os.listdir(estimator._backend.internals_directory)])

preds = estimator.predict(X_test)
assert isinstance(preds, np.ndarray)

assert len(estimator.ensemble_performance_history) > 0


@pytest.mark.parametrize('dataset_name', ('iris',))
def test_fit_ensemble_failure(backend, n_samples, dataset_name):
Expand Down