Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 67 additions & 16 deletions bigframes/ml/metrics/_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
"""Metrics functions for evaluating models. This module is styled after
scikit-learn's metrics module: https://scikit-learn.org/stable/modules/metrics.html."""

from __future__ import annotations

import inspect
import typing
from typing import Tuple, Union
from typing import Literal, overload, Tuple, Union

import bigframes_vendored.constants as constants
import bigframes_vendored.sklearn.metrics._classification as vendored_metrics_classification
Expand Down Expand Up @@ -259,31 +261,64 @@ def recall_score(
recall_score.__doc__ = inspect.getdoc(vendored_metrics_classification.recall_score)


@overload
def precision_score(
y_true: Union[bpd.DataFrame, bpd.Series],
y_pred: Union[bpd.DataFrame, bpd.Series],
y_true: bpd.DataFrame | bpd.Series,
y_pred: bpd.DataFrame | bpd.Series,
*,
average: typing.Optional[str] = "binary",
pos_label: int | float | bool | str = ...,
average: Literal["binary"] = ...,
) -> float:
...


@overload
def precision_score(
y_true: bpd.DataFrame | bpd.Series,
y_pred: bpd.DataFrame | bpd.Series,
*,
pos_label: int | float | bool | str = ...,
average: None = ...,
) -> pd.Series:
# TODO(ashleyxu): support more average type, default to "binary"
if average is not None:
raise NotImplementedError(
f"Only average=None is supported. {constants.FEEDBACK_LINK}"
)
...


def precision_score(
y_true: bpd.DataFrame | bpd.Series,
y_pred: bpd.DataFrame | bpd.Series,
*,
pos_label: int | float | bool | str = 1,
average: Literal["binary"] | None = "binary",
) -> pd.Series | float:
y_true_series, y_pred_series = utils.batch_convert_to_series(y_true, y_pred)

is_accurate = y_true_series == y_pred_series
if average is None:
return _precision_score_per_label(y_true_series, y_pred_series)

if average == "binary":
return _precision_score_binary_pos_only(y_true_series, y_pred_series, pos_label)

raise NotImplementedError(
f"Unsupported 'average' param value: {average}. {constants.FEEDBACK_LINK}"
)


precision_score.__doc__ = inspect.getdoc(
vendored_metrics_classification.precision_score
)


def _precision_score_per_label(y_true: bpd.Series, y_pred: bpd.Series) -> pd.Series:
is_accurate = y_true == y_pred
unique_labels = (
bpd.concat([y_true_series, y_pred_series], join="outer")
bpd.concat([y_true, y_pred], join="outer")
.drop_duplicates()
.sort_values(inplace=False)
)
index = unique_labels.to_list()

precision = (
is_accurate.groupby(y_pred_series).sum()
/ is_accurate.groupby(y_pred_series).count()
is_accurate.groupby(y_pred).sum() / is_accurate.groupby(y_pred).count()
).to_pandas()

precision_score = pd.Series(0, index=index)
Expand All @@ -293,9 +328,25 @@ def precision_score(
return precision_score


precision_score.__doc__ = inspect.getdoc(
vendored_metrics_classification.precision_score
)
def _precision_score_binary_pos_only(
y_true: bpd.Series, y_pred: bpd.Series, pos_label: int | float | bool | str
) -> float:
unique_labels = bpd.concat([y_true, y_pred]).unique(keep_order=False)

if unique_labels.count() != 2:
raise ValueError(
"Target is multiclass but average='binary'. Please choose another average setting."
)

if not (unique_labels == pos_label).any():
raise ValueError(
f"pos_labe={pos_label} is not a valid label. It should be one of {unique_labels.to_list()}"
)

target_elem_idx = y_pred == pos_label
is_accurate = y_pred[target_elem_idx] == y_true[target_elem_idx]

return is_accurate.sum() / is_accurate.count()


def f1_score(
Expand Down
65 changes: 65 additions & 0 deletions tests/system/small/ml/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,71 @@ def test_precision_score_series(session):
)


@pytest.mark.parametrize(
("pos_label", "expected_score"),
[
("a", 1 / 3),
("b", 0),
],
)
def test_precision_score_binary(session, pos_label, expected_score):
pd_df = pd.DataFrame(
{
"y_true": ["a", "a", "a", "b", "b"],
"y_pred": ["b", "b", "a", "a", "a"],
}
)
df = session.read_pandas(pd_df)

precision_score = metrics.precision_score(
df["y_true"], df["y_pred"], average="binary", pos_label=pos_label
)

assert precision_score == pytest.approx(expected_score)


def test_precision_score_binary_default_arguments(session):
pd_df = pd.DataFrame(
{
"y_true": [1, 1, 1, 0, 0],
"y_pred": [0, 0, 1, 1, 1],
}
)
df = session.read_pandas(pd_df)

precision_score = metrics.precision_score(df["y_true"], df["y_pred"])

assert precision_score == pytest.approx(1 / 3)


@pytest.mark.parametrize(
("y_true", "y_pred", "pos_label"),
[
pytest.param(
pd.Series([1, 2, 3]), pd.Series([1, 0]), 1, id="y_true-non-binary-label"
),
pytest.param(
pd.Series([1, 0]), pd.Series([1, 2, 3]), 1, id="y_pred-non-binary-label"
),
pytest.param(
pd.Series([1, 0]), pd.Series([1, 2]), 1, id="combined-non-binary-label"
),
pytest.param(pd.Series([1, 0]), pd.Series([1, 0]), 2, id="invalid-pos_label"),
],
)
def test_precision_score_binary_invalid_input_raise_error(
session, y_true, y_pred, pos_label
):

bf_y_true = session.read_pandas(y_true)
bf_y_pred = session.read_pandas(y_pred)

with pytest.raises(ValueError):
metrics.precision_score(
bf_y_true, bf_y_pred, average="binary", pos_label=pos_label
)


def test_f1_score(session):
pd_df = pd.DataFrame(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def precision_score(
default='binary'
This parameter is required for multiclass/multilabel targets.
Possible values are 'None', 'micro', 'macro', 'samples', 'weighted', 'binary'.
Only average=None is supported.
Only None and 'binary' is supported.

Returns:
precision: float (if average is not None) or Series of float of shape \
Expand Down