16
16
from sklearn .preprocessing import label_binarize
17
17
from sklearn .utils .fixes import np_version
18
18
from sklearn .utils .validation import check_random_state
19
- from sklearn .utils .testing import (assert_allclose , assert_array_equal ,
20
- assert_no_warnings , assert_equal ,
21
- assert_raises , assert_warns_message ,
22
- ignore_warnings , assert_not_equal ,
23
- assert_raise_message )
24
- from sklearn .metrics import (accuracy_score , average_precision_score ,
25
- brier_score_loss , cohen_kappa_score ,
26
- jaccard_similarity_score , precision_score ,
27
- recall_score , roc_auc_score )
19
+ from sklearn .utils .testing import assert_allclose , assert_array_equal
20
+ from sklearn .utils .testing import assert_no_warnings , assert_raises
21
+ from sklearn .utils .testing import assert_warns_message , ignore_warnings
22
+ from sklearn .utils .testing import assert_raise_message
23
+ from sklearn .metrics import accuracy_score , average_precision_score
24
+ from sklearn .metrics import brier_score_loss , cohen_kappa_score
25
+ from sklearn .metrics import jaccard_similarity_score , precision_score
26
+ from sklearn .metrics import recall_score , roc_auc_score
28
27
29
28
from imblearn .metrics import sensitivity_specificity_support
30
29
from imblearn .metrics import sensitivity_score
@@ -113,11 +112,11 @@ def test_sensitivity_specificity_score_binary():
113
112
114
113
def test_sensitivity_specificity_f_binary_single_class ():
115
114
# Such a case may occur with non-stratified cross-validation
116
- assert_equal ( 1. , sensitivity_score ([1 , 1 ], [1 , 1 ]))
117
- assert_equal ( 0. , specificity_score ([1 , 1 ], [1 , 1 ]))
115
+ assert sensitivity_score ([1 , 1 ], [1 , 1 ]) == 1.
116
+ assert specificity_score ([1 , 1 ], [1 , 1 ]) == 0.
118
117
119
- assert_equal ( 0. , sensitivity_score ([- 1 , - 1 ], [- 1 , - 1 ]))
120
- assert_equal ( 0. , specificity_score ([- 1 , - 1 ], [- 1 , - 1 ]))
118
+ assert sensitivity_score ([- 1 , - 1 ], [- 1 , - 1 ]) == 0.
119
+ assert specificity_score ([- 1 , - 1 ], [- 1 , - 1 ]) == 0.
121
120
122
121
123
122
@ignore_warnings
@@ -166,9 +165,8 @@ def test_sensitivity_specificity_ignored_labels():
166
165
rtol = R_TOL )
167
166
168
167
# ensure the above were meaningful tests:
169
- for average in ['macro' , 'weighted' , 'micro' ]:
170
- assert_not_equal (
171
- specificity_13 (average = average ), specificity_all (average = average ))
168
+ for each in ['macro' , 'weighted' , 'micro' ]:
169
+ assert specificity_13 (average = each ) != specificity_all (average = each )
172
170
173
171
174
172
def test_sensitivity_specificity_error_multilabels ():
@@ -333,15 +331,15 @@ def test_classification_report_imbalanced_multiclass():
333
331
y_pred ,
334
332
labels = np .arange (len (iris .target_names )),
335
333
target_names = iris .target_names )
336
- assert_equal ( _format_report (report ), expected_report )
334
+ assert _format_report (report ) == expected_report
337
335
# print classification report with label detection
338
336
expected_report = ('pre rec spe f1 geo iba sup 0 0.83 0.79 0.92 0.81 '
339
337
'0.86 0.74 24 1 0.33 0.10 0.86 0.15 0.44 0.19 31 2 '
340
338
'0.42 0.90 0.55 0.57 0.63 0.37 20 avg / total 0.51 '
341
339
'0.53 0.80 0.47 0.62 0.41 75' )
342
340
343
341
report = classification_report_imbalanced (y_true , y_pred )
344
- assert_equal ( _format_report (report ), expected_report )
342
+ assert _format_report (report ) == expected_report
345
343
346
344
347
345
def test_classification_report_imbalanced_multiclass_with_digits ():
@@ -361,14 +359,14 @@ def test_classification_report_imbalanced_multiclass_with_digits():
361
359
labels = np .arange (len (iris .target_names )),
362
360
target_names = iris .target_names ,
363
361
digits = 5 )
364
- assert_equal ( _format_report (report ), expected_report )
362
+ assert _format_report (report ) == expected_report
365
363
# print classification report with label detection
366
364
expected_report = ('pre rec spe f1 geo iba sup 0 0.83 0.79 0.92 0.81 '
367
365
'0.86 0.74 24 1 0.33 0.10 0.86 0.15 0.44 0.19 31 2 '
368
366
'0.42 0.90 0.55 0.57 0.63 0.37 20 avg / total 0.51 '
369
367
'0.53 0.80 0.47 0.62 0.41 75' )
370
368
report = classification_report_imbalanced (y_true , y_pred )
371
- assert_equal ( _format_report (report ), expected_report )
369
+ assert _format_report (report ) == expected_report
372
370
373
371
374
372
def test_classification_report_imbalanced_multiclass_with_string_label ():
@@ -382,15 +380,15 @@ def test_classification_report_imbalanced_multiclass_with_string_label():
382
380
'0.19 31 red 0.42 0.90 0.55 0.57 0.63 0.37 20 '
383
381
'avg / total 0.51 0.53 0.80 0.47 0.62 0.41 75' )
384
382
report = classification_report_imbalanced (y_true , y_pred )
385
- assert_equal ( _format_report (report ), expected_report )
383
+ assert _format_report (report ) == expected_report
386
384
387
385
expected_report = ('pre rec spe f1 geo iba sup a 0.83 0.79 0.92 0.81 '
388
386
'0.86 0.74 24 b 0.33 0.10 0.86 0.15 0.44 0.19 31 '
389
387
'c 0.42 0.90 0.55 0.57 0.63 0.37 20 avg / total '
390
388
'0.51 0.53 0.80 0.47 0.62 0.41 75' )
391
389
report = classification_report_imbalanced (
392
390
y_true , y_pred , target_names = ["a" , "b" , "c" ])
393
- assert_equal ( _format_report (report ), expected_report )
391
+ assert _format_report (report ) == expected_report
394
392
395
393
396
394
def test_classification_report_imbalanced_multiclass_with_unicode_label ():
@@ -411,7 +409,7 @@ def test_classification_report_imbalanced_multiclass_with_unicode_label():
411
409
classification_report_imbalanced , y_true , y_pred )
412
410
else :
413
411
report = classification_report_imbalanced (y_true , y_pred )
414
- assert_equal ( _format_report (report ), expected_report )
412
+ assert _format_report (report ) == expected_report
415
413
416
414
417
415
def test_classification_report_imbalanced_multiclass_with_long_string_label ():
@@ -427,7 +425,7 @@ def test_classification_report_imbalanced_multiclass_with_long_string_label():
427
425
'0.37 20 avg / total 0.51 0.53 0.80 0.47 0.62 0.41 75' )
428
426
429
427
report = classification_report_imbalanced (y_true , y_pred )
430
- assert_equal ( _format_report (report ), expected_report )
428
+ assert _format_report (report ) == expected_report
431
429
432
430
433
431
def test_iba_sklearn_metrics ():
@@ -436,22 +434,22 @@ def test_iba_sklearn_metrics():
436
434
acc = make_index_balanced_accuracy (alpha = 0.5 , squared = True )(
437
435
accuracy_score )
438
436
score = acc (y_true , y_pred )
439
- assert_equal ( score , 0.54756 )
437
+ assert score == 0.54756
440
438
441
439
jss = make_index_balanced_accuracy (alpha = 0.5 , squared = True )(
442
440
jaccard_similarity_score )
443
441
score = jss (y_true , y_pred )
444
- assert_equal ( score , 0.54756 )
442
+ assert score == 0.54756
445
443
446
444
pre = make_index_balanced_accuracy (alpha = 0.5 , squared = True )(
447
445
precision_score )
448
446
score = pre (y_true , y_pred )
449
- assert_equal ( score , 0.65025 )
447
+ assert score == 0.65025
450
448
451
449
rec = make_index_balanced_accuracy (alpha = 0.5 , squared = True )(
452
450
recall_score )
453
451
score = rec (y_true , y_pred )
454
- assert_equal ( score , 0.41616000000000009 )
452
+ assert score == 0.41616000000000009
455
453
456
454
457
455
def test_iba_error_y_score_prob ():
0 commit comments