Skip to content

Commit adf8c65

Browse files
committed
Forest DiffusionModel RC commit
1 parent 2789612 commit adf8c65

3 files changed

Lines changed: 62 additions & 90 deletions

File tree

README.md

Lines changed: 10 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
33
[![Downloads](https://pepy.tech/badge/tabgan)](https://pepy.tech/project/tabgan)
44

5-
# GANs for tabular data
5+
# GANs and Diffusions for tabular data
66

77
<img src="./images/tabular_gan.png" height="15%" width="15%">
88
Generative Adversarial Networks (GANs) are well-known for their success in realistic image generation. However, they can also be applied to generate tabular data. Here will give opportunity to try some of them.
@@ -29,9 +29,10 @@ test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list("ABCD
2929
# generate data
3030
new_train1, new_target1 = OriginalGenerator().generate_data_pipe(train, target, test, )
3131
new_train2, new_target2 = GANGenerator().generate_data_pipe(train, target, test, )
32+
new_train3, new_target3 = ForestDiffusionGenerator().generate_data_pipe(train, target, test, )
3233

3334
# example with all params defined
34-
new_train3, new_target3 = GANGenerator(gen_x_times=1.1, cat_cols=None,
35+
new_train4, new_target4 = GANGenerator(gen_x_times=1.1, cat_cols=None,
3536
bot_filter_quantile=0.001, top_filter_quantile=0.999, is_post_process=True,
3637
adversarial_model_params={
3738
"metrics": "AUC", "max_depth": 2, "max_bin": 100,
@@ -41,7 +42,10 @@ new_train3, new_target3 = GANGenerator(gen_x_times=1.1, cat_cols=None,
4142
test, deep_copy=True, only_adversarial=False, use_adversarial=True)
4243
```
4344

44-
Both samplers `OriginalGenerator` and `GANGenerator` have same input parameters:
45+
All samplers `OriginalGenerator`, `ForestDiffusionGenerator` and `GANGenerator` have same input parameters.
46+
47+
1. **GANGenerator** based on **CTGAN**
48+
2. **ForestDiffusionGenerator** based on **Forest Diffusion**
4549

4650
* **gen_x_times**: float = 1.1 - how much data to generate, output might be less because of postprocessing and
4751
adversarial filtering
@@ -132,43 +136,14 @@ To run experiment follow these steps:
132136
add more datasets, adjust validation type and categorical encoders.
133137
5. Observe metrics across all experiment in console or in `./Research/results/fit_predict_scores.txt`
134138

135-
**Task formalization**
136-
137-
Let say we have **T_train** and **T_test** (train and test set respectively). We need to train the model on **T_train**
138-
and make predictions on **T_test**. However, we will increase the train by generating new data by GAN, somehow similar
139-
to **T_test**, without using ground truth labels.
140139

141140
**Experiment design**
142141

143-
In the case of having a smaller **T_train** and a different data distribution, we can use CTGAN to generate additional data **T_synth**. First, we train CTGAN on **T_train** with ground truth labels (step 1), then generate additional data **T_synth** (step 2). Secondly, we train boosting in an adversarial way on concatenated **T_train** and **T_synth** (target set to 0) with **T_test** (target set to 1) (steps 3 & 4). The goal is to apply the newly trained adversarial boosting to obtain rows more like **T_test**. Note that initial ground truth labels aren't used for adversarial training. As a result, we take top rows from **T_train** and **T_synth** sorted by correspondence to **T_test** (steps 5 & 6), and train new boosting on them and check results on **T_test**.
144-
145142
![Experiment design and workflow](./images/workflow.png?raw=true)
146143

147144
**Picture 1.1** Experiment design and workflow
148145

149-
Of course for the benchmark purposes we will test ordinal training without these tricks and another original pipeline
150-
but without CTGAN (in step 3 we won"t use **T_sync**).
151-
152-
**Datasets**
153-
154-
All datasets came from different domains. They have a different number of observations, number of categorical and
155-
numerical features. The objective for all datasets - binary classification. Preprocessing of datasets were simple:
156-
removed all time-based columns from datasets. Remaining columns were either categorical or numerical.
157-
158-
**Table 1.1** Used datasets
159-
160-
| Name | Total points | Train points | Test points | Number of features | Number of categorical features | Short description |
161-
| :--- | :---: | :---: | :---: | :---: | :---: | :---: |
162-
| [Telecom](https://www.kaggle.com/blastchar/telco-customer-churn) | 7.0k | 4.2k | 2.8k | 20 | 16 | Churn prediction for telecom data |
163-
| [Adult](https://www.kaggle.com/wenruliu/adult-income-dataset) | 48.8k | 29.3k | 19.5k | 15 | 8 | Predict if persons" income is bigger 50k |
164-
| [Employee](https://www.kaggle.com/c/amazon-employee-access-challenge/data) | 32.7k | 19.6k | 13.1k | 10 | 9 | Predict an employee"s access needs, given his/her job role|
165-
| [Credit](https://www.kaggle.com/c/home-credit-default-risk/data) | 307.5k | 184.5k | 123k | 121 | 18 | Loan repayment |
166-
| [Mortgages](https://www.crowdanalytix.com/contests/propensity-to-fund-mortgages) | 45.6k | 27.4k | 18.2k | 20 | 9 | Predict if house mortgage is founded |
167-
| [Taxi](https://www.crowdanalytix.com/contests/mckinsey-big-data-hackathon) | 892.5k | 535.5k | 357k | 8 | 5 | Predict the probability of an offer being accepted by a certain driver |
168-
| [Poverty_A](https://www.drivendata.org/competitions/50/worldbank-poverty-prediction/page/99/) | 37.6k | 22.5k | 15.0k | 41 | 38 | Predict whether or not a given household for a given country is poor or not |
169-
170146
## Results
171-
172147
To determine the best sampling strategy, ROC AUC scores of each dataset were scaled (min-max scale) and then averaged
173148
among the dataset.
174149

@@ -224,35 +199,12 @@ arxiv publication:
224199
primaryClass={cs.LG}
225200
}
226201
```
227-
library itself:
228-
```bibtex
229-
@misc{Diyago2020tabgan,
230-
author = {Ashrapov, Insaf},
231-
title = {GANs for tabular data},
232-
howpublished = {\url{https://github.com/Diyago/GAN-for-tabular-data}},
233-
year = {2020}
234-
}
235-
```
236202

237203
## References
238204

239-
[1] Jonathan Hui. GAN — What is Generative Adversarial Networks GAN? (2018), medium article
240-
241-
[2]Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville,
242-
Yoshua Bengio. Generative Adversarial Networks (2014). arXiv:1406.2661
243-
244-
[3] Lei Xu LIDS, Kalyan Veeramachaneni. Synthesizing Tabular Data using Generative Adversarial Networks (2018). arXiv:
205+
[1] Lei Xu LIDS, Kalyan Veeramachaneni. Synthesizing Tabular Data using Generative Adversarial Networks (2018). arXiv:
245206
1811.11264v1 [cs.LG]
246207

247-
[4] Lei Xu, Maria Skoularidou, Alfredo Cuesta-Infante, Kalyan Veeramachaneni. Modeling Tabular Data using Conditional
248-
GAN (2019). arXiv:1907.00503v2 [cs.LG]
249-
250-
[5] Denis Vorotyntsev. Benchmarking Categorical Encoders. Medium post
251-
252-
[6] Insaf Ashrapov. GAN-for-tabular-data. Github repository.
253-
254-
[7] Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen, Timo Aila. Analyzing and Improving the
255-
Image Quality of StyleGAN (2019) arXiv:1912.04958v2 [cs.CV]
256-
257-
[8] ODS.ai: Open data science, https://ods.ai/
208+
[2] Alexia Jolicoeur-Martineau and Kilian Fatras and Tal Kachman. Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees ((2023) https://github.com/SamsungSAILMontreal/ForestDiffusion [cs.LG]
258209

210+
[3] Lei Xu, Maria Skoularidou, Alfredo Cuesta-Infante, Kalyan Veeramachaneni. Modeling Tabular data using Conditional GAN. NeurIPS, (2019)

src/tabgan/sampler.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -307,11 +307,11 @@ def generate_data(
307307
self.TEMP_TARGET = None
308308
logging.info("Fitting ForestDiffusion model")
309309
if self.cat_cols is None:
310-
forest_model = ForestDiffusionModel(train_df.to_numpy(), label_y=self.TEMP_TARGET, n_t=50,
310+
forest_model = ForestDiffusionModel(train_df.to_numpy(), label_y=None, n_t=50,
311311
duplicate_K=100,
312312
diffusion_type='flow', n_jobs=-1)
313313
else:
314-
forest_model = ForestDiffusionModel(train_df.to_numpy(), label_y=self.TEMP_TARGET, n_t=50,
314+
forest_model = ForestDiffusionModel(train_df.to_numpy(), label_y=None, n_t=50,
315315
duplicate_K=100,
316316
# todo fix bug with cat cols
317317
#cat_indexes=self.get_column_indexes(train_df, self.cat_cols),
@@ -393,39 +393,42 @@ def get_columns_if_exists(df, col) -> pd.DataFrame:
393393
logging.info(train)
394394
target = pd.DataFrame(np.random.randint(0, 2, size=(train_size, 1)), columns=list("Y"))
395395
test = pd.DataFrame(np.random.randint(0, 100, size=(train_size, 4)), columns=list("ABCD"))
396-
# _sampler(OriginalGenerator(gen_x_times=15), train, target, test)
397-
# _sampler(
398-
# GANGenerator(gen_x_times=10, only_generated_data=False,
399-
# gen_params={"batch_size": 500, "patience": 25, "epochs": 500, }), train, target, test
400-
# )
401-
#
402-
# _sampler(OriginalGenerator(gen_x_times=15), train, None, train)
403-
# _sampler(
404-
# GANGenerator(cat_cols=["A"], gen_x_times=20, only_generated_data=True),
405-
# train,
406-
# None,
407-
# train,
408-
# )
396+
_sampler(OriginalGenerator(gen_x_times=15), train, target, test)
397+
_sampler(
398+
GANGenerator(gen_x_times=10, only_generated_data=False,
399+
gen_params={"batch_size": 500, "patience": 25, "epochs": 500, }), train, target, test
400+
)
401+
402+
_sampler(OriginalGenerator(gen_x_times=15), train, None, train)
403+
_sampler(
404+
GANGenerator(cat_cols=["A"], gen_x_times=20, only_generated_data=True),
405+
train,
406+
None,
407+
train,
408+
)
409409
_sampler(
410410
ForestDiffusionGenerator(cat_cols=["A"], gen_x_times=1, only_generated_data=True),
411411
train,
412412
None,
413413
train,
414414
)
415+
_sampler(
416+
ForestDiffusionGenerator(gen_x_times=10, only_generated_data=False,
417+
gen_params={"batch_size": 500, "patience": 25, "epochs": 500, }), train, target, test
418+
)
419+
420+
min_date = pd.to_datetime('2019-01-01')
421+
max_date = pd.to_datetime('2021-12-31')
422+
423+
d = (max_date - min_date).days + 1
424+
425+
train['Date'] = min_date + pd.to_timedelta(np.random.randint(d, size=train_size), unit='d')
426+
train = get_year_mnth_dt_from_date(train, 'Date')
415427

416-
#
417-
# min_date = pd.to_datetime('2019-01-01')
418-
# max_date = pd.to_datetime('2021-12-31')
419-
#
420-
# d = (max_date - min_date).days + 1
421-
#
422-
# train['Date'] = min_date + pd.to_timedelta(np.random.randint(d, size=train_size), unit='d')
423-
# train = get_year_mnth_dt_from_date(train, 'Date')
424-
#
425-
# new_train, new_target = GANGenerator(gen_x_times=1.1, cat_cols=['year'], bot_filter_quantile=0.001,
426-
# top_filter_quantile=0.999,
427-
# is_post_process=True, pregeneration_frac=2, only_generated_data=False). \
428-
# generate_data_pipe(train.drop('Date', axis=1), None,
429-
# train.drop('Date', axis=1)
430-
# )
431-
# new_train = collect_dates(new_train)
428+
new_train, new_target = GANGenerator(gen_x_times=1.1, cat_cols=['year'], bot_filter_quantile=0.001,
429+
top_filter_quantile=0.999,
430+
is_post_process=True, pregeneration_frac=2, only_generated_data=False). \
431+
generate_data_pipe(train.drop('Date', axis=1), None,
432+
train.drop('Date', axis=1)
433+
)
434+
new_train = collect_dates(new_train)

tests/test_sampler.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010
import pandas as pd
1111

12-
from src.tabgan.sampler import OriginalGenerator, Sampler, GANGenerator
12+
from src.tabgan.sampler import OriginalGenerator, Sampler, GANGenerator, ForestDiffusionGenerator
1313

1414

1515
class TestOriginalGenerator(TestCase):
@@ -94,3 +94,20 @@ def test_generate_data(self):
9494
self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique()))
9595
self.assertTrue(gen_train.shape[0] > new_train.shape[0])
9696
self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique()))
97+
98+
class TestSamplerGAN(TestCase):
99+
def setUp(self):
100+
self.train = pd.DataFrame(np.random.randint(-10, 150, size=(50, 4)), columns=list('ABCD'))
101+
self.target = pd.DataFrame(np.random.randint(0, 2, size=(50, 1)), columns=list('Y'))
102+
self.test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list('ABCD'))
103+
self.gen = ForestDiffusionGenerator(gen_x_times=15)
104+
self.sampler = self.gen.get_object_generator()
105+
106+
def test_generate_data(self):
107+
new_train, new_target, test_df = self.sampler.preprocess_data(self.train.copy(),
108+
self.target.copy(), self.test)
109+
gen_train, gen_target = self.sampler.generate_data(new_train, new_target, test_df)
110+
self.assertEqual(gen_train.shape[0], gen_target.shape[0])
111+
self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique()))
112+
self.assertTrue(gen_train.shape[0] > new_train.shape[0])
113+
self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique()))

0 commit comments

Comments
 (0)