Skip to content

Commit ca5eb8c

Browse files
committed
Add more pre-trained models
1 parent b567ad4 commit ca5eb8c

6 files changed

Lines changed: 471 additions & 40 deletions

File tree

README.md

Lines changed: 111 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,101 @@ conda env export --no-builds | grep -v "prefix" > environment.yml
4141
pip list > pip_list.txt
4242
```
4343

44-
# Training on Audioset
45-
Download and prepare the dataset as explained in the [audioset page](audioset/)
46-
The base PaSST model can be trained for example like this:
47-
```bash
48-
python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base"
44+
# Getting started
45+
Each dataset has an experiment file such as `ex_audioset.py` and `ex_openmic.py` and a dataset folder with a readme file.
46+
In general, you can prob the experiment file for help:
47+
```shell
48+
python ex_audioset.py help
4949
```
50+
5051
you can override any of the configuration using the [sacred syntax](https://sacred.readthedocs.io/en/stable/command_line.html).
5152
In order to see the available options either use [omniboard](https://github.com/vivekratnavel/omniboard) or use:
5253
```shell
5354
python ex_audioset.py print_config
5455
```
56+
There are many pre-defined configuration options in `config_updates.py`. These include different models, setups etc...
57+
You can list these configurations with:
58+
```shell
59+
python ex_audioset.py print_named_configs
60+
```
61+
The overall configurations looks like this:
62+
```yaml
63+
...
64+
seed = 542198583 # the random seed for this experiment
65+
slurm_job_id = ''
66+
speed_test_batch_size = 100
67+
swa = True
68+
swa_epoch_start = 50
69+
swa_freq = 5
70+
use_mixup = True
71+
warm_up_len = 5
72+
weight_decay = 0.0001
73+
basedataset:
74+
base_dir = 'audioset_hdf5s/' # base directory of the dataset, change it or make a link
75+
eval_hdf5 = 'audioset_hdf5s/mp3/eval_segments_mp3.hdf'
76+
wavmix = 1
77+
....
78+
roll_conf:
79+
axis = 1
80+
shift = None
81+
shift_range = 50
82+
datasets:
83+
test:
84+
batch_size = 20
85+
dataset = {CMD!}'/basedataset.get_test_set'
86+
num_workers = 16
87+
validate = True
88+
training:
89+
batch_size = 12
90+
dataset = {CMD!}'/basedataset.get_full_training_set'
91+
num_workers = 16
92+
sampler = {CMD!}'/basedataset.get_ft_weighted_sampler'
93+
shuffle = None
94+
train = True
95+
models:
96+
mel:
97+
freqm = 48
98+
timem = 192
99+
hopsize = 320
100+
htk = False
101+
n_fft = 1024
102+
n_mels = 128
103+
norm = 1
104+
sr = 32000
105+
...
106+
net:
107+
arch = 'passt_s_swa_p16_128_ap476'
108+
fstride = 10
109+
in_channels = 1
110+
input_fdim = 128
111+
input_tdim = 998
112+
n_classes = 527
113+
s_patchout_f = 4
114+
s_patchout_t = 40
115+
tstride = 10
116+
u_patchout = 0
117+
...
118+
trainer:
119+
accelerator = None
120+
accumulate_grad_batches = 1
121+
amp_backend = 'native'
122+
amp_level = 'O2'
123+
auto_lr_find = False
124+
auto_scale_batch_size = False
125+
...
126+
```
127+
There are many things that can be updated from the command line.
55128
In short:
56-
- All the configuration options under `trainer` are pytorch lightning trainer [api](https://pytorch-lightning.readthedocs.io/en/1.4.1/common/trainer.html#trainer-class-api).
57-
- `models.net` are the passt options.
58-
- `models.mel` are the preprocessing options.
129+
- All the configuration options under `trainer` are pytorch lightning trainer [api](https://pytorch-lightning.readthedocs.io/en/1.4.1/common/trainer.html#trainer-class-api). For example, to turn off cuda benchmarking add `trainer.benchmark=False` to the command line.
130+
- `models.net` are the PaSST (or the chosen NN) options.
131+
- `models.mel` are the preprocessing options (mel spectrograms).
132+
133+
# Training on Audioset
134+
Download and prepare the dataset as explained in the [audioset page](audioset/)
135+
The base PaSST model can be trained for example like this:
136+
```bash
137+
python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base"
138+
```
59139

60140
For example using only unstructured patchout of 400:
61141
```bash
@@ -68,6 +148,7 @@ Multi-gpu training can be enabled by setting the environment variable `DDP`, for
68148
DDP=2 python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base 2 GPU"
69149
```
70150

151+
71152
# Pre-trained models
72153
Please check the [releases page](releases/), to download pre-trained models.
73154
In general, you can get a pretrained model on Audioset using
@@ -79,6 +160,28 @@ model = get_model(arch="passt_s_swa_p16_128_ap476", pretrained=True, n_classes=
79160
```
80161
this will get automatically download pretrained PaSST on audioset with with mAP of ```0.476```. the model was trained with ```s_patchout_t=40, s_patchout_f=4``` but you can change these to better fit your task/ computational needs.
81162

163+
There are several pretrained models availble with different strides (overlap) and with/without using SWA: `passt_s_p16_s16_128_ap468, passt_s_swa_p16_s16_128_ap473, passt_s_swa_p16_s14_128_ap471, passt_s_p16_s14_128_ap469, passt_s_swa_p16_s12_128_ap473, passt_s_p16_s12_128_ap470`.
164+
For example, In `passt_s_swa_p16_s16_128_ap473`: `p16` mean patch size is `16x16`, `s16` means no overlap (stride=16), 128 mel bands, `ap473` refers to the performance of this model on Audioset mAP=0.479.
165+
166+
In general, you can get a this pretrained model using:
167+
```python
168+
from models.passt import get_model
169+
passt = get_model(arch="passt_s_swa_p16_s16_128_ap473", fstride=16, tstride=16)
170+
```
171+
Using the framework, you can evaluate this model using:
172+
```shell
173+
python ex_audioset.py evaluate_only with passt_s_swa_p16_s16_128_ap473
174+
```
175+
176+
Two ensemble of these models are provided as well:
177+
A large ensemble giving `mAP=.4956`
178+
```shell
179+
python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_many
180+
```
181+
An ensemble of models with `stride=10` giving `mAP=.4864`
182+
```shell
183+
python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_s10
184+
```
82185

83186
# Contact
84187
The repo will be updated, in the mean time if you have any questions or problems feel free to open an issue on GitHub, or contact the authors directly.

audioset/dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def roll_func(b):
330330

331331

332332
@dataset.command
333-
def get_training_set(normalize, roll, wavmix=False, freqm=48, timem= 192,mel_bins=128):
333+
def get_training_set(normalize, roll, wavmix=False):
334334
ds = get_base_training_set()
335335
get_ir_sample()
336336
if normalize:
@@ -346,7 +346,7 @@ def get_training_set(normalize, roll, wavmix=False, freqm=48, timem= 192,mel_bin
346346

347347

348348
@dataset.command
349-
def get_full_training_set(normalize, roll, wavmix=False, freqm=48, timem= 192, mel_bins=128):
349+
def get_full_training_set(normalize, roll, wavmix=False):
350350
ds = get_base_full_training_set()
351351
get_ir_sample()
352352
if normalize:
@@ -362,7 +362,7 @@ def get_full_training_set(normalize, roll, wavmix=False, freqm=48, timem= 192, m
362362

363363

364364
@dataset.command
365-
def get_test_set(normalize, roll, mel_bins=128):
365+
def get_test_set(normalize):
366366
ds = get_base_test_set()
367367
if normalize:
368368
print("normalized test!")

config_updates.py

Lines changed: 152 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,178 @@
11
from sacred.config_helpers import DynamicIngredient, CMD
22

3+
34
def add_configs(ex):
5+
'''
6+
This functions add generic configuration for the experiments, such as mix-up, architectures, etc...
7+
@param ex: Ba3l Experiment
8+
@return:
9+
'''
410

511
@ex.named_config
612
def nomixup():
13+
'Don\'t apply mix-up (spectrogram level).'
714
use_mixup = False
815
mixup_alpha = 0.3
916

17+
@ex.named_config
18+
def mixup():
19+
' Apply mix-up (spectrogram level).'
20+
use_mixup = True
21+
mixup_alpha = 0.3
1022

1123
@ex.named_config
1224
def mini_train():
13-
# just to debug
25+
'limit training/validation to 5 batches for debbuging.'
1426
trainer = dict(limit_train_batches=5, limit_val_batches=5)
1527

16-
1728
@ex.named_config
1829
def passt():
30+
'use PaSST model'
31+
models = {
32+
"net": DynamicIngredient("models.passt.model_ing")
33+
}
34+
35+
@ex.named_config
36+
def passt_s_ap476():
37+
'use PaSST model pretrained on Audioset (with SWA) ap=476'
38+
# python ex_audioset.py evaluate_only with passt_s_ap476
1939
models = {
20-
"net": DynamicIngredient("models.vit.passt.model_ing")
40+
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_128_ap476", fstride=10,
41+
tstride=10)
2142
}
43+
44+
@ex.named_config
45+
def passt_s_ap4763():
46+
'use PaSST model pretrained on Audioset (with SWA) ap=4763'
47+
# test with: python ex_audioset.py evaluate_only with passt_s_ap4763
48+
models = {
49+
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_128_ap4763", fstride=10,
50+
tstride=10)
51+
}
52+
53+
@ex.named_config
54+
def passt_s_ap472():
55+
'use PaSST model pretrained on Audioset (no SWA) ap=472'
56+
# test with: python ex_audioset.py evaluate_only with passt_s_ap472
57+
models = {
58+
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_128_ap472", fstride=10,
59+
tstride=10)
60+
}
61+
62+
@ex.named_config
63+
def passt_s_p16_s16_128_ap468():
64+
'use PaSST model pretrained on Audioset (no SWA) ap=468 NO overlap'
65+
# test with: python ex_audioset.py evaluate_only with passt_s_p16_s16_128_ap468
66+
models = {
67+
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_s16_128_ap468", fstride=16,
68+
tstride=16)
69+
}
70+
71+
@ex.named_config
72+
def passt_s_swa_p16_s16_128_ap473():
73+
'use PaSST model pretrained on Audioset (SWA) ap=473 NO overlap'
74+
# test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s16_128_ap473
75+
models = {
76+
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_s16_128_ap473", fstride=16,
77+
tstride=16)
78+
}
79+
80+
@ex.named_config
81+
def passt_s_swa_p16_s14_128_ap471():
82+
'use PaSST model pretrained on Audioset stride=14 (SWA) ap=471 '
83+
# test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s14_128_ap471
84+
models = {
85+
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_s14_128_ap471", fstride=14,
86+
tstride=14)
87+
}
88+
89+
@ex.named_config
90+
def passt_s_p16_s14_128_ap469():
91+
'use PaSST model pretrained on Audioset stride=14 (No SWA) ap=469 '
92+
# test with: python ex_audioset.py evaluate_only with passt_s_p16_s14_128_ap469
93+
models = {
94+
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_s14_128_ap469", fstride=14,
95+
tstride=14)
96+
}
97+
98+
@ex.named_config
99+
def passt_s_swa_p16_s12_128_ap473():
100+
'use PaSST model pretrained on Audioset stride=12 (SWA) ap=473 '
101+
# test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s12_128_ap473
102+
models = {
103+
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_s12_128_ap473", fstride=12,
104+
tstride=12)
105+
}
106+
107+
@ex.named_config
108+
def passt_s_p16_s12_128_ap470():
109+
'use PaSST model pretrained on Audioset stride=12 (No SWA) ap=4670 '
110+
# test with: python ex_audioset.py evaluate_only with passt_s_p16_s12_128_ap470
111+
models = {
112+
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_s12_128_ap470", fstride=12,
113+
tstride=12)
114+
}
115+
116+
@ex.named_config
117+
def ensemble_s10():
118+
'use ensemble of PaSST models pretrained on Audioset with S10 mAP=.4864'
119+
# test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_s10
120+
models = {
121+
"net": DynamicIngredient("models.passt.model_ing", arch="ensemble_s10", fstride=None,
122+
tstride=None, instance_cmd="get_ensemble_model",
123+
# don't call get_model but rather get_ensemble_model
124+
arch_list=[
125+
("passt_s_swa_p16_128_ap476", 10, 10),
126+
("passt_s_swa_p16_128_ap4761", 10, 10),
127+
("passt_s_p16_128_ap472", 10, 10),
128+
]
129+
)
130+
}
131+
@ex.named_config
132+
def ensemble_many():
133+
'use ensemble of PaSST models pretrained on Audioset with different strides mAP=.4956'
134+
# test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_many
135+
models = {
136+
"net": DynamicIngredient("models.passt.model_ing", arch="ensemble_many", fstride=None,
137+
tstride=None, instance_cmd="get_ensemble_model",
138+
# don't call get_model but rather get_ensemble_model
139+
arch_list=[
140+
("passt_s_swa_p16_128_ap476", 10, 10),
141+
("passt_s_swa_p16_128_ap4761", 10, 10),
142+
("passt_s_p16_128_ap472", 10, 10),
143+
("passt_s_p16_s12_128_ap470", 12, 12),
144+
("passt_s_swa_p16_s12_128_ap473", 12, 12),
145+
("passt_s_p16_s14_128_ap469", 14, 14),
146+
("passt_s_swa_p16_s14_128_ap471", 14, 14),
147+
("passt_s_swa_p16_s16_128_ap473", 16, 16),
148+
("passt_s_p16_s16_128_ap468", 16, 16),
149+
]
150+
)
151+
}
152+
@ex.named_config
153+
def ensemble_s16_14():
154+
'use ensemble of PaSST models pretrained on Audioset with stride 16 and 14 mAP=.4863'
155+
# test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_s16_14
156+
models = {
157+
"net": DynamicIngredient("models.passt.model_ing", arch="ensemble_s16", fstride=None,
158+
tstride=None, instance_cmd="get_ensemble_model",
159+
# don't call get_model but rather get_ensemble_model
160+
arch_list=[
161+
("passt_s_p16_s14_128_ap469", 14, 14),
162+
("passt_s_swa_p16_s14_128_ap471", 14, 14),
163+
("passt_s_swa_p16_s16_128_ap473", 16, 16),
164+
("passt_s_p16_s16_128_ap468", 16, 16),
165+
]
166+
)
167+
}
168+
22169
@ex.named_config
23170
def dynamic_roll():
171+
# dynamically roll the spectrograms/waveforms
172+
# updates the dataset config
24173
basedataset = dict(roll=True, roll_conf=dict(axis=1, shift_range=10000)
25174
)
26175

27-
28176
# extra commands
29177

30178
@ex.command
@@ -46,4 +194,3 @@ def test_loaders_train_speed():
46194
print(f"{i}/{len(itr)}", end="\r")
47195
end = time.time()
48196
print("totoal time:", end - start)
49-

0 commit comments

Comments
 (0)