|
3 | 3 | #import struct
|
4 | 4 | import requests
|
5 | 5 | import re
|
| 6 | +import struct |
| 7 | +import numpy as np |
| 8 | +from concurrent.futures import ThreadPoolExecutor |
| 9 | + |
| 10 | + |
| 11 | +def fill_hann_window(size, periodic=True): |
| 12 | + if periodic: |
| 13 | + return np.hanning(size + 1)[:-1] |
| 14 | + return np.hanning(size) |
| 15 | + |
| 16 | + |
| 17 | +def irfft(n_fft, complex_input): |
| 18 | + return np.fft.irfft(complex_input, n=n_fft) |
| 19 | + |
| 20 | + |
| 21 | +def fold(buffer, n_out, n_win, n_hop, n_pad): |
| 22 | + result = np.zeros(n_out) |
| 23 | + n_frames = len(buffer) // n_win |
| 24 | + |
| 25 | + for i in range(n_frames): |
| 26 | + start = i * n_hop |
| 27 | + end = start + n_win |
| 28 | + result[start:end] += buffer[i * n_win:(i + 1) * n_win] |
| 29 | + |
| 30 | + return result[n_pad:-n_pad] if n_pad > 0 else result |
| 31 | + |
| 32 | + |
| 33 | +def process_frame(args): |
| 34 | + l, n_fft, ST, hann = args |
| 35 | + frame = irfft(n_fft, ST[l]) |
| 36 | + frame = frame * hann |
| 37 | + hann2 = hann * hann |
| 38 | + return frame, hann2 |
| 39 | + |
| 40 | + |
| 41 | +def embd_to_audio(embd, n_codes, n_embd, n_thread=4): |
| 42 | + embd = np.asarray(embd, dtype=np.float32).reshape(n_codes, n_embd) |
| 43 | + |
| 44 | + n_fft = 1280 |
| 45 | + n_hop = 320 |
| 46 | + n_win = 1280 |
| 47 | + n_pad = (n_win - n_hop) // 2 |
| 48 | + n_out = (n_codes - 1) * n_hop + n_win |
| 49 | + |
| 50 | + hann = fill_hann_window(n_fft, True) |
| 51 | + |
| 52 | + E = np.zeros((n_embd, n_codes), dtype=np.float32) |
| 53 | + for l in range(n_codes): |
| 54 | + for k in range(n_embd): |
| 55 | + E[k, l] = embd[l, k] |
| 56 | + |
| 57 | + half_embd = n_embd // 2 |
| 58 | + S = np.zeros((n_codes, half_embd + 1), dtype=np.complex64) |
| 59 | + |
| 60 | + for k in range(half_embd): |
| 61 | + for l in range(n_codes): |
| 62 | + mag = E[k, l] |
| 63 | + phi = E[k + half_embd, l] |
| 64 | + |
| 65 | + mag = np.clip(np.exp(mag), 0, 1e2) |
| 66 | + S[l, k] = mag * np.exp(1j * phi) |
| 67 | + |
| 68 | + res = np.zeros(n_codes * n_fft) |
| 69 | + hann2_buffer = np.zeros(n_codes * n_fft) |
| 70 | + |
| 71 | + with ThreadPoolExecutor(max_workers=n_thread) as executor: |
| 72 | + args = [(l, n_fft, S, hann) for l in range(n_codes)] |
| 73 | + results = list(executor.map(process_frame, args)) |
| 74 | + |
| 75 | + for l, (frame, hann2) in enumerate(results): |
| 76 | + res[l*n_fft:(l+1)*n_fft] = frame |
| 77 | + hann2_buffer[l*n_fft:(l+1)*n_fft] = hann2 |
| 78 | + |
| 79 | + audio = fold(res, n_out, n_win, n_hop, n_pad) |
| 80 | + env = fold(hann2_buffer, n_out, n_win, n_hop, n_pad) |
| 81 | + |
| 82 | + mask = env > 1e-10 |
| 83 | + audio[mask] /= env[mask] |
| 84 | + |
| 85 | + return audio |
| 86 | + |
| 87 | + |
| 88 | +def save_wav(filename, audio_data, sample_rate): |
| 89 | + num_channels = 1 |
| 90 | + bits_per_sample = 16 |
| 91 | + bytes_per_sample = bits_per_sample // 8 |
| 92 | + data_size = len(audio_data) * bytes_per_sample |
| 93 | + byte_rate = sample_rate * num_channels * bytes_per_sample |
| 94 | + block_align = num_channels * bytes_per_sample |
| 95 | + chunk_size = 36 + data_size # 36 = size of header minus first 8 bytes |
| 96 | + |
| 97 | + header = struct.pack( |
| 98 | + '<4sI4s4sIHHIIHH4sI', |
| 99 | + b'RIFF', |
| 100 | + chunk_size, |
| 101 | + b'WAVE', |
| 102 | + b'fmt ', |
| 103 | + 16, # fmt chunk size |
| 104 | + 1, # audio format (PCM) |
| 105 | + num_channels, |
| 106 | + sample_rate, |
| 107 | + byte_rate, |
| 108 | + block_align, |
| 109 | + bits_per_sample, |
| 110 | + b'data', |
| 111 | + data_size |
| 112 | + ) |
| 113 | + |
| 114 | + audio_data = np.clip(audio_data * 32767, -32768, 32767) |
| 115 | + pcm_data = audio_data.astype(np.int16) |
| 116 | + |
| 117 | + with open(filename, 'wb') as f: |
| 118 | + f.write(header) |
| 119 | + f.write(pcm_data.tobytes()) |
| 120 | + |
6 | 121 |
|
7 | 122 | def process_text(text: str):
|
8 | 123 | text = re.sub(r'\d+(\.\d+)?', lambda x: x.group(), text.lower()) # TODO this needs to be fixed
|
@@ -170,6 +285,15 @@ def process_text(text: str):
|
170 | 285 | print('spectrogram generated: n_codes: %d, n_embd: %d' % (n_codes, n_embd))
|
171 | 286 |
|
172 | 287 | # post-process the spectrogram to convert to audio
|
173 |
| -# TODO: see the tts.cpp:embd_to_audio() and implement it in Python |
174 | 288 | print('converting to audio ...')
|
175 |
| -print('TODO: see the tts.cpp:embd_to_audio() and implement it in Python') |
| 289 | +audio = embd_to_audio(embd, n_codes, n_embd) |
| 290 | +print('audio generated: %d samples' % len(audio)) |
| 291 | + |
| 292 | +filename = "output.wav" |
| 293 | +sample_rate = 24000 # sampling rate |
| 294 | + |
| 295 | +# zero out first 0.25 seconds |
| 296 | +audio[:24000 // 4] = 0.0 |
| 297 | + |
| 298 | +save_wav(filename, audio, sample_rate) |
| 299 | +print('audio written to file "%s"' % filename) |
0 commit comments