-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathREADME.Rmd
More file actions
167 lines (116 loc) · 7.16 KB
/
Copy pathREADME.Rmd
File metadata and controls
167 lines (116 loc) · 7.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
---
output: github_document
---
<!-- README.md is generated from README.Rmd. Please edit that file -->
```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.path = "man/figures/README-",
out.width = "100%"
)
```
# tabnet
<!-- badges: start -->
[](https://github.com/mlverse/tabnet/actions)
[](https://lifecycle.r-lib.org/articles/stages.html)
[](https://CRAN.R-project.org/package=tabnet)
[](https://cran.r-project.org/package=tabnet) [](https://discord.com/invite/s3D5cKhBkx)
<!-- badges: end -->
An R implementation of: [TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442) [(Sercan O. Arik, Tomas Pfister)](https://doi.org/10.48550/arXiv.1908.07442).\
The code in this repository started by an R port using the [torch](https://github.com/mlverse/torch) package of [dreamquark-ai/tabnet](https://github.com/dreamquark-ai/tabnet) implementation.
TabNet is now augmented with
- [Coherent Hierarchical Multi-label Classification Networks](https://proceedings.neurips.cc//paper/2020/file/6dd4e10e3296fa63738371ec0d5df818-Paper.pdf) [(Eleonora Giunchiglia et Al.)](https://doi.org/10.48550/arXiv.2010.10151) for hierarchical outcomes
- [Optimizing ROC Curves with a Sort-Based Surrogate Loss for Binary Classification and Changepoint Detection (J Hillman, TD Hocking)](https://jmlr.org/papers/v24/21-0751.html) for imbalanced binary classification.
## Installation
Install [{tabnet} from CRAN](https://CRAN.R-project.org/package=tabnet) with:
``` r
install.packages("tabnet")
```
The development version can be installed from [GitHub](https://github.com/mlverse/tabnet) with:
``` r
# install.packages("pak")
pak::pak("mlverse/tabnet")
```
## Basic Binary Classification Example
Here we show a **binary classification** example of the `attrition` dataset, using a **recipe** for dataset input specification.
```{r model-fit}
#| fig.alt: "A training loss line-plot along training epochs. Both validation loss and training loss are shown. Training loss line includes regular dots at epochs where a checkpoint is recorded."
library(tabnet)
suppressPackageStartupMessages(library(recipes))
library(yardstick)
library(ggplot2)
torch::torch_manual_seed(2026)
data("attrition", package = "modeldata")
test_idx <- sample.int(nrow(attrition), size = 0.2 * nrow(attrition))
train <- attrition[-test_idx,]
test <- attrition[test_idx,]
rec <- recipe(Attrition ~ ., data = train) %>%
step_normalize(all_numeric(), -all_outcomes())
fit <- tabnet_fit(rec, train, epochs = 30, valid_split=0.1, learn_rate = 5e-3)
autoplot(fit)
```
The plots gives you an immediate insight about model over-fitting, and if any, the available model checkpoints available before the over-fitting
Keep in mind that **regression** as well as **multi-class classification** are also available, and that you can specify dataset through **data.frame** and **formula** as well. You will find them in the package vignettes.
## Model performance results
As the standard method `predict()` is used, you can rely on your usual metric functions for model performance results. Here we use {yardstick} :
```{r}
metrics <- metric_set(accuracy, precision, recall)
augment(fit, test) %>%
metrics(Attrition, estimate = .pred_class)
augment(fit, test, type = "prob") %>%
roc_auc(Attrition, .pred_No)
```
## Explain model on test-set with attention map
TabNet has intrinsic explainability feature through the visualization of attention map, either **aggregated**:
```{r model-explain}
#| fig.alt: "An heatmap as explainability plot showing for each variable of the test-set on the y axis the importance along each observation on the x axis. The value is a mask agggregate."
explain <- tabnet_explain(fit, test)
autoplot(explain)
```
or at **each layer** through the `type = "steps"` option:
```{r step-explain}
#| fig.alt: "An small-multiple heatmap as explainability plot for each step of the Tabnet network. Each plot shows for each variable of the test-set on the y axis the importance along each observation on the x axis."
autoplot(explain, type = "steps")
```
## Self-supervised pretraining
For cases when a consistent part of your dataset has no outcome, TabNet offers a self-supervised training step allowing to model to capture predictors intrinsic features and predictors interactions, upfront the supervised task.
```{r step-pretrain}
#| fig.alt: "A training loss line-plot along pre-training epochs. Both validation loss and training loss are shown. Training loss line includes regular dots at epochs where a checkpoint is recorded."
pretrain <- tabnet_pretrain(rec, train, epochs = 50, valid_split=0.1, learn_rate = 1e-2)
autoplot(pretrain)
```
The example here is a toy example as the `train` dataset does actually contain outcomes. The vignette [`vignette("selfsupervised_training")`](https://mlverse.github.io/tabnet/articles/selfsupervised_training.html) will gives you the complete correct workflow step-by-step.
## {tidymodels} integration
The integration within tidymodels workflows offers you unlimited opportunity to compare {tabnet} models with challengers.
Don't miss the [`vignette("tidymodels-interface")`](https://mlverse.github.io/tabnet/articles/tidymodels-interface.html) for that.
## Missing data in predictors
{tabnet} leverage the masking mechanism to deal with missing data, so you don't have to remove the entries in your dataset with some missing values in the predictors variables.
See [`vignette("Missing_data_predictors")`](https://mlverse.github.io/tabnet/articles/Missing_data_predictors.html)
## Imbalanced binary classification
{tabnet} includes a Area under the $Min(FPR,FNR)$ (AUM) loss function `nn_aum_loss()` dedicated to your imbalanced binary classification tasks.
Try it out in [`vignette("aum_loss")`](https://mlverse.github.io/tabnet/articles/aum_loss.html)
# Comparison with other implementations
| Group | Feature | {tabnet} | dreamquark-ai | fast-tabnet |
|---------------|---------------|:-------------:|:-------------:|:-------------:|
| Input format | data-frame | ✅ | ✅ | ✅ |
| | formula | ✅ | | |
| | recipe | ✅ | | |
| | Node | ✅ | | |
| | missings in predictor | ✅ | | |
| Output format | data-frame | ✅ | ✅ | ✅ |
| | workflow | ✅ | | |
| ML Tasks | self-supervised learning | ✅ | ✅ | |
| | classification (binary, multi-class) | ✅ | ✅ | ✅ |
| | unbalanced binary classification | ✅ | | |
| | regression | ✅ | ✅ | ✅ |
| | multi-outcome | ✅ | ✅ | |
| | hierarchical multi-label classif. | ✅ | | |
| Model management | from / to file | ✅ | ✅ | v |
| | resume from snapshot | ✅ | | |
| | training diagnostic | ✅ | | |
| Interpretability | | ✅ | ✅ | ✅ |
| Performance | | 1 x | 2 - 4 x | |
| Code quality | test coverage | 85% | | |
| | continuous integration | 4 OS including GPU | | |
: Alternative TabNet implementation features