11from collections .abc import Callable , Sequence
22from dataclasses import dataclass
3- from enum import Enum , Flag , auto
4- from typing import Protocol , override
3+ from enum import Flag , auto
4+ from typing import Literal , Protocol , override
55
66import numpy as np
77from numpy .typing import NDArray
@@ -57,11 +57,7 @@ def __call__(
5757 ) -> Float : ...
5858
5959
60- class LabelType (Enum ):
61- """The variable that is compared to the predictions in order to check how similar they are."""
62-
63- S = "s"
64- Y = "y"
60+ type LabelType = Literal ["group" , "y" ]
6561
6662
6763@dataclass
@@ -72,12 +68,12 @@ class RenyiCorrelation(GroupMetric):
7268 titled "On Measures of Dependence" by Alfréd Rényi.
7369 """
7470
75- base : LabelType = LabelType . S
71+ base : LabelType = "group"
7672
7773 @property
7874 def __name__ (self ) -> str :
7975 """The name of the metric."""
80- return f"renyi_{ self .base . value } "
76+ return f"renyi_{ self .base } "
8177
8278 @override
8379 def __call__ (
@@ -287,7 +283,7 @@ def as_group_metric(
287283 """Turn a sequence of metrics into a list of group metrics."""
288284 metrics = []
289285 for metric in base_metrics :
290- if agg & MetricAgg .DIFF :
286+ if MetricAgg .DIFF in agg :
291287 metrics .append (
292288 _BinaryAggMetric (
293289 metric = metric ,
@@ -296,7 +292,7 @@ def as_group_metric(
296292 aggregator = lambda i , j : j - i ,
297293 )
298294 )
299- if agg & MetricAgg .RATIO :
295+ if MetricAgg .RATIO in agg :
300296 metrics .append (
301297 _BinaryAggMetric (
302298 metric = metric ,
@@ -305,7 +301,7 @@ def as_group_metric(
305301 aggregator = lambda i , j : i / j if j != 0 else np .float64 (np .nan ),
306302 )
307303 )
308- if agg & MetricAgg .MIN :
304+ if MetricAgg .MIN in agg :
309305 metrics .append (
310306 _MulticlassAggMetric (
311307 metric = metric ,
@@ -314,7 +310,7 @@ def as_group_metric(
314310 aggregator = np .min ,
315311 )
316312 )
317- if agg & MetricAgg .MAX :
313+ if MetricAgg .MAX in agg :
318314 metrics .append (
319315 _MulticlassAggMetric (
320316 metric = metric ,
@@ -323,7 +319,7 @@ def as_group_metric(
323319 aggregator = np .max ,
324320 )
325321 )
326- if agg & MetricAgg .INDIVIDUAL :
322+ if MetricAgg .INDIVIDUAL in agg :
327323 metrics .append (
328324 _BinaryAggMetric (
329325 metric = metric ,
0 commit comments