-
Notifications
You must be signed in to change notification settings - Fork 62
feat: support average='binary' in precision_score() #2080
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
7d3492f
ab90b72
2904b9d
95a005c
06392d2
8d5d573
a9943bd
96758ff
e1c032b
58adcba
8633dea
2ec095f
741a198
bd41f8a
e99b59a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,9 +15,11 @@ | |
| """Metrics functions for evaluating models. This module is styled after | ||
| scikit-learn's metrics module: https://scikit-learn.org/stable/modules/metrics.html.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import inspect | ||
| import typing | ||
| from typing import Tuple, Union | ||
| from typing import Literal, overload, Tuple, Union | ||
|
|
||
| import bigframes_vendored.constants as constants | ||
| import bigframes_vendored.sklearn.metrics._classification as vendored_metrics_classification | ||
|
|
@@ -259,31 +261,64 @@ def recall_score( | |
| recall_score.__doc__ = inspect.getdoc(vendored_metrics_classification.recall_score) | ||
|
|
||
|
|
||
| @overload | ||
| def precision_score( | ||
| y_true: Union[bpd.DataFrame, bpd.Series], | ||
| y_pred: Union[bpd.DataFrame, bpd.Series], | ||
| y_true: bpd.DataFrame | bpd.Series, | ||
| y_pred: bpd.DataFrame | bpd.Series, | ||
| *, | ||
| average: typing.Optional[str] = "binary", | ||
| pos_label: int | float | bool | str = ..., | ||
| average: Literal["binary"] = ..., | ||
| ) -> float: | ||
| ... | ||
|
|
||
|
|
||
| @overload | ||
| def precision_score( | ||
| y_true: bpd.DataFrame | bpd.Series, | ||
| y_pred: bpd.DataFrame | bpd.Series, | ||
| *, | ||
| pos_label: int | float | bool | str = ..., | ||
| average: None = ..., | ||
| ) -> pd.Series: | ||
| # TODO(ashleyxu): support more average type, default to "binary" | ||
| if average is not None: | ||
| raise NotImplementedError( | ||
| f"Only average=None is supported. {constants.FEEDBACK_LINK}" | ||
| ) | ||
| ... | ||
|
|
||
|
|
||
| def precision_score( | ||
| y_true: bpd.DataFrame | bpd.Series, | ||
| y_pred: bpd.DataFrame | bpd.Series, | ||
| *, | ||
| pos_label: int | float | bool | str = 1, | ||
| average: Literal["binary"] | None = "binary", | ||
| ) -> pd.Series | float: | ||
| y_true_series, y_pred_series = utils.batch_convert_to_series(y_true, y_pred) | ||
|
|
||
| is_accurate = y_true_series == y_pred_series | ||
| if average is None: | ||
| return _precision_score_per_class(y_true_series, y_pred_series) | ||
|
|
||
| if average == "binary": | ||
| return _precision_score_binary_pos_only(y_true_series, y_pred_series, pos_label) | ||
|
|
||
| raise NotImplementedError( | ||
| f"Unsupported 'average' param value: {average}. {constants.FEEDBACK_LINK}" | ||
| ) | ||
|
|
||
|
|
||
| precision_score.__doc__ = inspect.getdoc( | ||
| vendored_metrics_classification.precision_score | ||
| ) | ||
|
|
||
|
|
||
| def _precision_score_per_class(y_true: bpd.Series, y_pred: bpd.Series) -> pd.Series: | ||
| is_accurate = y_true == y_pred | ||
| unique_labels = ( | ||
| bpd.concat([y_true_series, y_pred_series], join="outer") | ||
| bpd.concat([y_true, y_pred], join="outer") | ||
| .drop_duplicates() | ||
| .sort_values(inplace=False) | ||
| ) | ||
| index = unique_labels.to_list() | ||
|
|
||
| precision = ( | ||
| is_accurate.groupby(y_pred_series).sum() | ||
| / is_accurate.groupby(y_pred_series).count() | ||
| is_accurate.groupby(y_pred).sum() / is_accurate.groupby(y_pred).count() | ||
| ).to_pandas() | ||
|
|
||
| precision_score = pd.Series(0, index=index) | ||
|
|
@@ -293,9 +328,32 @@ def precision_score( | |
| return precision_score | ||
|
|
||
|
|
||
| precision_score.__doc__ = inspect.getdoc( | ||
| vendored_metrics_classification.precision_score | ||
| ) | ||
| def _precision_score_binary_pos_only( | ||
| y_true: bpd.Series, y_pred: bpd.Series, pos_label: int | float | bool | str | ||
| ) -> float: | ||
| if y_true.drop_duplicates().count() != 2 or y_pred.drop_duplicates().count() != 2: | ||
| raise ValueError( | ||
| "Target is multiclass but average='binary'. Please choose another average setting." | ||
| ) | ||
|
|
||
| total_labels = set( | ||
| y_true.drop_duplicates().to_list() + y_pred.drop_duplicates().to_list() | ||
| ) | ||
|
||
|
|
||
| if len(total_labels) != 2: | ||
| raise ValueError( | ||
| "Target is multiclass but average='binary'. Please choose another average setting." | ||
| ) | ||
|
|
||
| if pos_label not in total_labels: | ||
| raise ValueError( | ||
| f"pos_labe={pos_label} is not a valid label. It should be one of {list(total_labels)}" | ||
| ) | ||
|
|
||
| target_elem_idx = y_pred == pos_label | ||
| is_accurate = y_pred[target_elem_idx] == y_true[target_elem_idx] | ||
|
|
||
| return is_accurate.sum() / is_accurate.count() | ||
|
|
||
|
|
||
| def f1_score( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may create extra queries with y_true.drop_duplicates().to_list() in line 340. We may want to merge them.
Can you take a look at how many queries are created when running this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the result: https://screenshot.googleplex.com/9aFGAUSHzuPDPtB. it feels weird because no query jobs are printed out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Local execution? @TrevorBergeron