-
Notifications
You must be signed in to change notification settings - Fork 488
/
Copy pathgenerate_aligned_predictions.py
177 lines (142 loc) · 5.73 KB
/
generate_aligned_predictions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# coding: utf-8
"""
Generate ground trouth-aligned predictions
usage: generate_aligned_predictions.py [options] <checkpoint> <in_dir> <out_dir>
options:
--hparams=<parmas> Hyper parameters [default: ].
--preset=<json> Path of preset parameters (json).
--overwrite Overwrite audio and mel outputs.
-h, --help Show help message.
"""
from docopt import docopt
import os
from tqdm import tqdm
import importlib
from os.path import join
from warnings import warn
import sys
import numpy as np
import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
# The deepvoice3 model
from deepvoice3_pytorch import frontend
from hparams import hparams
use_cuda = torch.cuda.is_available()
_frontend = None # to be set later
def preprocess(model, in_dir, out_dir, text, audio_filename, mel_filename,
p=0, speaker_id=None,
fast=False):
"""Generate ground truth-aligned prediction
The output of the network and corresponding audio are saved after time
resolution adjustment.
"""
r = hparams.outputs_per_step
downsample_step = hparams.downsample_step
if use_cuda:
model = model.cuda()
model.eval()
if fast:
model.make_generation_fast_()
mel_org = np.load(join(in_dir, mel_filename))
# zero padd
b_pad = r # imitates initial state
e_pad = r - len(mel_org) % r if len(mel_org) % r > 0 else 0
mel = np.pad(mel_org, [(b_pad, e_pad), (0, 0)],
mode="constant", constant_values=0)
mel = Variable(torch.from_numpy(mel)).unsqueeze(0).contiguous()
# Downsample mel spectrogram
if downsample_step > 1:
mel = mel[:, 0::downsample_step, :].contiguous()
decoder_target_len = mel.shape[1] // r
s, e = 1, decoder_target_len + 1
frame_positions = torch.arange(s, e).long().unsqueeze(0)
frame_positions = Variable(frame_positions)
sequence = np.array(_frontend.text_to_sequence(text, p=p))
sequence = Variable(torch.from_numpy(sequence)).unsqueeze(0)
text_positions = torch.arange(1, sequence.size(-1) + 1).unsqueeze(0).long()
text_positions = Variable(text_positions)
speaker_ids = None if speaker_id is None else Variable(torch.LongTensor([speaker_id]))
if use_cuda:
sequence = sequence.cuda()
text_positions = text_positions.cuda()
speaker_ids = None if speaker_ids is None else speaker_ids.cuda()
mel = mel.cuda()
frame_positions = frame_positions.cuda()
# **Teacher forcing** decoding
mel_outputs, _, _, _ = model(
sequence, mel, text_positions=text_positions,
frame_positions=frame_positions, speaker_ids=speaker_ids)
mel_output = mel_outputs[0].data.cpu().numpy()
# **Time resolution adjustment**
mel_output = mel_output[:-(b_pad + e_pad)]
wav = np.load(join(in_dir, audio_filename))
assert len(wav) % hparams.hop_size == 0
# Coarse upsample just for convenience
# so that we can upsample conditional features by hop_size in wavenet
if downsample_step > 0:
mel_output = np.repeat(mel_output, downsample_step, axis=0)
# downsampling -> upsampling, then we should have length equal to or larger than
# the original mel length
assert mel_output.shape[0] >= mel_org.shape[0]
# Make sure we have correct lengths
assert mel_output.shape[0] * hparams.hop_size == len(wav)
timesteps = len(wav)
# save
np.save(join(out_dir, audio_filename), wav, allow_pickle=False)
np.save(join(out_dir, mel_filename), mel_output.astype(np.float32),
allow_pickle=False)
if speaker_id is None:
return (audio_filename, mel_filename, timesteps, text)
else:
return (audio_filename, mel_filename, timesteps, text, speaker_id)
def write_metadata(metadata, out_dir):
with open(os.path.join(out_dir, 'train.txt'), 'w', encoding='utf-8') as f:
for m in metadata:
f.write('|'.join([str(x) for x in m]) + '\n')
frames = sum([m[2] for m in metadata])
sr = hparams.sample_rate
hours = frames / sr / 3600
print('Wrote %d utterances, %d time steps (%.2f hours)' % (len(metadata), frames, hours))
print('Max input length: %d' % max(len(m[3]) for m in metadata))
print('Max output length: %d' % max(m[2] for m in metadata))
if __name__ == "__main__":
args = docopt(__doc__)
checkpoint_path = args["<checkpoint>"]
in_dir = args["<in_dir>"]
out_dir = args["<out_dir>"]
preset = args["--preset"]
# Load preset if specified
if preset is not None:
with open(preset) as f:
hparams.parse_json(f.read())
# Override hyper parameters
hparams.parse(args["--hparams"])
assert hparams.name == "deepvoice3"
_frontend = getattr(frontend, hparams.frontend)
import train
train._frontend = _frontend
from train import build_model
model = build_model()
# Load checkpoint
print("Load checkpoint from {}".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["state_dict"])
os.makedirs(out_dir, exist_ok=True)
results = []
with open(os.path.join(in_dir, "train.txt")) as f:
lines = f.readlines()
for idx in tqdm(range(len(lines))):
l = lines[idx]
l = l[:-1].split("|")
audio_filename, mel_filename, _, text = l[:4]
speaker_id = int(l[4]) if len(l) > 4 else None
if text == "N/A":
raise RuntimeError("No transcription available")
result = preprocess(model, in_dir, out_dir, text, audio_filename,
mel_filename, p=0,
speaker_id=speaker_id, fast=True)
results.append(result)
write_metadata(results, out_dir)
sys.exit(0)