|
17 | 17 |
|
18 | 18 |
|
19 | 19 | class Hubert(nn.Module): |
20 | | - def __init__(self, num_label_embeddings: int = 100, mask=True): |
| 20 | + def __init__(self, num_label_embeddings: int = 100, mask: bool = True): |
21 | 21 | super().__init__() |
22 | 22 | self._mask = mask |
23 | 23 | self.feature_extractor = FeatureExtractor() |
@@ -69,6 +69,28 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
69 | 69 | return logits, mask |
70 | 70 |
|
71 | 71 |
|
| 72 | +class HubertSoft(Hubert): |
| 73 | + def __init__(self): |
| 74 | + super().__init__() |
| 75 | + |
| 76 | + def units(self, wav: torch.Tensor) -> torch.Tensor: |
| 77 | + wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) |
| 78 | + x, _ = self.encode(wav) |
| 79 | + return self.proj(x) |
| 80 | + |
| 81 | + |
| 82 | +class HubertDiscrete(Hubert): |
| 83 | + def __init__(self, kmeans): |
| 84 | + super().__init__() |
| 85 | + self.kmeans = kmeans |
| 86 | + |
| 87 | + def units(self, wav: torch.Tensor) -> torch.LongTensor: |
| 88 | + wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) |
| 89 | + x, _ = self.encode(wav, layer=7) |
| 90 | + x = self.kmeans.predict(x.squeeze().cpu().numpy()) |
| 91 | + return torch.tensor(x, dtype=torch.long, device=wav.device) |
| 92 | + |
| 93 | + |
72 | 94 | class FeatureExtractor(nn.Module): |
73 | 95 | def __init__(self): |
74 | 96 | super().__init__() |
@@ -204,43 +226,45 @@ def _compute_mask( |
204 | 226 | return mask |
205 | 227 |
|
206 | 228 |
|
207 | | -def _hubert( |
208 | | - name: str, |
209 | | - num_label_embeddings: int, |
210 | | - pretrained: bool = True, |
211 | | - progress: bool = True, |
212 | | -) -> Hubert: |
213 | | - hubert = Hubert(num_label_embeddings) |
214 | | - if pretrained: |
215 | | - checkpoint = torch.hub.load_state_dict_from_url(URLS[name], progress=progress) |
216 | | - consume_prefix_in_state_dict_if_present(checkpoint, "module.") |
217 | | - hubert.load_state_dict(checkpoint) |
218 | | - hubert.eval() |
219 | | - return hubert |
220 | | - |
221 | | - |
222 | 229 | def hubert_discrete( |
223 | 230 | pretrained: bool = True, |
224 | 231 | progress: bool = True, |
225 | | -) -> Hubert: |
| 232 | +) -> HubertDiscrete: |
226 | 233 | r"""HuBERT-Discrete from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. |
227 | 234 | Args: |
228 | 235 | pretrained (bool): load pretrained weights into the model |
229 | 236 | progress (bool): show progress bar when downloading model |
230 | 237 | """ |
231 | | - return _hubert("hubert-discrete", 504, pretrained, progress) |
| 238 | + kmeans = kmeans100(pretrained=pretrained, progress=progress) |
| 239 | + hubert = HubertDiscrete(kmeans) |
| 240 | + if pretrained: |
| 241 | + checkpoint = torch.hub.load_state_dict_from_url( |
| 242 | + URLS["hubert-discrete"], progress=progress |
| 243 | + ) |
| 244 | + consume_prefix_in_state_dict_if_present(checkpoint, "module.") |
| 245 | + hubert.load_state_dict(checkpoint) |
| 246 | + hubert.eval() |
| 247 | + return hubert |
232 | 248 |
|
233 | 249 |
|
234 | 250 | def hubert_soft( |
235 | 251 | pretrained: bool = True, |
236 | 252 | progress: bool = True, |
237 | | -) -> Hubert: |
| 253 | +) -> HubertSoft: |
238 | 254 | r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. |
239 | 255 | Args: |
240 | 256 | pretrained (bool): load pretrained weights into the model |
241 | 257 | progress (bool): show progress bar when downloading model |
242 | 258 | """ |
243 | | - return _hubert("hubert-soft", 100, pretrained, progress) |
| 259 | + hubert = HubertSoft() |
| 260 | + if pretrained: |
| 261 | + checkpoint = torch.hub.load_state_dict_from_url( |
| 262 | + URLS["hubert-soft"], progress=progress |
| 263 | + ) |
| 264 | + consume_prefix_in_state_dict_if_present(checkpoint, "module.") |
| 265 | + hubert.load_state_dict(checkpoint) |
| 266 | + hubert.eval() |
| 267 | + return hubert |
244 | 268 |
|
245 | 269 |
|
246 | 270 | def _kmeans( |
|
0 commit comments