diff --git a/autoPyTorch/pipeline/components/training/trainer/__init__.py b/autoPyTorch/pipeline/components/training/trainer/__init__.py index e54006d10..1645c00cd 100755 --- a/autoPyTorch/pipeline/components/training/trainer/__init__.py +++ b/autoPyTorch/pipeline/components/training/trainer/__init__.py @@ -293,6 +293,13 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic writer=writer, ) + # its fine if train_loss is None due to `is_max_time_reached()` + if train_loss is None: + if self.budget_tracker.is_max_time_reached(): + break + else: + raise RuntimeError("Got an unexpected None in `train_loss`.") + val_loss, val_metrics, test_loss, test_metrics = None, {}, None, {} if self.eval_valid_each_epoch(X): val_loss, val_metrics = self.choice.evaluate(X['val_data_loader'], epoch, writer) @@ -334,6 +341,9 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic if 'cuda' in X['device']: torch.cuda.empty_cache() + if self.run_summary.is_empty(): + raise RuntimeError("Budget exhausted without finishing an epoch.") + # wrap up -- add score if not evaluating every epoch if not self.eval_valid_each_epoch(X): val_loss, val_metrics = self.choice.evaluate(X['val_data_loader'], epoch, writer) diff --git a/autoPyTorch/pipeline/components/training/trainer/base_trainer.py b/autoPyTorch/pipeline/components/training/trainer/base_trainer.py index 4909f56ce..6be283ebb 100644 --- a/autoPyTorch/pipeline/components/training/trainer/base_trainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/base_trainer.py @@ -179,6 +179,16 @@ def repr_last_epoch(self) -> str: string += '=' * 40 return string + def is_empty(self) -> bool: + """ + Checks if the object is empty or not + + Returns: + bool + """ + # if train_loss is empty, we can be sure that RunSummary is empty. + return not bool(self.performance_tracker['train_loss']) + class BaseTrainerComponent(autoPyTorchTrainingComponent): @@ -277,7 +287,7 @@ def _scheduler_step( def train_epoch(self, train_loader: torch.utils.data.DataLoader, epoch: int, writer: Optional[SummaryWriter], - ) -> Tuple[float, Dict[str, float]]: + ) -> Tuple[Optional[float], Dict[str, float]]: """ Train the model for a single epoch. @@ -317,6 +327,9 @@ def train_epoch(self, train_loader: torch.utils.data.DataLoader, epoch: int, epoch * len(train_loader) + step, ) + if N == 0: + return None, {} + self._scheduler_step(step_interval=StepIntervalUnit.epoch, loss=loss_sum / N) if self.metrics_during_training: diff --git a/test/test_pipeline/components/training/test_training.py b/test/test_pipeline/components/training/test_training.py index 8ae2759db..98bb748c4 100644 --- a/test/test_pipeline/components/training/test_training.py +++ b/test/test_pipeline/components/training/test_training.py @@ -236,6 +236,43 @@ def test_train_step(self): lr = optimizer.param_groups[0]['lr'] assert lr == target_lr + def test_train_epoch_no_step(self): + """ + This test checks if max runtime is reached + for an epoch before any train_step has been + completed. In this case we would like to + return None for train_loss and an empty + dictionary for the metrics. + """ + device = torch.device('cpu') + model = torch.nn.Linear(1, 1).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=1) + data_loader = unittest.mock.MagicMock(spec=torch.utils.data.DataLoader) + ms = [3, 5, 6] + params = { + 'metrics': [], + 'device': device, + 'task_type': constants.TABULAR_REGRESSION, + 'labels': torch.Tensor([]), + 'metrics_during_training': False, + 'budget_tracker': BudgetTracker(budget_type='runtime', max_runtime=0), + 'criterion': torch.nn.MSELoss, + 'optimizer': optimizer, + 'scheduler': torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=ms, gamma=2), + 'model': model, + 'step_interval': StepIntervalUnit.epoch + } + trainer = StandardTrainer() + trainer.prepare(**params) + + loss, metrics = trainer.train_epoch( + train_loader=data_loader, + epoch=0, + writer=None + ) + assert loss is None + assert metrics == {} + class TestStandardTrainer(BaseTraining): def test_regression_epoch_training(self, n_samples): diff --git a/test/test_pipeline/test_tabular_classification.py b/test/test_pipeline/test_tabular_classification.py index 52288b199..adfe3241b 100644 --- a/test/test_pipeline/test_tabular_classification.py +++ b/test/test_pipeline/test_tabular_classification.py @@ -1,6 +1,7 @@ import os import re import unittest +import unittest.mock from ConfigSpace.hyperparameters import ( CategoricalHyperparameter, @@ -491,3 +492,30 @@ def test_train_pipeline_with_runtime(fit_dictionary_tabular_dummy): # More than 200 epochs would have pass in 5 seconds for this dataset assert len(run_summary.performance_tracker['start_time']) > 100 + + +@pytest.mark.parametrize("fit_dictionary_tabular_dummy", ["classification"], indirect=True) +def test_train_pipeline_with_runtime_max_reached(fit_dictionary_tabular_dummy): + """ + This test makes sure that the pipeline raises an + error in case no epoch has finished successfully + due to max runtime reached + """ + + # Convert the training to runtime + fit_dictionary_tabular_dummy.pop('epochs', None) + fit_dictionary_tabular_dummy['budget_type'] = 'runtime' + fit_dictionary_tabular_dummy['runtime'] = 5 + fit_dictionary_tabular_dummy['early_stopping'] = -1 + + pipeline = TabularClassificationPipeline( + dataset_properties=fit_dictionary_tabular_dummy['dataset_properties']) + + cs = pipeline.get_hyperparameter_search_space() + config = cs.get_default_configuration() + pipeline.set_hyperparameters(config) + + with unittest.mock.patch('autoPyTorch.pipeline.components.training.trainer.BudgetTracker') as patch: + patch.is_max_time_reached.return_value = True + with pytest.raises(RuntimeError): + pipeline.fit(fit_dictionary_tabular_dummy)