Skip to content

Commit d68c691

Browse files
authored
Additional metrics during train (#194)
* Added additional metrics to fit dictionary * Added in test also
1 parent 7f9305d commit d68c691

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

autoPyTorch/evaluation/abstract_evaluator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,16 @@ def __init__(self, backend: Backend,
377377
'backend': self.backend,
378378
'logger_port': logger_port,
379379
})
380+
381+
# Update fit dictionary with metrics passed to the evaluator
382+
metrics_dict: Dict[str, List[str]] = {'additional_metrics': []}
383+
metrics_dict['additional_metrics'].append(self.metric.name)
384+
if all_supported_metrics:
385+
assert self.additional_metrics is not None
386+
for metric in self.additional_metrics:
387+
metrics_dict['additional_metrics'].append(metric.name)
388+
self.fit_dictionary.update(metrics_dict)
389+
380390
assert self.pipeline_class is not None, "Could not infer pipeline class"
381391
pipeline_config = pipeline_config if pipeline_config is not None \
382392
else self.pipeline_class.get_default_pipeline_options()

test/test_api/test_api.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,8 @@ def test_pipeline_fit(openml_id,
501501
run_time_limit_secs=50,
502502
budget_type='epochs',
503503
budget=30,
504-
disable_file_output=disable_file_output
504+
disable_file_output=disable_file_output,
505+
eval_metric='balanced_accuracy'
505506
)
506507
assert isinstance(dataset, BaseDataset)
507508
assert isinstance(run_info, RunInfo)
@@ -511,6 +512,7 @@ def test_pipeline_fit(openml_id,
511512
assert 'SUCCESS' in str(run_value.status)
512513

513514
if not disable_file_output:
515+
514516
if resampling_strategy in CrossValTypes:
515517
pytest.skip("Bug, Can't predict with cross validation pipeline")
516518
assert isinstance(pipeline, BaseEstimator)
@@ -522,11 +524,14 @@ def test_pipeline_fit(openml_id,
522524
assert isinstance(score, float)
523525
assert score > 0.8
524526
else:
525-
assert isinstance(pipeline, BasePipeline)
526527
# To make sure we fitted the model, there should be a
527-
# run summary object with accuracy
528+
# run summary object
528529
run_summary = pipeline.named_steps['trainer'].run_summary
529530
assert run_summary is not None
531+
# test to ensure balanced_accuracy is reported during training
532+
assert 'balanced_accuracy' in run_summary.performance_tracker['train_metrics'][1].keys()
533+
534+
assert isinstance(pipeline, BasePipeline)
530535
X_test = dataset.test_tensors[0]
531536
preds = pipeline.predict(X_test)
532537
assert isinstance(preds, np.ndarray)

0 commit comments

Comments
 (0)