Skip to content

Commit 55eecbf

Browse files
oyilmaz-nvidiapzelasko
authored andcommitted
Whisper model support in Lite (#11464)
* Data loader * another dataset * preprocessed audio dataset Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com> * seq2seq support Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com> * Apply isort and black reformatting Signed-off-by: oyilmaz-nvidia <oyilmaz-nvidia@users.noreply.github.com> * remove any update Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com> * Apply isort and black reformatting Signed-off-by: oyilmaz-nvidia <oyilmaz-nvidia@users.noreply.github.com> * fixing validation errors Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com> * Modify training step and tokenizer to achieve correct Whisper training Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * Apply isort and black reformatting Signed-off-by: pzelasko <pzelasko@users.noreply.github.com> * Moved files into speechlm collection Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com> * revert changes Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com> * create recipes folder Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com> * generalize forward Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com> * Apply isort and black reformatting Signed-off-by: oyilmaz-nvidia <oyilmaz-nvidia@users.noreply.github.com> * example update Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com> * Apply isort and black reformatting Signed-off-by: oyilmaz-nvidia <oyilmaz-nvidia@users.noreply.github.com> * address codeql reviews Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com> * remove examples Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com> --------- Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com> Signed-off-by: oyilmaz-nvidia <oyilmaz-nvidia@users.noreply.github.com> Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> Signed-off-by: pzelasko <pzelasko@users.noreply.github.com> Co-authored-by: oyilmaz-nvidia <oyilmaz-nvidia@users.noreply.github.com> Co-authored-by: Piotr Żelasko <pzelasko@nvidia.com> Co-authored-by: pzelasko <pzelasko@users.noreply.github.com> Signed-off-by: Abhinav Garg <abhgarg@nvidia.com>
1 parent c563fce commit 55eecbf

File tree

9 files changed

+826
-0
lines changed

9 files changed

+826
-0
lines changed

examples/speechlm/sft/hf.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import fiddle as fdl
16+
import torch
17+
from lhotse.dataset.collation import collate_matrices, collate_vectors
18+
from omegaconf import OmegaConf
19+
20+
from nemo import lightning as nl
21+
from nemo.collections import speechlm
22+
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
23+
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
24+
from nemo.collections.speechlm.models import HFAutoModelForSpeechSeq2Seq
25+
26+
torch.set_float32_matmul_precision("medium")
27+
28+
29+
class LhotseHfNeMoDataset(torch.utils.data.Dataset):
30+
def __init__(self, processor, tokenizer, decoder_mask_fill=-100):
31+
super().__init__()
32+
self.processor = processor
33+
self.tokenizer = tokenizer
34+
self.decoder_mask_fill = decoder_mask_fill
35+
36+
def __getitem__(self, cuts):
37+
features = []
38+
for cut in cuts:
39+
audio = cut.load_audio()
40+
features.append(
41+
self.processor(
42+
audio,
43+
sampling_rate=cut.sampling_rate,
44+
return_tensors="pt",
45+
text=cut.supervisions[0].text,
46+
)
47+
)
48+
49+
input_features = collate_matrices(tensors=[f["input_features"].squeeze(0) for f in features])
50+
labels = collate_vectors(tensors=[c.supervisions[0].tokens for c in cuts])
51+
decoder_input_ids = labels[:, :-1]
52+
decoder_input_ids = decoder_input_ids.masked_fill(
53+
decoder_input_ids == self.decoder_mask_fill, self.tokenizer.pad_id
54+
)
55+
labels = labels[:, 1:].reshape(-1)
56+
57+
return {
58+
"input_features": input_features,
59+
"labels": labels,
60+
"decoder_input_ids": decoder_input_ids,
61+
}
62+
63+
64+
if __name__ == '__main__':
65+
import argparse
66+
67+
parser = argparse.ArgumentParser()
68+
69+
# Models can be one of the supported ones by AutoModelForSpeechSeq2Seq such as
70+
# openai/whisper-large-v3 and facebook/s2t-small-librispeech-asr
71+
parser.add_argument('--model', default='openai/whisper-large-v3')
72+
parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp'])
73+
parser.add_argument('--devices', default=1)
74+
parser.add_argument('--accelerator', default='gpu', choices=['gpu'])
75+
parser.add_argument('--max-steps', type=int, default=100)
76+
parser.add_argument('--model-save-path', type=str, default=None)
77+
args = parser.parse_args()
78+
79+
model = HFAutoModelForSpeechSeq2Seq(model_name=args.model)
80+
model = model.to(torch.float)
81+
processor = model.processor
82+
tokenizer = AutoTokenizer(args.model, include_special_tokens=True)
83+
84+
config = OmegaConf.create(
85+
{
86+
"cuts_path": "/opt/checkpoints/lhotse/libri/libri-train-5.jsonl.gz",
87+
"sample_rate": 16000,
88+
"shuffle": True,
89+
"num_workers": 2,
90+
"batch_size": 4,
91+
"shuffle_buffer_size": 100,
92+
}
93+
)
94+
95+
train_dataloader = get_lhotse_dataloader_from_config(
96+
config,
97+
global_rank=0,
98+
world_size=1,
99+
dataset=LhotseHfNeMoDataset(
100+
processor=processor,
101+
tokenizer=tokenizer,
102+
),
103+
tokenizer=tokenizer,
104+
)
105+
106+
speechlm.api.finetune(
107+
model=model,
108+
data=train_dataloader,
109+
trainer=nl.Trainer(
110+
devices=args.devices,
111+
max_steps=args.max_steps,
112+
accelerator=args.accelerator,
113+
strategy=args.strategy,
114+
precision="bf16-mixed",
115+
log_every_n_steps=1,
116+
limit_val_batches=0.0,
117+
num_sanity_val_steps=0,
118+
accumulate_grad_batches=10,
119+
gradient_clip_val=0.5,
120+
use_distributed_sampler=False,
121+
callbacks=[],
122+
logger=None,
123+
),
124+
optim=fdl.build(speechlm.adam.pytorch_adam_with_flat_lr(lr=1e-5)),
125+
log=None,
126+
)
127+
128+
if args.model_save_path is not None:
129+
model.save_pretrained(args.model_save_path)

nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
additional_special_tokens: Optional[List] = [],
4747
use_fast: Optional[bool] = False,
4848
trust_remote_code: Optional[bool] = False,
49+
include_special_tokens: bool = False,
4950
):
5051
"""
5152
Args:
@@ -63,6 +64,7 @@ def __init__(
6364
unk_token: token to use for unknown tokens
6465
additional_special_tokens: list of other tokens beside standard special tokens (bos, eos, pad, etc.). For example, sentinel tokens for T5 (<extra_id_0>, <extra_id_1>, etc.)
6566
use_fast: whether to use fast HuggingFace tokenizer
67+
include_special_tokens: when True, converting text to ids will include special tokens / prompt tokens (if any), yielding self.tokenizer(text).input_ids
6668
"""
6769
try:
6870
# this logic deals with different huggingface tokenizers having different positional args
@@ -92,6 +94,7 @@ def __init__(
9294
f'Unable to instantiate HuggingFace AUTOTOKENIZER for {pretrained_model_name}. Exception: {e}'
9395
)
9496

97+
self.include_special_tokens = include_special_tokens
9598
self.original_vocab_size = len(self.tokenizer)
9699
special_tokens_dict = {}
97100

@@ -220,6 +223,8 @@ def ids_to_tokens(self, ids):
220223
return tokens
221224

222225
def text_to_ids(self, text):
226+
if self.include_special_tokens:
227+
return self.tokenizer(text).input_ids
223228
tokens = self.text_to_tokens(text)
224229
ids = self.tokens_to_ids(tokens)
225230
return ids
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from nemo.collections.speechlm.models import HFAutoModelForSpeechSeq2Seq
16+
from nemo.utils import logging
17+
18+
__all__ = [
19+
"HFAutoModelForSpeechSeq2Seq",
20+
]
21+
22+
try:
23+
import nemo_run as run
24+
25+
from nemo.collections.llm.recipes import adam
26+
from nemo.collections.speechlm.api import finetune, generate, pretrain, train, validate
27+
28+
__all__.extend(
29+
[
30+
"train",
31+
"pretrain",
32+
"validate",
33+
"finetune",
34+
"generate",
35+
]
36+
)
37+
except ImportError as error:
38+
logging.warning(f"Failed to import nemo.collections.speechlm.[api, recipes]: {error}")

0 commit comments

Comments
 (0)