16
16
from sklearn .preprocessing import label_binarize
17
17
from sklearn .utils .fixes import np_version
18
18
from sklearn .utils .validation import check_random_state
19
- from sklearn .utils .testing import (assert_allclose , assert_array_equal ,
20
- assert_no_warnings , assert_equal ,
21
- assert_raises , assert_warns_message ,
22
- ignore_warnings , assert_not_equal ,
23
- assert_raise_message )
24
- from sklearn .metrics import (accuracy_score , average_precision_score ,
25
- brier_score_loss , cohen_kappa_score ,
26
- jaccard_similarity_score , precision_score ,
27
- recall_score , roc_auc_score )
19
+ from sklearn .utils .testing import assert_allclose , assert_array_equal
20
+ from sklearn .utils .testing import assert_no_warnings , assert_raises
21
+ from sklearn .utils .testing import assert_warns_message , ignore_warnings
22
+ from sklearn .utils .testing import assert_raise_message
23
+ from sklearn .metrics import accuracy_score , average_precision_score
24
+ from sklearn .metrics import brier_score_loss , cohen_kappa_score
25
+ from sklearn .metrics import jaccard_similarity_score , precision_score
26
+ from sklearn .metrics import recall_score , roc_auc_score
28
27
29
28
from imblearn .metrics import sensitivity_specificity_support
30
29
from imblearn .metrics import sensitivity_score
33
32
from imblearn .metrics import make_index_balanced_accuracy
34
33
from imblearn .metrics import classification_report_imbalanced
35
34
35
+ from pytest import approx
36
+
36
37
RND_SEED = 42
37
38
R_TOL = 1e-2
38
39
@@ -113,11 +114,11 @@ def test_sensitivity_specificity_score_binary():
113
114
114
115
def test_sensitivity_specificity_f_binary_single_class ():
115
116
# Such a case may occur with non-stratified cross-validation
116
- assert_equal ( 1. , sensitivity_score ([1 , 1 ], [1 , 1 ]))
117
- assert_equal ( 0. , specificity_score ([1 , 1 ], [1 , 1 ]))
117
+ assert sensitivity_score ([1 , 1 ], [1 , 1 ]) == 1.
118
+ assert specificity_score ([1 , 1 ], [1 , 1 ]) == 0.
118
119
119
- assert_equal ( 0. , sensitivity_score ([- 1 , - 1 ], [- 1 , - 1 ]))
120
- assert_equal ( 0. , specificity_score ([- 1 , - 1 ], [- 1 , - 1 ]))
120
+ assert sensitivity_score ([- 1 , - 1 ], [- 1 , - 1 ]) == 0.
121
+ assert specificity_score ([- 1 , - 1 ], [- 1 , - 1 ]) == 0.
121
122
122
123
123
124
@ignore_warnings
@@ -166,9 +167,8 @@ def test_sensitivity_specificity_ignored_labels():
166
167
rtol = R_TOL )
167
168
168
169
# ensure the above were meaningful tests:
169
- for average in ['macro' , 'weighted' , 'micro' ]:
170
- assert_not_equal (
171
- specificity_13 (average = average ), specificity_all (average = average ))
170
+ for each in ['macro' , 'weighted' , 'micro' ]:
171
+ assert specificity_13 (average = each ) != specificity_all (average = each )
172
172
173
173
174
174
def test_sensitivity_specificity_error_multilabels ():
@@ -333,15 +333,15 @@ def test_classification_report_imbalanced_multiclass():
333
333
y_pred ,
334
334
labels = np .arange (len (iris .target_names )),
335
335
target_names = iris .target_names )
336
- assert_equal ( _format_report (report ), expected_report )
336
+ assert _format_report (report ) == expected_report
337
337
# print classification report with label detection
338
338
expected_report = ('pre rec spe f1 geo iba sup 0 0.83 0.79 0.92 0.81 '
339
339
'0.86 0.74 24 1 0.33 0.10 0.86 0.15 0.44 0.19 31 2 '
340
340
'0.42 0.90 0.55 0.57 0.63 0.37 20 avg / total 0.51 '
341
341
'0.53 0.80 0.47 0.62 0.41 75' )
342
342
343
343
report = classification_report_imbalanced (y_true , y_pred )
344
- assert_equal ( _format_report (report ), expected_report )
344
+ assert _format_report (report ) == expected_report
345
345
346
346
347
347
def test_classification_report_imbalanced_multiclass_with_digits ():
@@ -361,14 +361,14 @@ def test_classification_report_imbalanced_multiclass_with_digits():
361
361
labels = np .arange (len (iris .target_names )),
362
362
target_names = iris .target_names ,
363
363
digits = 5 )
364
- assert_equal ( _format_report (report ), expected_report )
364
+ assert _format_report (report ) == expected_report
365
365
# print classification report with label detection
366
366
expected_report = ('pre rec spe f1 geo iba sup 0 0.83 0.79 0.92 0.81 '
367
367
'0.86 0.74 24 1 0.33 0.10 0.86 0.15 0.44 0.19 31 2 '
368
368
'0.42 0.90 0.55 0.57 0.63 0.37 20 avg / total 0.51 '
369
369
'0.53 0.80 0.47 0.62 0.41 75' )
370
370
report = classification_report_imbalanced (y_true , y_pred )
371
- assert_equal ( _format_report (report ), expected_report )
371
+ assert _format_report (report ) == expected_report
372
372
373
373
374
374
def test_classification_report_imbalanced_multiclass_with_string_label ():
@@ -382,15 +382,15 @@ def test_classification_report_imbalanced_multiclass_with_string_label():
382
382
'0.19 31 red 0.42 0.90 0.55 0.57 0.63 0.37 20 '
383
383
'avg / total 0.51 0.53 0.80 0.47 0.62 0.41 75' )
384
384
report = classification_report_imbalanced (y_true , y_pred )
385
- assert_equal ( _format_report (report ), expected_report )
385
+ assert _format_report (report ) == expected_report
386
386
387
387
expected_report = ('pre rec spe f1 geo iba sup a 0.83 0.79 0.92 0.81 '
388
388
'0.86 0.74 24 b 0.33 0.10 0.86 0.15 0.44 0.19 31 '
389
389
'c 0.42 0.90 0.55 0.57 0.63 0.37 20 avg / total '
390
390
'0.51 0.53 0.80 0.47 0.62 0.41 75' )
391
391
report = classification_report_imbalanced (
392
392
y_true , y_pred , target_names = ["a" , "b" , "c" ])
393
- assert_equal ( _format_report (report ), expected_report )
393
+ assert _format_report (report ) == expected_report
394
394
395
395
396
396
def test_classification_report_imbalanced_multiclass_with_unicode_label ():
@@ -411,7 +411,7 @@ def test_classification_report_imbalanced_multiclass_with_unicode_label():
411
411
classification_report_imbalanced , y_true , y_pred )
412
412
else :
413
413
report = classification_report_imbalanced (y_true , y_pred )
414
- assert_equal ( _format_report (report ), expected_report )
414
+ assert _format_report (report ) == expected_report
415
415
416
416
417
417
def test_classification_report_imbalanced_multiclass_with_long_string_label ():
@@ -427,7 +427,7 @@ def test_classification_report_imbalanced_multiclass_with_long_string_label():
427
427
'0.37 20 avg / total 0.51 0.53 0.80 0.47 0.62 0.41 75' )
428
428
429
429
report = classification_report_imbalanced (y_true , y_pred )
430
- assert_equal ( _format_report (report ), expected_report )
430
+ assert _format_report (report ) == expected_report
431
431
432
432
433
433
def test_iba_sklearn_metrics ():
@@ -436,22 +436,22 @@ def test_iba_sklearn_metrics():
436
436
acc = make_index_balanced_accuracy (alpha = 0.5 , squared = True )(
437
437
accuracy_score )
438
438
score = acc (y_true , y_pred )
439
- assert_equal ( score , 0.54756 )
439
+ assert score == approx ( 0.54756 )
440
440
441
441
jss = make_index_balanced_accuracy (alpha = 0.5 , squared = True )(
442
442
jaccard_similarity_score )
443
443
score = jss (y_true , y_pred )
444
- assert_equal ( score , 0.54756 )
444
+ assert score == approx ( 0.54756 )
445
445
446
446
pre = make_index_balanced_accuracy (alpha = 0.5 , squared = True )(
447
447
precision_score )
448
448
score = pre (y_true , y_pred )
449
- assert_equal ( score , 0.65025 )
449
+ assert score == approx ( 0.65025 )
450
450
451
451
rec = make_index_balanced_accuracy (alpha = 0.5 , squared = True )(
452
452
recall_score )
453
453
score = rec (y_true , y_pred )
454
- assert_equal ( score , 0.41616000000000009 )
454
+ assert score == approx ( 0.41616000000000009 )
455
455
456
456
457
457
def test_iba_error_y_score_prob ():
0 commit comments