Skip to content

Commit edce489

Browse files
author
Benjamin Van Niekerk
committed
Initial commit
1 parent a4674a0 commit edce489

5 files changed

Lines changed: 58 additions & 96 deletions

File tree

README.md

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ units = kmeans.predict(x.squeeze().cpu().numpy())
5656

5757
**Step 1**: Download and extract the [LibriSpeech](https://www.openslr.org/12) corpus.
5858

59-
**Step 2**: Encode LibriSpeech using the HuBERT-Discrete model and `encode.py` script (setting `--layer=7`):
59+
**Step 2**: Encode LibriSpeech using the HuBERT-Discrete model and `encode.py` script:
6060

6161
```
62-
usage: encode.py [-h] [--extension EXTENSION] [--model {hubert_soft,hubert_discrete}] [--layer LAYER] in-dir out-dir
62+
usage: encode.py [-h] [--extension EXTENSION] [--model {hubert_soft,hubert_discrete}] in-dir out-dir
6363
6464
Encode an audio dataset.
6565
@@ -73,31 +73,15 @@ optional arguments:
7373
extension of the audio files.
7474
--model {hubert_soft,hubert_discrete}
7575
available models
76-
--layer LAYER the selected transformer layer (defaults to the last layer)
7776
```
7877

7978
for example:
8079

8180
```
82-
python encode.py path/to/LibriSpeech path/to/LibriSpeech/
81+
python encode.py path/to/LibriSpeech/wavs path/to/LibriSpeech/units --model hubert_discrete
8382
```
8483

85-
**Step 3**: Discretize the extracted features using the k-means checkpoint and `discretize.py` script:
86-
87-
```
88-
usage: discretize.py [-h] in-dir out-dir
89-
90-
Discretize HuBERT features.
91-
92-
positional arguments:
93-
in-dir path to the dataset directory.
94-
out-dir path to the output directory.
95-
96-
optional arguments:
97-
-h, --help show this help message and exit
98-
```
99-
100-
**Step 5**: Train the HuBERT-Soft model using the `train.py` script:
84+
**Step 3**: Train the HuBERT-Soft model using the `train.py` script:
10185

10286
```
10387
usage: train.py [-h] [--resume RESUME] [--warmstart] [--mask] [--alpha ALPHA] dataset-dir checkpoint-dir

discretize.py

Lines changed: 0 additions & 42 deletions
This file was deleted.

encode.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from tqdm import tqdm
66

77
import torch
8-
import torch.nn.functional as F
98
import torchaudio
109
from torchaudio.functional import resample
1110

@@ -22,17 +21,13 @@ def encode_dataset(args):
2221
wav, sr = torchaudio.load(in_path)
2322
wav = resample(wav, sr, 16000)
2423
wav = wav.unsqueeze(0).cuda()
25-
wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
2624

27-
# Extract hubert features from the args.layer transformer layer
2825
with torch.inference_mode():
29-
x, _ = hubert.encode(wav, layer=args.layer)
30-
if args.layer is None:
31-
x = hubert.proj(x)
26+
units = hubert.units(wav)
3227

3328
out_path = args.out_dir / in_path.relative_to(args.in_dir)
3429
out_path.parent.mkdir(parents=True, exist_ok=True)
35-
np.save(out_path.with_suffix(".npy"), x.squeeze(0).cpu().numpy())
30+
np.save(out_path.with_suffix(".npy"), units.squeeze().cpu().numpy())
3631

3732

3833
if __name__ == "__main__":
@@ -61,11 +56,5 @@ def encode_dataset(args):
6156
choices=["hubert_soft", "hubert_discrete"],
6257
default="hubert_soft",
6358
)
64-
parser.add_argument(
65-
"--layer",
66-
help="the selected transformer layer (defaults to the last layer)",
67-
default=None,
68-
type=int,
69-
)
7059
args = parser.parse_args()
7160
encode_dataset(args)

hubert/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,8 @@
1-
from .model import Hubert, hubert_discrete, hubert_soft, kmeans100
1+
from .model import (
2+
Hubert,
3+
HubertDiscrete,
4+
HubertSoft,
5+
hubert_discrete,
6+
hubert_soft,
7+
kmeans100,
8+
)

hubert/model.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
class Hubert(nn.Module):
20-
def __init__(self, num_label_embeddings: int = 100, mask=True):
20+
def __init__(self, num_label_embeddings: int = 100, mask: bool = True):
2121
super().__init__()
2222
self._mask = mask
2323
self.feature_extractor = FeatureExtractor()
@@ -69,6 +69,28 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
6969
return logits, mask
7070

7171

72+
class HubertSoft(Hubert):
73+
def __init__(self):
74+
super().__init__()
75+
76+
def units(self, wav: torch.Tensor) -> torch.Tensor:
77+
wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
78+
x, _ = self.encode(wav)
79+
return self.proj(x)
80+
81+
82+
class HubertDiscrete(Hubert):
83+
def __init__(self, kmeans):
84+
super().__init__()
85+
self.kmeans = kmeans
86+
87+
def units(self, wav: torch.Tensor) -> torch.LongTensor:
88+
wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
89+
x, _ = self.encode(wav, layer=7)
90+
x = self.kmeans.predict(x.squeeze().cpu().numpy())
91+
return torch.tensor(x, dtype=torch.long, device=wav.device)
92+
93+
7294
class FeatureExtractor(nn.Module):
7395
def __init__(self):
7496
super().__init__()
@@ -204,43 +226,45 @@ def _compute_mask(
204226
return mask
205227

206228

207-
def _hubert(
208-
name: str,
209-
num_label_embeddings: int,
210-
pretrained: bool = True,
211-
progress: bool = True,
212-
) -> Hubert:
213-
hubert = Hubert(num_label_embeddings)
214-
if pretrained:
215-
checkpoint = torch.hub.load_state_dict_from_url(URLS[name], progress=progress)
216-
consume_prefix_in_state_dict_if_present(checkpoint, "module.")
217-
hubert.load_state_dict(checkpoint)
218-
hubert.eval()
219-
return hubert
220-
221-
222229
def hubert_discrete(
223230
pretrained: bool = True,
224231
progress: bool = True,
225-
) -> Hubert:
232+
) -> HubertDiscrete:
226233
r"""HuBERT-Discrete from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
227234
Args:
228235
pretrained (bool): load pretrained weights into the model
229236
progress (bool): show progress bar when downloading model
230237
"""
231-
return _hubert("hubert-discrete", 504, pretrained, progress)
238+
kmeans = kmeans100(pretrained=pretrained, progress=progress)
239+
hubert = HubertDiscrete(kmeans)
240+
if pretrained:
241+
checkpoint = torch.hub.load_state_dict_from_url(
242+
URLS["hubert-discrete"], progress=progress
243+
)
244+
consume_prefix_in_state_dict_if_present(checkpoint, "module.")
245+
hubert.load_state_dict(checkpoint)
246+
hubert.eval()
247+
return hubert
232248

233249

234250
def hubert_soft(
235251
pretrained: bool = True,
236252
progress: bool = True,
237-
) -> Hubert:
253+
) -> HubertSoft:
238254
r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
239255
Args:
240256
pretrained (bool): load pretrained weights into the model
241257
progress (bool): show progress bar when downloading model
242258
"""
243-
return _hubert("hubert-soft", 100, pretrained, progress)
259+
hubert = HubertSoft()
260+
if pretrained:
261+
checkpoint = torch.hub.load_state_dict_from_url(
262+
URLS["hubert-soft"], progress=progress
263+
)
264+
consume_prefix_in_state_dict_if_present(checkpoint, "module.")
265+
hubert.load_state_dict(checkpoint)
266+
hubert.eval()
267+
return hubert
244268

245269

246270
def _kmeans(

0 commit comments

Comments
 (0)