You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Generative Adversarial Networks (GANs) are well-known for their success in realistic image generation. However, they can also be applied to generate tabular data. Here will give opportunity to try some of them.
Both samplers `OriginalGenerator` and `GANGenerator` have same input parameters:
45
+
All samplers `OriginalGenerator`, `ForestDiffusionGenerator` and `GANGenerator` have same input parameters.
46
+
47
+
1.**GANGenerator** based on **CTGAN**
48
+
2.**ForestDiffusionGenerator** based on **Forest Diffusion**
45
49
46
50
***gen_x_times**: float = 1.1 - how much data to generate, output might be less because of postprocessing and
47
51
adversarial filtering
@@ -132,43 +136,14 @@ To run experiment follow these steps:
132
136
add more datasets, adjust validation type and categorical encoders.
133
137
5. Observe metrics across all experiment in console or in `./Research/results/fit_predict_scores.txt`
134
138
135
-
**Task formalization**
136
-
137
-
Let say we have **T_train** and **T_test** (train and test set respectively). We need to train the model on **T_train**
138
-
and make predictions on **T_test**. However, we will increase the train by generating new data by GAN, somehow similar
139
-
to **T_test**, without using ground truth labels.
140
139
141
140
**Experiment design**
142
141
143
-
In the case of having a smaller **T_train** and a different data distribution, we can use CTGAN to generate additional data **T_synth**. First, we train CTGAN on **T_train** with ground truth labels (step 1), then generate additional data **T_synth** (step 2). Secondly, we train boosting in an adversarial way on concatenated **T_train** and **T_synth** (target set to 0) with **T_test** (target set to 1) (steps 3 & 4). The goal is to apply the newly trained adversarial boosting to obtain rows more like **T_test**. Note that initial ground truth labels aren't used for adversarial training. As a result, we take top rows from **T_train** and **T_synth** sorted by correspondence to **T_test** (steps 5 & 6), and train new boosting on them and check results on **T_test**.
144
-
145
142

146
143
147
144
**Picture 1.1** Experiment design and workflow
148
145
149
-
Of course for the benchmark purposes we will test ordinal training without these tricks and another original pipeline
150
-
but without CTGAN (in step 3 we won"t use **T_sync**).
151
-
152
-
**Datasets**
153
-
154
-
All datasets came from different domains. They have a different number of observations, number of categorical and
155
-
numerical features. The objective for all datasets - binary classification. Preprocessing of datasets were simple:
156
-
removed all time-based columns from datasets. Remaining columns were either categorical or numerical.
157
-
158
-
**Table 1.1** Used datasets
159
-
160
-
| Name | Total points | Train points | Test points | Number of features | Number of categorical features | Short description |
|[Mortgages](https://www.crowdanalytix.com/contests/propensity-to-fund-mortgages)| 45.6k | 27.4k | 18.2k | 20 | 9 | Predict if house mortgage is founded |
167
-
|[Taxi](https://www.crowdanalytix.com/contests/mckinsey-big-data-hackathon)| 892.5k | 535.5k | 357k | 8 | 5 | Predict the probability of an offer being accepted by a certain driver |
168
-
|[Poverty_A](https://www.drivendata.org/competitions/50/worldbank-poverty-prediction/page/99/)| 37.6k | 22.5k | 15.0k | 41 | 38 | Predict whether or not a given household for a given country is poor or not |
169
-
170
146
## Results
171
-
172
147
To determine the best sampling strategy, ROC AUC scores of each dataset were scaled (min-max scale) and then averaged
[7] Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen, Timo Aila. Analyzing and Improving the
255
-
Image Quality of StyleGAN (2019) arXiv:1912.04958v2 [cs.CV]
256
-
257
-
[8] ODS.ai: Open data science, https://ods.ai/
208
+
[2] Alexia Jolicoeur-Martineau and Kilian Fatras and Tal Kachman. Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees ((2023) https://github.com/SamsungSAILMontreal/ForestDiffusion[cs.LG]
258
209
210
+
[3] Lei Xu, Maria Skoularidou, Alfredo Cuesta-Infante, Kalyan Veeramachaneni. Modeling Tabular data using Conditional GAN. NeurIPS, (2019)
0 commit comments