This repository was archived by the owner on Apr 1, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 68
feat: support average='binary' in precision_score() #2080
Merged
Merged
Changes from 1 commit
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
7d3492f
feat: support 'binary' for precision_score
sycai ab90b72
add test
sycai 2904b9d
Merge branch 'main' into sycai_precision_score_binary
sycai 95a005c
Merge branch 'main' into sycai_precision_score_binary
sycai 06392d2
use unique(keep_order=False) to count unique items
sycai 8d5d573
use local variables to hold unique classes
sycai a9943bd
Merge branch 'main' into sycai_precision_score_binary
sycai 96758ff
use concat before checking unique labels
sycai e1c032b
fix test
sycai 58adcba
Merge branch 'main' into sycai_precision_score_binary
sycai 8633dea
Merge branch 'main' into sycai_precision_score_binary
sycai 2ec095f
Merge branch 'main' into sycai_precision_score_binary
sycai 741a198
Merge branch 'main' into sycai_precision_score_binary
sycai bd41f8a
Merge branch 'main' into sycai_precision_score_binary
sycai e99b59a
Merge branch 'main' into sycai_precision_score_binary
sycai File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,9 +15,11 @@ | |
| """Metrics functions for evaluating models. This module is styled after | ||
| scikit-learn's metrics module: https://scikit-learn.org/stable/modules/metrics.html.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import inspect | ||
| import typing | ||
| from typing import Tuple, Union | ||
| from typing import Literal, overload, Tuple, Union | ||
|
|
||
| import bigframes_vendored.constants as constants | ||
| import bigframes_vendored.sklearn.metrics._classification as vendored_metrics_classification | ||
|
|
@@ -259,31 +261,64 @@ def recall_score( | |
| recall_score.__doc__ = inspect.getdoc(vendored_metrics_classification.recall_score) | ||
|
|
||
|
|
||
| @overload | ||
| def precision_score( | ||
| y_true: Union[bpd.DataFrame, bpd.Series], | ||
| y_pred: Union[bpd.DataFrame, bpd.Series], | ||
| y_true: bpd.DataFrame | bpd.Series, | ||
| y_pred: bpd.DataFrame | bpd.Series, | ||
| *, | ||
| average: typing.Optional[str] = "binary", | ||
| pos_label: int | float | bool | str = ..., | ||
| average: Literal["binary"] = ..., | ||
| ) -> float: | ||
| ... | ||
|
|
||
|
|
||
| @overload | ||
| def precision_score( | ||
| y_true: bpd.DataFrame | bpd.Series, | ||
| y_pred: bpd.DataFrame | bpd.Series, | ||
| *, | ||
| pos_label: int | float | bool | str = ..., | ||
| average: None = ..., | ||
| ) -> pd.Series: | ||
| # TODO(ashleyxu): support more average type, default to "binary" | ||
| if average is not None: | ||
| raise NotImplementedError( | ||
| f"Only average=None is supported. {constants.FEEDBACK_LINK}" | ||
| ) | ||
| ... | ||
|
|
||
|
|
||
| def precision_score( | ||
| y_true: bpd.DataFrame | bpd.Series, | ||
| y_pred: bpd.DataFrame | bpd.Series, | ||
| *, | ||
| pos_label: int | float | bool | str = 1, | ||
| average: Literal["binary"] | None = "binary", | ||
| ) -> pd.Series | float: | ||
| y_true_series, y_pred_series = utils.batch_convert_to_series(y_true, y_pred) | ||
|
|
||
| is_accurate = y_true_series == y_pred_series | ||
| if average is None: | ||
| return _precision_score_per_class(y_true_series, y_pred_series) | ||
|
|
||
| if average == "binary": | ||
| return _precision_score_binary_pos_only(y_true_series, y_pred_series, pos_label) | ||
|
|
||
| raise NotImplementedError( | ||
| f"Unsupported 'average' param value: {average}. {constants.FEEDBACK_LINK}" | ||
| ) | ||
|
|
||
|
|
||
| precision_score.__doc__ = inspect.getdoc( | ||
| vendored_metrics_classification.precision_score | ||
| ) | ||
|
|
||
|
|
||
| def _precision_score_per_class(y_true: bpd.Series, y_pred: bpd.Series) -> pd.Series: | ||
| is_accurate = y_true == y_pred | ||
| unique_labels = ( | ||
| bpd.concat([y_true_series, y_pred_series], join="outer") | ||
| bpd.concat([y_true, y_pred], join="outer") | ||
| .drop_duplicates() | ||
| .sort_values(inplace=False) | ||
| ) | ||
| index = unique_labels.to_list() | ||
|
|
||
| precision = ( | ||
| is_accurate.groupby(y_pred_series).sum() | ||
| / is_accurate.groupby(y_pred_series).count() | ||
| is_accurate.groupby(y_pred).sum() / is_accurate.groupby(y_pred).count() | ||
| ).to_pandas() | ||
|
|
||
| precision_score = pd.Series(0, index=index) | ||
|
|
@@ -293,9 +328,32 @@ def precision_score( | |
| return precision_score | ||
|
|
||
|
|
||
| precision_score.__doc__ = inspect.getdoc( | ||
| vendored_metrics_classification.precision_score | ||
| ) | ||
| def _precision_score_binary_pos_only( | ||
| y_true: bpd.Series, y_pred: bpd.Series, pos_label: int | float | bool | str | ||
| ) -> float: | ||
| if y_true.drop_duplicates().count() != 2 or y_pred.drop_duplicates().count() != 2: | ||
| raise ValueError( | ||
| "Target is multiclass but average='binary'. Please choose another average setting." | ||
| ) | ||
|
|
||
| total_labels = set( | ||
| y_true.drop_duplicates().to_list() + y_pred.drop_duplicates().to_list() | ||
| ) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should probably avoid
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Code updated. This is the execution output: https://screenshot.googleplex.com/9aFGAUSHzuPDPtB. It's weird that no query job links are provided. |
||
|
|
||
| if len(total_labels) != 2: | ||
| raise ValueError( | ||
| "Target is multiclass but average='binary'. Please choose another average setting." | ||
| ) | ||
|
|
||
| if pos_label not in total_labels: | ||
| raise ValueError( | ||
| f"pos_labe={pos_label} is not a valid label. It should be one of {list(total_labels)}" | ||
| ) | ||
|
|
||
| target_elem_idx = y_pred == pos_label | ||
| is_accurate = y_pred[target_elem_idx] == y_true[target_elem_idx] | ||
|
|
||
| return is_accurate.sum() / is_accurate.count() | ||
|
|
||
|
|
||
| def f1_score( | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may create extra queries with y_true.drop_duplicates().to_list() in line 340. We may want to merge them.
Can you take a look at how many queries are created when running this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the result: https://screenshot.googleplex.com/9aFGAUSHzuPDPtB. it feels weird because no query jobs are printed out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Local execution? @TrevorBergeron