Skip to content
Merged
8 changes: 8 additions & 0 deletions thinc/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ def test_to_categorical(label_smoothing):
):
to_categorical(numpy.asarray([0, 0, 0]), label_smoothing=0.01),

error_msg = ("For 5 number of classes "
"label_smoothing parameter has to be less than "
"0.8, but found 0.8.")
with pytest.raises(
ValueError, match=error_msg
):
to_categorical(numpy.asarray([0, 1, 2, 3, 4]), label_smoothing=0.8)


def test_convert_recursive():
is_match = lambda obj: obj == "foo"
Expand Down
18 changes: 12 additions & 6 deletions thinc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,15 @@ def to_categorical(
*,
label_smoothing: float = 0.0,
) -> FloatsXd:
if not 0.0 <= label_smoothing < 0.5:
raise ValueError(
"label_smoothing should be greater or "
"equal to 0.0 and less than 0.5, "
f"but {label_smoothing} was provided."
)

if n_classes is None:
n_classes = int(numpy.max(Y) + 1) # type: ignore

if label_smoothing < 0.0:
raise ValueError(
"Label-smoothing parameter has to be greater than or equal to 0"
)

if label_smoothing == 0.0:
if n_classes == 0:
raise ValueError("n_classes should be at least 1")
Expand All @@ -234,6 +233,13 @@ def to_categorical(
)
nongold_prob = label_smoothing / (n_classes - 1)

if (1 - label_smoothing) < nongold_prob:
raise ValueError(
f"For {n_classes} number of classes "
"label_smoothing parameter has to be less than "
f"{1 - nongold_prob}, but found {label_smoothing}."
)

xp = get_array_module(Y)
label_distr = xp.full((n_classes, n_classes), nongold_prob, dtype="float32")
xp.fill_diagonal(label_distr, 1 - label_smoothing)
Expand Down