Skip to content

Commit c6c9ba5

Browse files
ljhzxcluotao1
authored andcommitted
[TTS] [黑客松]Add JETS (PaddlePaddle#3109)
1 parent 97ca0da commit c6c9ba5

25 files changed

+4481
-1
lines changed

examples/csmsc/jets/README.md

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# JETS with CSMSC
2+
This example contains code used to train a [JETS](https://arxiv.org/abs/2203.16852v1) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html).
3+
4+
## Dataset
5+
### Download and Extract
6+
Download CSMSC from it's [Official Website](https://test.data-baker.com/data/index/source).
7+
8+
### Get MFA Result and Extract
9+
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get phonemes and durations for JETS.
10+
You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo.
11+
12+
## Get Started
13+
Assume the path to the dataset is `~/datasets/BZNSYP`.
14+
Assume the path to the MFA result of CSMSC is `./baker_alignment_tone`.
15+
Run the command below to
16+
1. **source path**.
17+
2. preprocess the dataset.
18+
3. train the model.
19+
4. synthesize wavs.
20+
- synthesize waveform from `metadata.jsonl`.
21+
- synthesize waveform from a text file.
22+
23+
```bash
24+
./run.sh
25+
```
26+
You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset.
27+
```bash
28+
./run.sh --stage 0 --stop-stage 0
29+
```
30+
### Data Preprocessing
31+
```bash
32+
./local/preprocess.sh ${conf_path}
33+
```
34+
When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below.
35+
36+
```text
37+
dump
38+
├── dev
39+
│   ├── norm
40+
│   └── raw
41+
├── phone_id_map.txt
42+
├── speaker_id_map.txt
43+
├── test
44+
│   ├── norm
45+
│   └── raw
46+
└── train
47+
├── feats_stats.npy
48+
├── norm
49+
└── raw
50+
```
51+
The dataset is split into 3 parts, namely `train`, `dev`, and` test`, each of which contains a `norm` and `raw` subfolder. The raw folder contains wave、mel spectrogram、speech、pitch and energy features of each utterance, while the norm folder contains normalized ones. The statistics used to normalize features are computed from the training set, which is located in `dump/train/feats_stats.npy`.
52+
53+
Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains phones, text_lengths, the path of feats, feats_lengths, the path of pitch features, the path of energy features, the path of raw waves, speaker, and the id of each utterance.
54+
55+
### Model Training
56+
```bash
57+
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path}
58+
```
59+
`./local/train.sh` calls `${BIN_DIR}/train.py`.
60+
Here's the complete help message.
61+
```text
62+
usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA]
63+
[--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR]
64+
[--ngpu NGPU] [--phones-dict PHONES_DICT]
65+
66+
Train a JETS model.
67+
68+
optional arguments:
69+
-h, --help show this help message and exit
70+
--config CONFIG config file to overwrite default config.
71+
--train-metadata TRAIN_METADATA
72+
training data.
73+
--dev-metadata DEV_METADATA
74+
dev data.
75+
--output-dir OUTPUT_DIR
76+
output dir.
77+
--ngpu NGPU if ngpu == 0, use cpu.
78+
--phones-dict PHONES_DICT
79+
phone vocabulary file.
80+
```
81+
1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`.
82+
2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder.
83+
3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory.
84+
4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
85+
5. `--phones-dict` is the path of the phone vocabulary file.
86+
87+
### Synthesizing
88+
89+
`./local/synthesize.sh` calls `${BIN_DIR}/synthesize.py`, which can synthesize waveform from `metadata.jsonl`.
90+
91+
```bash
92+
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name}
93+
```
94+
95+
`./local/synthesize_e2e.sh` calls `${BIN_DIR}/synthesize_e2e.py`, which can synthesize waveform from text file.
96+
```bash
97+
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name}
98+
```
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# This configuration tested on 4 GPUs (V100) with 32GB GPU
2+
# memory. It takes around 2 weeks to finish the training
3+
# but 100k iters model should generate reasonable results.
4+
###########################################################
5+
# FEATURE EXTRACTION SETTING #
6+
###########################################################
7+
8+
n_mels: 80
9+
fs: 22050 # sr
10+
n_fft: 1024 # FFT size (samples).
11+
n_shift: 256 # Hop size (samples). 12.5ms
12+
win_length: null # Window length (samples). 50ms
13+
# If set to null, it will be the same as fft_size.
14+
window: "hann" # Window function.
15+
fmin: 0 # minimum frequency for Mel basis
16+
fmax: null # maximum frequency for Mel basis
17+
f0min: 80 # Minimum f0 for pitch extraction.
18+
f0max: 400 # Maximum f0 for pitch extraction.
19+
20+
21+
##########################################################
22+
# TTS MODEL SETTING #
23+
##########################################################
24+
model:
25+
# generator related
26+
generator_type: jets_generator
27+
generator_params:
28+
adim: 256 # attention dimension
29+
aheads: 2 # number of attention heads
30+
elayers: 4 # number of encoder layers
31+
eunits: 1024 # number of encoder ff units
32+
dlayers: 4 # number of decoder layers
33+
dunits: 1024 # number of decoder ff units
34+
positionwise_layer_type: conv1d # type of position-wise layer
35+
positionwise_conv_kernel_size: 3 # kernel size of position wise conv layer
36+
duration_predictor_layers: 2 # number of layers of duration predictor
37+
duration_predictor_chans: 256 # number of channels of duration predictor
38+
duration_predictor_kernel_size: 3 # filter size of duration predictor
39+
use_masking: True # whether to apply masking for padded part in loss calculation
40+
encoder_normalize_before: True # whether to perform layer normalization before the input
41+
decoder_normalize_before: True # whether to perform layer normalization before the input
42+
encoder_type: transformer # encoder type
43+
decoder_type: transformer # decoder type
44+
conformer_rel_pos_type: latest # relative positional encoding type
45+
conformer_pos_enc_layer_type: rel_pos # conformer positional encoding type
46+
conformer_self_attn_layer_type: rel_selfattn # conformer self-attention type
47+
conformer_activation_type: swish # conformer activation type
48+
use_macaron_style_in_conformer: true # whether to use macaron style in conformer
49+
use_cnn_in_conformer: true # whether to use CNN in conformer
50+
conformer_enc_kernel_size: 7 # kernel size in CNN module of conformer-based encoder
51+
conformer_dec_kernel_size: 31 # kernel size in CNN module of conformer-based decoder
52+
init_type: xavier_uniform # initialization type
53+
init_enc_alpha: 1.0 # initial value of alpha for encoder
54+
init_dec_alpha: 1.0 # initial value of alpha for decoder
55+
transformer_enc_dropout_rate: 0.2 # dropout rate for transformer encoder layer
56+
transformer_enc_positional_dropout_rate: 0.2 # dropout rate for transformer encoder positional encoding
57+
transformer_enc_attn_dropout_rate: 0.2 # dropout rate for transformer encoder attention layer
58+
transformer_dec_dropout_rate: 0.2 # dropout rate for transformer decoder layer
59+
transformer_dec_positional_dropout_rate: 0.2 # dropout rate for transformer decoder positional encoding
60+
transformer_dec_attn_dropout_rate: 0.2 # dropout rate for transformer decoder attention layer
61+
pitch_predictor_layers: 5 # number of conv layers in pitch predictor
62+
pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor
63+
pitch_predictor_kernel_size: 5 # kernel size of conv leyers in pitch predictor
64+
pitch_predictor_dropout: 0.5 # dropout rate in pitch predictor
65+
pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch
66+
pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch
67+
stop_gradient_from_pitch_predictor: true # whether to stop the gradient from pitch predictor to encoder
68+
energy_predictor_layers: 2 # number of conv layers in energy predictor
69+
energy_predictor_chans: 256 # number of channels of conv layers in energy predictor
70+
energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor
71+
energy_predictor_dropout: 0.5 # dropout rate in energy predictor
72+
energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy
73+
energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy
74+
stop_gradient_from_energy_predictor: false # whether to stop the gradient from energy predictor to encoder
75+
generator_out_channels: 1
76+
generator_channels: 512
77+
generator_global_channels: -1
78+
generator_kernel_size: 7
79+
generator_upsample_scales: [8, 8, 2, 2]
80+
generator_upsample_kernel_sizes: [16, 16, 4, 4]
81+
generator_resblock_kernel_sizes: [3, 7, 11]
82+
generator_resblock_dilations: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
83+
generator_use_additional_convs: true
84+
generator_bias: true
85+
generator_nonlinear_activation: "leakyrelu"
86+
generator_nonlinear_activation_params:
87+
negative_slope: 0.1
88+
generator_use_weight_norm: true
89+
segment_size: 64 # segment size for random windowed discriminator
90+
91+
# discriminator related
92+
discriminator_type: hifigan_multi_scale_multi_period_discriminator
93+
discriminator_params:
94+
scales: 1
95+
scale_downsample_pooling: "AvgPool1D"
96+
scale_downsample_pooling_params:
97+
kernel_size: 4
98+
stride: 2
99+
padding: 2
100+
scale_discriminator_params:
101+
in_channels: 1
102+
out_channels: 1
103+
kernel_sizes: [15, 41, 5, 3]
104+
channels: 128
105+
max_downsample_channels: 1024
106+
max_groups: 16
107+
bias: True
108+
downsample_scales: [2, 2, 4, 4, 1]
109+
nonlinear_activation: "leakyrelu"
110+
nonlinear_activation_params:
111+
negative_slope: 0.1
112+
use_weight_norm: True
113+
use_spectral_norm: False
114+
follow_official_norm: False
115+
periods: [2, 3, 5, 7, 11]
116+
period_discriminator_params:
117+
in_channels: 1
118+
out_channels: 1
119+
kernel_sizes: [5, 3]
120+
channels: 32
121+
downsample_scales: [3, 3, 3, 3, 1]
122+
max_downsample_channels: 1024
123+
bias: True
124+
nonlinear_activation: "leakyrelu"
125+
nonlinear_activation_params:
126+
negative_slope: 0.1
127+
use_weight_norm: True
128+
use_spectral_norm: False
129+
# others
130+
sampling_rate: 22050 # needed in the inference for saving wav
131+
cache_generator_outputs: True # whether to cache generator outputs in the training
132+
use_alignment_module: False # whether to use alignment module
133+
134+
###########################################################
135+
# LOSS SETTING #
136+
###########################################################
137+
# loss function related
138+
generator_adv_loss_params:
139+
average_by_discriminators: False # whether to average loss value by #discriminators
140+
loss_type: mse # loss type, "mse" or "hinge"
141+
discriminator_adv_loss_params:
142+
average_by_discriminators: False # whether to average loss value by #discriminators
143+
loss_type: mse # loss type, "mse" or "hinge"
144+
feat_match_loss_params:
145+
average_by_discriminators: False # whether to average loss value by #discriminators
146+
average_by_layers: False # whether to average loss value by #layers of each discriminator
147+
include_final_outputs: True # whether to include final outputs for loss calculation
148+
mel_loss_params:
149+
fs: 22050 # must be the same as the training data
150+
fft_size: 1024 # fft points
151+
hop_size: 256 # hop size
152+
win_length: null # window length
153+
window: hann # window type
154+
num_mels: 80 # number of Mel basis
155+
fmin: 0 # minimum frequency for Mel basis
156+
fmax: null # maximum frequency for Mel basis
157+
log_base: null # null represent natural log
158+
159+
###########################################################
160+
# ADVERSARIAL LOSS SETTING #
161+
###########################################################
162+
lambda_adv: 1.0 # loss scaling coefficient for adversarial loss
163+
lambda_mel: 45.0 # loss scaling coefficient for Mel loss
164+
lambda_feat_match: 2.0 # loss scaling coefficient for feat match loss
165+
lambda_var: 1.0 # loss scaling coefficient for duration loss
166+
lambda_align: 2.0 # loss scaling coefficient for KL divergence loss
167+
# others
168+
sampling_rate: 22050 # needed in the inference for saving wav
169+
cache_generator_outputs: True # whether to cache generator outputs in the training
170+
171+
172+
# extra module for additional inputs
173+
pitch_extract: dio # pitch extractor type
174+
pitch_extract_conf:
175+
reduction_factor: 1
176+
use_token_averaged_f0: false
177+
pitch_normalize: global_mvn # normalizer for the pitch feature
178+
energy_extract: energy # energy extractor type
179+
energy_extract_conf:
180+
reduction_factor: 1
181+
use_token_averaged_energy: false
182+
energy_normalize: global_mvn # normalizer for the energy feature
183+
184+
185+
###########################################################
186+
# DATA LOADER SETTING #
187+
###########################################################
188+
batch_size: 32 # Batch size.
189+
num_workers: 4 # Number of workers in DataLoader.
190+
191+
##########################################################
192+
# OPTIMIZER & SCHEDULER SETTING #
193+
##########################################################
194+
# optimizer setting for generator
195+
generator_optimizer_params:
196+
beta1: 0.8
197+
beta2: 0.99
198+
epsilon: 1.0e-9
199+
weight_decay: 0.0
200+
generator_scheduler: exponential_decay
201+
generator_scheduler_params:
202+
learning_rate: 2.0e-4
203+
gamma: 0.999875
204+
205+
# optimizer setting for discriminator
206+
discriminator_optimizer_params:
207+
beta1: 0.8
208+
beta2: 0.99
209+
epsilon: 1.0e-9
210+
weight_decay: 0.0
211+
discriminator_scheduler: exponential_decay
212+
discriminator_scheduler_params:
213+
learning_rate: 2.0e-4
214+
gamma: 0.999875
215+
generator_first: True # whether to start updating generator first
216+
217+
##########################################################
218+
# OTHER TRAINING SETTING #
219+
##########################################################
220+
num_snapshots: 10 # max number of snapshots to keep while training
221+
train_max_steps: 350000 # Number of training steps. == total_iters / ngpus, total_iters = 1000000
222+
save_interval_steps: 1000 # Interval steps to save checkpoint.
223+
eval_interval_steps: 250 # Interval steps to evaluate the network.
224+
seed: 777 # random seed number
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#!/bin/bash
2+
3+
train_output_path=$1
4+
5+
stage=0
6+
stop_stage=0
7+
8+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
9+
python3 ${BIN_DIR}/inference.py \
10+
--inference_dir=${train_output_path}/inference \
11+
--am=jets_csmsc \
12+
--text=${BIN_DIR}/../sentences.txt \
13+
--output_dir=${train_output_path}/pd_infer_out \
14+
--phones_dict=dump/phone_id_map.txt
15+
fi

0 commit comments

Comments
 (0)