Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 74 additions & 16 deletions bigframes/ml/metrics/_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
"""Metrics functions for evaluating models. This module is styled after
scikit-learn's metrics module: https://scikit-learn.org/stable/modules/metrics.html."""

from __future__ import annotations

import inspect
import typing
from typing import Tuple, Union
from typing import Literal, overload, Tuple, Union

import bigframes_vendored.constants as constants
import bigframes_vendored.sklearn.metrics._classification as vendored_metrics_classification
Expand Down Expand Up @@ -259,31 +261,64 @@ def recall_score(
recall_score.__doc__ = inspect.getdoc(vendored_metrics_classification.recall_score)


@overload
def precision_score(
y_true: Union[bpd.DataFrame, bpd.Series],
y_pred: Union[bpd.DataFrame, bpd.Series],
y_true: bpd.DataFrame | bpd.Series,
y_pred: bpd.DataFrame | bpd.Series,
*,
average: typing.Optional[str] = "binary",
pos_label: int | float | bool | str = ...,
average: Literal["binary"] = ...,
) -> float:
...


@overload
def precision_score(
y_true: bpd.DataFrame | bpd.Series,
y_pred: bpd.DataFrame | bpd.Series,
*,
pos_label: int | float | bool | str = ...,
average: None = ...,
) -> pd.Series:
# TODO(ashleyxu): support more average type, default to "binary"
if average is not None:
raise NotImplementedError(
f"Only average=None is supported. {constants.FEEDBACK_LINK}"
)
...


def precision_score(
y_true: bpd.DataFrame | bpd.Series,
y_pred: bpd.DataFrame | bpd.Series,
*,
pos_label: int | float | bool | str = 1,
average: Literal["binary"] | None = "binary",
) -> pd.Series | float:
y_true_series, y_pred_series = utils.batch_convert_to_series(y_true, y_pred)

is_accurate = y_true_series == y_pred_series
if average is None:
return _precision_score_per_class(y_true_series, y_pred_series)

if average == "binary":
return _precision_score_binary_pos_only(y_true_series, y_pred_series, pos_label)

raise NotImplementedError(
f"Unsupported 'average' param value: {average}. {constants.FEEDBACK_LINK}"
)


precision_score.__doc__ = inspect.getdoc(
vendored_metrics_classification.precision_score
)


def _precision_score_per_class(y_true: bpd.Series, y_pred: bpd.Series) -> pd.Series:
is_accurate = y_true == y_pred
unique_labels = (
bpd.concat([y_true_series, y_pred_series], join="outer")
bpd.concat([y_true, y_pred], join="outer")
.drop_duplicates()
.sort_values(inplace=False)
)
index = unique_labels.to_list()

precision = (
is_accurate.groupby(y_pred_series).sum()
/ is_accurate.groupby(y_pred_series).count()
is_accurate.groupby(y_pred).sum() / is_accurate.groupby(y_pred).count()
).to_pandas()

precision_score = pd.Series(0, index=index)
Expand All @@ -293,9 +328,32 @@ def precision_score(
return precision_score


precision_score.__doc__ = inspect.getdoc(
vendored_metrics_classification.precision_score
)
def _precision_score_binary_pos_only(
y_true: bpd.Series, y_pred: bpd.Series, pos_label: int | float | bool | str
) -> float:
if y_true.drop_duplicates().count() != 2 or y_pred.drop_duplicates().count() != 2:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may create extra queries with y_true.drop_duplicates().to_list() in line 340. We may want to merge them.

Can you take a look at how many queries are created when running this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the result: https://screenshot.googleplex.com/9aFGAUSHzuPDPtB. it feels weird because no query jobs are printed out.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Local execution? @TrevorBergeron

raise ValueError(
"Target is multiclass but average='binary'. Please choose another average setting."
)

total_labels = set(
y_true.drop_duplicates().to_list() + y_pred.drop_duplicates().to_list()
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably avoid drop_duplicates, it has overhead from trying to preserve ordering, try unique(keep_order=False) instead. Also try to minimize query count

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code updated. This is the execution output: https://screenshot.googleplex.com/9aFGAUSHzuPDPtB.

It's weird that no query job links are provided.


if len(total_labels) != 2:
raise ValueError(
"Target is multiclass but average='binary'. Please choose another average setting."
)

if pos_label not in total_labels:
raise ValueError(
f"pos_labe={pos_label} is not a valid label. It should be one of {list(total_labels)}"
)

target_elem_idx = y_pred == pos_label
is_accurate = y_pred[target_elem_idx] == y_true[target_elem_idx]

return is_accurate.sum() / is_accurate.count()


def f1_score(
Expand Down
65 changes: 65 additions & 0 deletions tests/system/small/ml/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,71 @@ def test_precision_score_series(session):
)


@pytest.mark.parametrize(
("pos_label", "expected_score"),
[
("a", 1 / 3),
("b", 0),
],
)
def test_precision_score_binary(session, pos_label, expected_score):
pd_df = pd.DataFrame(
{
"y_true": ["a", "a", "a", "b", "b"],
"y_pred": ["b", "b", "a", "a", "a"],
}
)
df = session.read_pandas(pd_df)

precision_score = metrics.precision_score(
df["y_true"], df["y_pred"], average="binary", pos_label=pos_label
)

assert precision_score == pytest.approx(expected_score)


def test_precision_score_binary_default_arguments(session):
pd_df = pd.DataFrame(
{
"y_true": [1, 1, 1, 0, 0],
"y_pred": [0, 0, 1, 1, 1],
}
)
df = session.read_pandas(pd_df)

precision_score = metrics.precision_score(df["y_true"], df["y_pred"])

assert precision_score == pytest.approx(1 / 3)


@pytest.mark.parametrize(
("y_true", "y_pred", "pos_label"),
[
pytest.param(
pd.Series([1, 2, 3]), pd.Series([1, 0]), 1, id="y_true-non-binary-label"
),
pytest.param(
pd.Series([1, 0]), pd.Series([1, 2, 3]), 1, id="y_pred-non-binary-label"
),
pytest.param(
pd.Series([1, 0]), pd.Series([1, 2]), 1, id="combined-non-binary-label"
),
pytest.param(pd.Series([1, 0]), pd.Series([1, 0]), 2, id="invalid-pos_label"),
],
)
def test_precision_score_binary_invalid_input_raise_error(
session, y_true, y_pred, pos_label
):

bf_y_true = session.read_pandas(y_true)
bf_y_pred = session.read_pandas(y_pred)

with pytest.raises(ValueError):
metrics.precision_score(
bf_y_true, bf_y_pred, average="binary", pos_label=pos_label
)


def test_f1_score(session):
pd_df = pd.DataFrame(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def precision_score(
default='binary'
This parameter is required for multiclass/multilabel targets.
Possible values are 'None', 'micro', 'macro', 'samples', 'weighted', 'binary'.
Only average=None is supported.
Only None and 'binary' is supported.

Returns:
precision: float (if average is not None) or Series of float of shape \
Expand Down