Skip to content

Commit 790c70b

Browse files
committed
Replace enums with string literals
1 parent 36bed5b commit 790c70b

File tree

10 files changed

+48
-65
lines changed

10 files changed

+48
-65
lines changed

fair_forge/datasets.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from enum import Enum
21
from pathlib import Path
3-
from typing import NamedTuple, Protocol
2+
from typing import Literal, NamedTuple, Protocol
43

54
import numpy as np
65
from numpy.typing import NDArray
@@ -46,9 +45,7 @@ class GroupDataset(NamedTuple):
4645
feature_names: list[str]
4746

4847

49-
class AdultGroup(Enum):
50-
SEX = "Sex"
51-
RACE = "Race"
48+
type AdultGroup = Literal["Sex", "Race"]
5249

5350

5451
def load_adult(
@@ -66,7 +63,7 @@ def load_adult(
6663
Returns:
6764
A Dataset object containing the Adult dataset.
6865
"""
69-
name = f"Adult {group.value}"
66+
name = f"Adult {group}"
7067
if binarize_nationality:
7168
name += ", binary nationality"
7269
if binarize_race:
@@ -116,16 +113,18 @@ def load_adult(
116113
groups: NDArray[np.int32]
117114
to_drop: str
118115
match group:
119-
case AdultGroup.SEX:
116+
case "Sex":
120117
groups = (
121118
df.get_column("sex").cat.starts_with("Male").cast(pl.Int32).to_numpy()
122119
)
123120
to_drop = "sex"
124-
case AdultGroup.RACE:
121+
case "Race":
125122
# `.to_physical()` converts the categorical column to its physical representation,
126123
# which is UInt32 by default in Polars.
127124
groups = df.get_column("race").to_physical().cast(pl.Int32).to_numpy()
128125
to_drop = "race"
126+
case _:
127+
raise ValueError(f"Invalid group: {group}")
129128
if not group_in_features:
130129
df = df.drop(to_drop)
131130
column_grouping_prefixes.remove(to_drop)

fair_forge/eval.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from collections.abc import Mapping, Sequence
2-
from enum import Enum
3-
from typing import Any, cast
2+
from typing import Any, Literal, cast
43

54
import polars as pl
65

@@ -12,12 +11,7 @@
1211

1312
__all__ = ["Split", "evaluate"]
1413

15-
16-
class Split(Enum):
17-
"""Enum for different split methods used in evaluation."""
18-
19-
BASIC = "basic"
20-
PROPORTIONAL = "proportional"
14+
type Split = Literal["basic", "proportional"] | SplitMethod
2115

2216

2317
def evaluate(
@@ -28,7 +22,7 @@ def evaluate(
2822
*,
2923
preprocessor: Preprocessor | None = None,
3024
repeat: int = 1,
31-
split: Split | SplitMethod = Split.PROPORTIONAL,
25+
split: Split = "proportional",
3226
seed: int = 42,
3327
train_percentage: float = 0.8,
3428
remove_score_suffix: bool = True,
@@ -41,10 +35,15 @@ def evaluate(
4135
split_seed = seed + repeat_index
4236
split_method: SplitMethod
4337
match split:
44-
case Split.BASIC:
38+
case "basic":
4539
split_method = basic_split
46-
case Split.PROPORTIONAL:
40+
case "proportional":
4741
split_method = proportional_split
42+
case str() as split_method:
43+
raise ValueError(
44+
f"Invalid split method: {split_method}. "
45+
"Use 'basic', 'proportional', or a custom SplitMethod instance."
46+
)
4847
case _:
4948
split_method = split
5049
train_idx, test_idx = split_method(

fair_forge/methods.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""Protocols and implementations of methods for fairness-aware machine learning."""
22

33
from dataclasses import asdict, dataclass
4-
from enum import Enum
5-
from typing import Any, Protocol, Self
4+
from typing import Any, Literal, Protocol, Self
65

76
import numpy as np
87
from numpy.typing import NDArray
@@ -49,13 +48,7 @@ def fit(
4948
) -> Self: ...
5049

5150

52-
class FairnessType(Enum):
53-
DP = "dp"
54-
"""Demographic Parity (DP)"""
55-
EQ_OPP = "eq_opp"
56-
"""Equal Opportunity (EQ_OPP)"""
57-
EQ_ODDS = "eq_odds"
58-
"""Equalized Odds (EQ_ODDS)"""
51+
type FairnessType = Literal["dp", "eq_opp", "eq_odds"]
5952

6053

6154
@dataclass

fair_forge/metrics.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from collections.abc import Callable, Sequence
22
from dataclasses import dataclass
3-
from enum import Enum, Flag, auto
4-
from typing import Protocol, override
3+
from enum import Flag, auto
4+
from typing import Literal, Protocol, override
55

66
import numpy as np
77
from numpy.typing import NDArray
@@ -57,11 +57,7 @@ def __call__(
5757
) -> Float: ...
5858

5959

60-
class LabelType(Enum):
61-
"""The variable that is compared to the predictions in order to check how similar they are."""
62-
63-
S = "s"
64-
Y = "y"
60+
type LabelType = Literal["group", "y"]
6561

6662

6763
@dataclass
@@ -72,12 +68,12 @@ class RenyiCorrelation(GroupMetric):
7268
titled "On Measures of Dependence" by Alfréd Rényi.
7369
"""
7470

75-
base: LabelType = LabelType.S
71+
base: LabelType = "group"
7672

7773
@property
7874
def __name__(self) -> str:
7975
"""The name of the metric."""
80-
return f"renyi_{self.base.value}"
76+
return f"renyi_{self.base}"
8177

8278
@override
8379
def __call__(
@@ -287,7 +283,7 @@ def as_group_metric(
287283
"""Turn a sequence of metrics into a list of group metrics."""
288284
metrics = []
289285
for metric in base_metrics:
290-
if agg & MetricAgg.DIFF:
286+
if MetricAgg.DIFF in agg:
291287
metrics.append(
292288
_BinaryAggMetric(
293289
metric=metric,
@@ -296,7 +292,7 @@ def as_group_metric(
296292
aggregator=lambda i, j: j - i,
297293
)
298294
)
299-
if agg & MetricAgg.RATIO:
295+
if MetricAgg.RATIO in agg:
300296
metrics.append(
301297
_BinaryAggMetric(
302298
metric=metric,
@@ -305,7 +301,7 @@ def as_group_metric(
305301
aggregator=lambda i, j: i / j if j != 0 else np.float64(np.nan),
306302
)
307303
)
308-
if agg & MetricAgg.MIN:
304+
if MetricAgg.MIN in agg:
309305
metrics.append(
310306
_MulticlassAggMetric(
311307
metric=metric,
@@ -314,7 +310,7 @@ def as_group_metric(
314310
aggregator=np.min,
315311
)
316312
)
317-
if agg & MetricAgg.MAX:
313+
if MetricAgg.MAX in agg:
318314
metrics.append(
319315
_MulticlassAggMetric(
320316
metric=metric,
@@ -323,7 +319,7 @@ def as_group_metric(
323319
aggregator=np.max,
324320
)
325321
)
326-
if agg & MetricAgg.INDIVIDUAL:
322+
if MetricAgg.INDIVIDUAL in agg:
327323
metrics.append(
328324
_BinaryAggMetric(
329325
metric=metric,

fair_forge/nn/beutel.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ class Beutel(BaseEstimator, GroupBasedTransform):
154154
adv_size: list[int] = field(default_factory=lambda: [40])
155155
pred_size: list[int] = field(default_factory=lambda: [40])
156156
adv_weight: float = 1.0
157-
fairness: FairnessType = FairnessType.DP
157+
fairness: FairnessType = "dp"
158158
batch_size: int = 64
159159
iters: int = 500
160160
random_state: int = 42
@@ -182,12 +182,14 @@ def loss_fn(model: Model, x: Array, y: Array, s: Array) -> Array:
182182
).mean()
183183

184184
match self.fairness:
185-
case FairnessType.EQ_OPP:
185+
case "eq_opp":
186186
mask = y > 0.5
187-
case FairnessType.EQ_ODDS:
187+
case "eq_odds":
188188
raise NotImplementedError("Not implemented Eq. Odds yet")
189-
case FairnessType.DP:
189+
case "dp":
190190
mask = jnp.ones(s.shape, dtype=jnp.bool)
191+
case _:
192+
raise ValueError("Invalid fairness value")
191193
if s_size > 1:
192194
adversary_loss = optax.softmax_cross_entropy_with_integer_labels(
193195
logits=s_hat, labels=s, where=mask

fair_forge/preprocessing/group_pre_method.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from dataclasses import dataclass
2-
from enum import Enum
32
import itertools
4-
from typing import Any, Self
3+
from typing import Any, Literal, Self
54

65
import numpy as np
76
from numpy.typing import NDArray
@@ -54,17 +53,12 @@ def set_params(self, **params: Any) -> Self:
5453
return ret
5554

5655

57-
class UpsampleStrategy(Enum):
58-
"""Strategy for upsampling."""
59-
60-
UNIFORM = "uniform"
61-
# PREFERENTIAL = "preferential"
62-
NAIVE = "naive"
56+
type UpsampleStrategy = Literal["uniform", "naive"] # , "preferential"]
6357

6458

6559
@dataclass
6660
class Upsampler(BaseEstimator, GroupDatasetModifier):
67-
strategy: UpsampleStrategy = UpsampleStrategy.UNIFORM
61+
strategy: UpsampleStrategy = "uniform"
6862
random_state: int = 0
6963

7064
def fit(
@@ -89,7 +83,7 @@ def fit(
8983
vals = list([d[1] for d in data])
9084

9185
for mask, length, y_eq_y, s_eq_s in data:
92-
if self.strategy is UpsampleStrategy.NAIVE:
86+
if self.strategy == "naive":
9387
percentages.append((mask, (np.max(vals) / length).astype(np.float64)))
9488
else:
9589
num_samples = len(y)

tests/test_datasets.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
def test_adult_gender():
77
data = ff.load_adult(
8-
group=ff.AdultGroup.SEX,
8+
group="Sex",
99
group_in_features=False,
1010
binarize_nationality=False,
1111
binarize_race=False,
@@ -32,7 +32,7 @@ def test_adult_gender():
3232

3333
def test_adult_race():
3434
data = ff.load_adult(
35-
group=ff.AdultGroup.RACE,
35+
group="Race",
3636
group_in_features=False,
3737
binarize_nationality=False,
3838
binarize_race=False,
@@ -56,7 +56,7 @@ def test_adult_race():
5656

5757
def test_adult_race_binary():
5858
data = ff.load_adult(
59-
group=ff.AdultGroup.RACE,
59+
group="Race",
6060
group_in_features=False,
6161
binarize_nationality=True,
6262
binarize_race=True,
@@ -80,7 +80,7 @@ def test_adult_race_binary():
8080

8181
def test_adult_gender_in_features():
8282
data = ff.load_adult(
83-
group=ff.AdultGroup.SEX,
83+
group="Sex",
8484
group_in_features=True,
8585
binarize_nationality=True,
8686
binarize_race=False,

tests/test_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_pipeline_with_dummy():
2222
metrics=metrics,
2323
group_metrics=group_metrics,
2424
repeat=2,
25-
split=ff.Split.BASIC,
25+
split="basic",
2626
seed=42,
2727
train_percentage=0.8,
2828
remove_score_suffix=True,

tests/test_metrics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ def test_renyi():
99
y_pred = np.array([1, 0, 1, 0, 0, 1], dtype=np.int32)
1010
groups = np.array([1, 0, 1, 0, 0, 1], dtype=np.int32)
1111

12-
renyi_y = ff.RenyiCorrelation(ff.LabelType.Y)
12+
renyi_y = ff.RenyiCorrelation("y")
1313
result = renyi_y(y_true=y_true, y_pred=y_pred, groups=groups)
1414
np.testing.assert_allclose(result, 1 / 3)
1515
assert renyi_y.__name__ == "renyi_y"
1616

17-
renyi_s = ff.RenyiCorrelation(ff.LabelType.S)
17+
renyi_s = ff.RenyiCorrelation("group")
1818
result = renyi_s(y_true=y_true, y_pred=y_pred, groups=groups)
1919
np.testing.assert_allclose(result, 1.0)
20-
assert renyi_s.__name__ == "renyi_s"
20+
assert renyi_s.__name__ == "renyi_group"
2121

2222

2323
def test_prob_pos():

tests/test_pre_methods.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def test_upsampler():
99
y = np.array([0, 0, 1, 1, 1], dtype=np.int32)
1010
groups = np.array([0, 1, 0, 1, 1], dtype=np.int32)
1111
lr = LogisticRegression(random_state=41, max_iter=10)
12-
upsampler = ff.Upsampler(strategy=ff.UpsampleStrategy.NAIVE, random_state=41)
12+
upsampler = ff.Upsampler(strategy="naive", random_state=41)
1313
pipeline = ff.GroupPipeline(group_data_modifier=upsampler, estimator=lr)
1414
pipeline.set_params(random_state=42)
1515
assert pipeline.get_params()["estimator__random_state"] == 42

0 commit comments

Comments
 (0)