Skip to content

Commit 32fccd4

Browse files
committed
change EBM implementation used from the interpret repo directly to using the autogluon implementation of EBMs
1 parent 1a394bd commit 32fccd4

7 files changed

Lines changed: 10 additions & 239 deletions

File tree

tabarena/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ tabpfn = [
4343
"tabpfn-extensions[many_class] @ git+https://github.com/PriorLabs/tabpfn-extensions.git",
4444
]
4545
tabicl = ["tabicl>=2.0.0"]
46-
ebm = ["interpret-core>=0.7.3"]
46+
ebm = ["autogluon.tabular[interpret]>=1.5,<1.6"]
4747
search_spaces = ["configspace>=1.2,<2.0"]
4848
realmlp = ["pytabkit>=1.5.0,<2.0"]
4949
tabdpt = [
@@ -62,7 +62,7 @@ benchmark = [
6262
"tabpfn>=7.0.0",
6363
"tabpfn-extensions[many_class] @ git+https://github.com/PriorLabs/tabpfn-extensions.git",
6464
"tabicl>=2.0.0",
65-
"interpret-core>=0.7.3",
65+
"autogluon.tabular[interpret]>=1.5,<1.6",
6666
"configspace>=1.2,<2.0",
6767
"pytabkit>=1.5.0,<2.0",
6868
"tabdpt>=1.1.10",

tabarena/tabarena/benchmark/models/ag/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
from tabarena.benchmark.models.ag.ebm.ebm_model import ExplainableBoostingMachineModel
43
from tabarena.benchmark.models.ag.knn_new.knn_model import KNNNewModel
54
from tabarena.benchmark.models.ag.modernnca.modernnca_model import ModernNCAModel
65
from tabarena.benchmark.models.ag.perpetual_booster.perpetual_booster_model import (
@@ -19,7 +18,6 @@
1918
from tabarena.benchmark.models.ag.xrfm.xrfm_model import XRFMModel
2019

2120
__all__ = [
22-
"ExplainableBoostingMachineModel",
2321
"KNNNewModel",
2422
"ModernNCAModel",
2523
"PerpetualBoosterModel",

tabarena/tabarena/benchmark/models/ag/ebm/__init__.py

Whitespace-only changes.

tabarena/tabarena/benchmark/models/ag/ebm/ebm_model.py

Lines changed: 0 additions & 224 deletions
This file was deleted.

tabarena/tabarena/benchmark/models/model_registry.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from autogluon.tabular.registry import ModelRegistry, ag_model_registry
66

77
from tabarena.benchmark.models.ag import (
8-
ExplainableBoostingMachineModel,
98
KNNNewModel,
109
ModernNCAModel,
1110
PerpetualBoosterModel,
@@ -24,7 +23,6 @@
2423
tabarena_model_registry: ModelRegistry = copy.deepcopy(ag_model_registry)
2524

2625
_models_to_add = [
27-
ExplainableBoostingMachineModel,
2826
RealMLPModel,
2927
TabICLModel,
3028
TabDPTModel,

tabarena/tabarena/models/ebm/generate.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
from __future__ import annotations
2-
31
from autogluon.common.space import Categorical, Int, Real
4-
5-
from tabarena.benchmark.models.ag.ebm.ebm_model import ExplainableBoostingMachineModel
2+
from autogluon.tabular.models import EBMModel
63
from tabarena.utils.config_utils import ConfigGenerator
74

85
name = "EBM"
9-
manual_configs = []
6+
manual_configs = [
7+
{},
8+
]
109
search_space = {
1110
"max_leaves": Int(2, 3, default=2),
1211
"smoothing_rounds": Int(0, 1000, default=200),
@@ -42,9 +41,9 @@
4241
}
4342

4443
gen_ebm = ConfigGenerator(
45-
model_cls=ExplainableBoostingMachineModel,
44+
model_cls=EBMModel,
4645
search_space=search_space,
47-
manual_configs=[{}],
46+
manual_configs=manual_configs,
4847
)
4948

5049

tst/benchmark/models/test_ebm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ def test_ebm():
66

77
try:
88
from autogluon.tabular.testing import FitHelper
9-
from tabarena.benchmark.models.ag.ebm.ebm_model import ExplainableBoostingMachineModel
10-
model_cls = ExplainableBoostingMachineModel
9+
from autogluon.tabular.models import EBMModel
10+
model_cls = EBMModel
1111
FitHelper.verify_model(model_cls=model_cls, model_hyperparameters=model_hyperparameters)
1212
except ImportError as err:
1313
pytest.skip(

0 commit comments

Comments
 (0)