Skip to content

Commit 8237f2c

Browse files
Fix_248 (#263)
1 parent 2a9f04e commit 8237f2c

File tree

2 files changed

+79
-8
lines changed

2 files changed

+79
-8
lines changed

autoPyTorch/utils/implementations.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,17 @@ class LossWeightStrategyWeighted():
2525
def __call__(self, y: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
2626
if isinstance(y, torch.Tensor):
2727
y = y.detach().cpu().numpy() if y.is_cuda else y.numpy()
28-
if isinstance(y[0], str):
29-
y = y.astype('float64')
3028
counts = np.sum(y, axis=0)
3129
total_weight = y.shape[0]
3230

33-
if len(y.shape) > 1:
31+
if len(y.shape) > 1 and y.shape[1] != 1:
32+
# In this case, the second axis represents classes
3433
weight_per_class = total_weight / y.shape[1]
3534
weights = (np.ones(y.shape[1]) * weight_per_class) / np.maximum(counts, 1)
3635
else:
36+
# Numpy unique return the sorted classes. This is desirable as
37+
# weights recieved by PyTorch is a sorted list of classes
3738
classes, counts = np.unique(y, axis=0, return_counts=True)
38-
classes, counts = classes[::-1], counts[::-1]
3939
weight_per_class = total_weight / classes.shape[0]
4040
weights = (np.ones(classes.shape[0]) * weight_per_class) / counts
4141

@@ -50,10 +50,8 @@ class LossWeightStrategyWeightedBinary():
5050
def __call__(self, y: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
5151
if isinstance(y, torch.Tensor):
5252
y = y.detach().cpu().numpy() if y.is_cuda else y.numpy()
53-
if isinstance(y[0], str):
54-
y = y.astype('float64')
5553
counts_one = np.sum(y, axis=0)
56-
counts_zero = counts_one + (-y.shape[0])
54+
counts_zero = y.shape[0] - counts_one
5755
weights = counts_zero / np.maximum(counts_one, 1)
5856

5957
return np.array(weights)

test/test_pipeline/test_losses.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1+
import numpy as np
2+
13
import pytest
24

35
import torch
46
from torch import nn
57
from torch.nn.modules.loss import _Loss as Loss
68

79
from autoPyTorch.pipeline.components.training.losses import get_loss, losses
8-
from autoPyTorch.utils.implementations import get_loss_weight_strategy
10+
from autoPyTorch.utils.implementations import (
11+
LossWeightStrategyWeighted,
12+
LossWeightStrategyWeightedBinary,
13+
get_loss_weight_strategy,
14+
)
915

1016

1117
@pytest.mark.parametrize('output_type', ['multiclass',
@@ -66,3 +72,70 @@ def test_loss_dict():
6672
assert isinstance(loss['module'](), Loss)
6773
assert 'supported_output_types' in loss.keys()
6874
assert isinstance(loss['supported_output_types'], list)
75+
76+
77+
@pytest.mark.parametrize('target,expected_weights', [
78+
(
79+
# Expected 4 classes where first one is majority one
80+
np.array([[1, 0, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]),
81+
# We reduce the contribution of the first class which has double elements
82+
np.array([0.5, 1., 1., 1.]),
83+
),
84+
(
85+
# Expected 2 classes -- multilable format
86+
np.array([[1, 0], [1, 0], [1, 0], [0, 1]]),
87+
# We reduce the contribution of the first class which 3 to 1 ratio
88+
np.array([2 / 3, 2]),
89+
),
90+
(
91+
# Expected 2 classes -- (-1, 1) format
92+
np.array([[1], [1], [1], [0]]),
93+
# We reduce the contribution of the second class, which has a 3 to 1 ratio
94+
np.array([2, 2 / 3]),
95+
),
96+
(
97+
# Expected 2 classes -- single column
98+
# We have to reduce the contribution of the second class with 5 to 1 ratio
99+
np.array([1, 1, 1, 1, 1, 0]),
100+
# We reduce the contribution of the first class which has double elements
101+
np.array([3, 6 / 10]),
102+
),
103+
])
104+
def test_lossweightstrategyweighted(target, expected_weights):
105+
weights = LossWeightStrategyWeighted()(target)
106+
np.testing.assert_array_equal(weights, expected_weights)
107+
assert nn.CrossEntropyLoss(weight=torch.Tensor(weights))(
108+
torch.zeros(target.shape[0], len(weights)).float(),
109+
torch.from_numpy(target.argmax(1)).long() if len(target.shape) > 1
110+
else torch.from_numpy(target).long()
111+
) > 0
112+
113+
114+
@pytest.mark.parametrize('target,expected_weights', [
115+
(
116+
# Expected 2 classes -- multilable format
117+
np.array([[1, 0], [1, 0], [1, 0], [0, 1]]),
118+
# We reduce the contribution of the first class which 3 to 1 ratio
119+
np.array([1 / 3, 3]),
120+
),
121+
(
122+
# Expected 2 classes -- (-1, 1) format
123+
np.array([[1], [1], [1], [0]]),
124+
# We reduce the contribution of the second class, which has a 3 to 1 ratio
125+
np.array([1 / 3]),
126+
),
127+
(
128+
# Expected 2 classes -- single column
129+
# We have to reduce the contribution of the second class with 5 to 1 ratio
130+
np.array([1, 1, 1, 1, 1, 0]),
131+
# We reduce the contribution of the first class which has double elements
132+
np.array([0.2]),
133+
),
134+
])
135+
def test_lossweightstrategyweightedbinary(target, expected_weights):
136+
weights = LossWeightStrategyWeightedBinary()(target)
137+
np.testing.assert_array_equal(weights, expected_weights)
138+
assert nn.BCEWithLogitsLoss(pos_weight=torch.Tensor(weights))(
139+
torch.from_numpy(target).float(),
140+
torch.from_numpy(target).float(),
141+
) > 0

0 commit comments

Comments
 (0)