Skip to content

Commit d032bfa

Browse files
committed
add support for large-v3 and distil-whisper
1 parent 88eec90 commit d032bfa

File tree

6 files changed

+318
-14
lines changed

6 files changed

+318
-14
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# WhisperS2T ⚡
22

3-
WhisperS2T is an optimized lightning-fast speech-to-text pipeline tailored for the whisper model! It's designed to be exceptionally fast, boasting a 1.5X speed improvement over WhisperX and a 2X speed boost compared to HuggingFace Pipeline with FlashAttention 2 (Insanely Fast Whisper). Moreover, it includes several heuristics to enhance transcription accuracy.
3+
WhisperS2T is an optimized lightning-fast speech-to-text pipeline tailored for the whisper model! It's designed to be exceptionally fast, boasting a 1.5X speed improvement over WhisperX and a 2X speed boost compared to HuggingFace Pipeline with FlashAttention 2 (Insanely Fast Whisper). Moreover, it includes several heuristics to enhance transcription accuracy.
44

55
[**Whisper**](https://github.com/openai/whisper) is a general-purpose speech recognition model developed by OpenAI. It is trained on a large dataset of diverse audio and is also a multitasking model that can perform multilingual speech recognition, speech translation, and language identification.
66

@@ -10,7 +10,8 @@ Stay tuned for a technical report comparing WhisperS2T against other whisper pip
1010

1111
![A30 Benchmark](files/benchmarks.png)
1212

13-
**NOTE:** I ran all the benchmarks with `without_timestamps` parameter as `True`. Setting `without_timestamps` as `False` may improve the WER of HuggingFace pipiline at the expense of additional inference time.
13+
**NOTE:** I conducted all the benchmarks using the `without_timestamps` parameter set as `True`. Adjusting this parameter to `False` may enhance the Word Error Rate (WER) of the HuggingFace pipeline but at the expense of increased inference time. Notably, the improvements in inference speed were achieved solely through a **superior pipeline design**, without any specific optimization made to the backend inference engines (such as CTranslate2, FlashAttention2, etc.). For instance, WhisperS2T (utilizing FlashAttention2) demonstrates significantly superior inference speed compared to the HuggingFace pipeline (also using FlashAttention2), despite both leveraging the same inference engine—HuggingFace whisper model with FlashAttention2. Additionally, there is a noticeable difference in the WER as well.
14+
1415

1516
## Features
1617

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import argparse
2+
from rich.console import Console
3+
console = Console()
4+
5+
def parse_arguments():
6+
parser = argparse.ArgumentParser()
7+
parser.add_argument('--repo_path', default="", type=str)
8+
parser.add_argument('--batch_size', default=16, type=int)
9+
parser.add_argument('--flash_attention', default="yes", type=str)
10+
parser.add_argument('--better_transformer', default="no", type=str)
11+
parser.add_argument('--eval_mp3', default="no", type=str)
12+
parser.add_argument('--eval_multilingual', default="no", type=str)
13+
args = parser.parse_args()
14+
return args
15+
16+
17+
def run(repo_path, flash_attention=False, better_transformer=False, batch_size=16, eval_mp3=False, eval_multilingual=True):
18+
import torch
19+
import time, os
20+
import pandas as pd
21+
from transformers import pipeline
22+
23+
# Load Model >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
24+
model_kwargs = {
25+
"use_safetensors": True,
26+
"low_cpu_mem_usage": True
27+
}
28+
29+
results_dir = f"{repo_path}/results/HuggingFaceDistilWhisper-bs_{batch_size}"
30+
31+
if flash_attention:
32+
results_dir = f"{results_dir}-fa"
33+
model_kwargs["use_flash_attention_2"] = True
34+
35+
ASR = pipeline("automatic-speech-recognition",
36+
f"distil-whisper/distil-large-v2",
37+
num_workers=1,
38+
torch_dtype=torch.float16,
39+
device="cuda",
40+
model_kwargs=model_kwargs)
41+
42+
if (not flash_attention) and better_transformer:
43+
ASR.model = ASR.model.to_bettertransformer()
44+
results_dir = f"{results_dir}-bt"
45+
46+
os.makedirs(results_dir, exist_ok=True)
47+
48+
# KINCAID46 WAV >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
49+
data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_wav.tsv', sep="\t")
50+
files = [f"{repo_path}/{fn}" for fn in data['audio_path']]
51+
52+
with console.status("Warming"):
53+
st = time.time()
54+
_ = ASR(files,
55+
batch_size=batch_size,
56+
chunk_length_s=15,
57+
generate_kwargs={'num_beams': 1, 'language': 'en'},
58+
return_timestamps=False)
59+
60+
print(f"[Warming Time]: {time.time()-st}")
61+
62+
with console.status("KINCAID WAV"):
63+
st = time.time()
64+
outputs = ASR(files,
65+
batch_size=batch_size,
66+
chunk_length_s=15,
67+
generate_kwargs={'num_beams': 1, 'language': 'en'},
68+
return_timestamps=False)
69+
70+
time_kincaid46_wav = time.time()-st
71+
print(f"[KINCAID WAV Time]: {time_kincaid46_wav}")
72+
73+
data['pred_text'] = [_['text'].strip() for _ in outputs]
74+
data.to_csv(f"{results_dir}/KINCAID46_WAV.tsv", sep="\t", index=False)
75+
76+
77+
# KINCAID46 MP3 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
78+
if eval_mp3:
79+
data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_mp3.tsv', sep="\t")
80+
files = [f"{repo_path}/{fn}" for fn in data['audio_path']]
81+
82+
with console.status("KINCAID MP3"):
83+
st = time.time()
84+
outputs = ASR(files,
85+
batch_size=batch_size,
86+
chunk_length_s=30,
87+
generate_kwargs={'num_beams': 1, 'language': 'en'},
88+
return_timestamps=False)
89+
90+
time_kincaid46_mp3 = time.time()-st
91+
92+
print(f"[KINCAID MP3 Time]: {time_kincaid46_mp3}")
93+
94+
data['pred_text'] = [_['text'].strip() for _ in outputs]
95+
data.to_csv(f"{results_dir}/KINCAID46_MP3.tsv", sep="\t", index=False)
96+
else:
97+
time_kincaid46_mp3 = 0.0
98+
99+
# MultiLingualLongform >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
100+
if eval_multilingual:
101+
data = pd.read_csv(f'{repo_path}/data/MultiLingualLongform/manifest.tsv', sep="\t")
102+
files = [f"{repo_path}/{fn}" for fn in data['audio_path']]
103+
lang_codes = data['lang_code'].to_list()
104+
105+
with console.status("MultiLingualLongform"):
106+
st = time.time()
107+
108+
curr_files = [files[0]]
109+
curr_lang = lang_codes[0]
110+
outputs = []
111+
for fn, lang in zip(files[1:], lang_codes[1:]):
112+
if lang != curr_lang:
113+
_outputs = ASR(curr_files,
114+
batch_size=batch_size,
115+
chunk_length_s=30,
116+
generate_kwargs={'num_beams': 1, 'language': curr_lang},
117+
return_timestamps=False)
118+
outputs.extend(_outputs)
119+
120+
curr_files = [fn]
121+
curr_lang = lang
122+
else:
123+
curr_files.append(fn)
124+
125+
_outputs = ASR(curr_files,
126+
batch_size=batch_size,
127+
chunk_length_s=30,
128+
generate_kwargs={'num_beams': 1, 'language': curr_lang},
129+
return_timestamps=False)
130+
131+
outputs.extend(_outputs)
132+
133+
time_multilingual = time.time()-st
134+
print(f"[MultiLingualLongform Time]: {time_multilingual}")
135+
136+
data['pred_text'] = [_['text'].strip() for _ in outputs]
137+
data.to_csv(f"{results_dir}/MultiLingualLongform.tsv", sep="\t", index=False)
138+
else:
139+
time_multilingual = 0.0
140+
141+
infer_time = [
142+
["Dataset", "Time"],
143+
["KINCAID46 WAV", time_kincaid46_wav],
144+
["KINCAID46 MP3", time_kincaid46_mp3],
145+
["MultiLingualLongform", time_multilingual]
146+
]
147+
148+
infer_time = pd.DataFrame(infer_time[1:], columns=infer_time[0])
149+
infer_time.to_csv(f"{results_dir}/infer_time.tsv", sep="\t", index=False)
150+
151+
152+
if __name__ == '__main__':
153+
args = parse_arguments()
154+
eval_mp3 = True if args.eval_mp3 == "yes" else False
155+
eval_multilingual = True if args.eval_multilingual == "yes" else False
156+
flash_attention = True if args.flash_attention == "yes" else False
157+
better_transformer = True if args.better_transformer == "yes" else False
158+
159+
run(args.repo_path, flash_attention=flash_attention, better_transformer=better_transformer, batch_size=args.batch_size, eval_mp3=eval_mp3, eval_multilingual=eval_multilingual)
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import argparse
2+
3+
def parse_arguments():
4+
parser = argparse.ArgumentParser()
5+
parser.add_argument('--repo_path', default="", type=str)
6+
parser.add_argument('--backend', default="HuggingFace", type=str)
7+
parser.add_argument('--batch_size', default=16, type=int)
8+
parser.add_argument('--flash_attention', default="yes", type=str)
9+
parser.add_argument('--better_transformer', default="no", type=str)
10+
parser.add_argument('--eval_mp3', default="no", type=str)
11+
parser.add_argument('--eval_multilingual', default="no", type=str)
12+
args = parser.parse_args()
13+
return args
14+
15+
def run(repo_path, backend, flash_attention=False, better_transformer=False, batch_size=16, eval_mp3=False, eval_multilingual=True):
16+
import sys, time, os
17+
18+
if len(repo_path):
19+
sys.path.append(repo_path)
20+
21+
import whisper_s2t
22+
import pandas as pd
23+
24+
if backend.lower() in ["huggingface", "hf"]:
25+
asr_options = {
26+
"use_flash_attention": flash_attention,
27+
"use_better_transformer": better_transformer
28+
}
29+
30+
if flash_attention:
31+
results_dir = f"{repo_path}/results/WhisperS2T-{backend}DistilWhisper-bs_{batch_size}-fa"
32+
elif better_transformer:
33+
results_dir = f"{repo_path}/results/WhisperS2T-{backend}DistilWhisper-bs_{batch_size}-bt"
34+
else:
35+
results_dir = f"{repo_path}/results/WhisperS2T-{backend}DistilWhisper-bs_{batch_size}"
36+
else:
37+
asr_options = {}
38+
results_dir = f"{repo_path}/results/WhisperS2T-{backend}-bs_{batch_size}"
39+
40+
os.makedirs(results_dir, exist_ok=True)
41+
42+
model = whisper_s2t.load_model("distil-large-v2", backend=backend, asr_options=asr_options)
43+
44+
# KINCAID46 WAV >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
45+
data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_wav.tsv', sep="\t")
46+
files = [f"{repo_path}/{fn}" for fn in data['audio_path']]
47+
lang_codes = len(files)*['en']
48+
tasks = len(files)*['transcribe']
49+
initial_prompts = len(files)*[None]
50+
51+
_ = model.transcribe_with_vad(files,
52+
lang_codes=lang_codes,
53+
tasks=tasks,
54+
initial_prompts=initial_prompts,
55+
batch_size=batch_size)
56+
57+
st = time.time()
58+
out = model.transcribe_with_vad(files,
59+
lang_codes=lang_codes,
60+
tasks=tasks,
61+
initial_prompts=initial_prompts,
62+
batch_size=batch_size)
63+
time_kincaid46_wav = time.time()-st
64+
65+
data['pred_text'] = [" ".join([_['text'] for _ in _transcript]).strip() for _transcript in out]
66+
data.to_csv(f"{results_dir}/KINCAID46_WAV.tsv", sep="\t", index=False)
67+
68+
69+
# KINCAID46 MP3 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
70+
if eval_mp3:
71+
data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_mp3.tsv', sep="\t")
72+
files = [f"{repo_path}/{fn}" for fn in data['audio_path']]
73+
lang_codes = len(files)*['en']
74+
tasks = len(files)*['transcribe']
75+
initial_prompts = len(files)*[None]
76+
77+
st = time.time()
78+
out = model.transcribe_with_vad(files,
79+
lang_codes=lang_codes,
80+
tasks=tasks,
81+
initial_prompts=initial_prompts,
82+
batch_size=batch_size)
83+
time_kincaid46_mp3 = time.time()-st
84+
85+
data['pred_text'] = [" ".join([_['text'] for _ in _transcript]).strip() for _transcript in out]
86+
data.to_csv(f"{results_dir}/KINCAID46_MP3.tsv", sep="\t", index=False)
87+
else:
88+
time_kincaid46_mp3 = 0.0
89+
90+
91+
# MultiLingualLongform
92+
if eval_multilingual:
93+
data = pd.read_csv(f'{repo_path}/data/MultiLingualLongform/manifest.tsv', sep="\t")
94+
files = [f"{repo_path}/{fn}" for fn in data['audio_path']]
95+
lang_codes = data['lang_code'].to_list()
96+
tasks = len(files)*['transcribe']
97+
initial_prompts = len(files)*[None]
98+
99+
st = time.time()
100+
out = model.transcribe_with_vad(files,
101+
lang_codes=lang_codes,
102+
tasks=tasks,
103+
initial_prompts=initial_prompts,
104+
batch_size=batch_size)
105+
time_multilingual = time.time()-st
106+
107+
data['pred_text'] = [" ".join([_['text'] for _ in _transcript]).strip() for _transcript in out]
108+
data.to_csv(f"{results_dir}/MultiLingualLongform.tsv", sep="\t", index=False)
109+
else:
110+
time_multilingual = 0.0
111+
112+
infer_time = [
113+
["Dataset", "Time"],
114+
["KINCAID46 WAV", time_kincaid46_wav],
115+
["KINCAID46 MP3", time_kincaid46_mp3],
116+
["MultiLingualLongform", time_multilingual]
117+
]
118+
infer_time = pd.DataFrame(infer_time[1:], columns=infer_time[0])
119+
infer_time.to_csv(f"{results_dir}/infer_time.tsv", sep="\t", index=False)
120+
121+
122+
if __name__ == '__main__':
123+
args = parse_arguments()
124+
eval_mp3 = True if args.eval_mp3 == "yes" else False
125+
eval_multilingual = True if args.eval_multilingual == "yes" else False
126+
flash_attention = True if args.flash_attention == "yes" else False
127+
better_transformer = True if args.better_transformer == "yes" else False
128+
129+
run(args.repo_path, args.backend, flash_attention=flash_attention, better_transformer=better_transformer, batch_size=args.batch_size, eval_mp3=eval_mp3, eval_multilingual=eval_multilingual)

whisper_s2t/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,25 @@ def load_model(model_identifier="large-v2",
1212
backend='CTranslate2',
1313
**model_kwargs):
1414

15+
if model_identifier in ['large-v3']:
16+
model_kwargs['n_mels'] = 128
17+
elif (model_identifier in ['distil-large-v2']) and (backend.lower() not in ["huggingface", "hf"]):
18+
print(f"Switching backend to HuggingFace. Distill whisper is only supported with HuggingFace backend.")
19+
backend = "huggingface"
20+
21+
model_kwargs['max_speech_len'] = 15.0
22+
model_kwargs['max_text_token_len'] = 128
23+
1524
if backend.lower() in ["ctranslate2", "ct2"]:
1625
from .backends.ctranslate2.model import WhisperModelCT2 as WhisperModel
1726

1827
elif backend.lower() in ["huggingface", "hf"]:
1928
from .backends.huggingface.model import WhisperModelHF as WhisperModel
20-
model_identifier = f"openai/whisper-{model_identifier}"
29+
30+
if 'distil' in model_identifier:
31+
model_identifier = f"distil-whisper/{model_identifier}"
32+
else:
33+
model_identifier = f"openai/whisper-{model_identifier}"
2134

2235
elif backend.lower() in ["openai", "oai"]:
2336
from .backends.openai.model import WhisperModelOAI as WhisperModel

whisper_s2t/backends/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(self,
5555

5656
self.vad_model = vad_model
5757
self.speech_segmenter_options = speech_segmenter_options
58+
self.speech_segmenter_options['max_seg_len'] = self.max_speech_len
5859

5960
# Tokenizer
6061
if tokenizer is None:

whisper_s2t/backends/ctranslate2/hf_utils.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,18 @@
1414

1515

1616
_MODELS = {
17-
"tiny.en": "guillaumekln/faster-whisper-tiny.en",
18-
"tiny": "guillaumekln/faster-whisper-tiny",
19-
"base.en": "guillaumekln/faster-whisper-base.en",
20-
"base": "guillaumekln/faster-whisper-base",
21-
"small.en": "guillaumekln/faster-whisper-small.en",
22-
"small": "guillaumekln/faster-whisper-small",
23-
"medium.en": "guillaumekln/faster-whisper-medium.en",
24-
"medium": "guillaumekln/faster-whisper-medium",
25-
"large-v1": "guillaumekln/faster-whisper-large-v1",
26-
"large-v2": "guillaumekln/faster-whisper-large-v2",
27-
"large": "guillaumekln/faster-whisper-large-v2",
17+
"tiny.en": "Systran/faster-whisper-tiny.en",
18+
"tiny": "Systran/faster-whisper-tiny",
19+
"base.en": "Systran/faster-whisper-base.en",
20+
"base": "Systran/faster-whisper-base",
21+
"small.en": "Systran/faster-whisper-small.en",
22+
"small": "Systran/faster-whisper-small",
23+
"medium.en": "Systran/faster-whisper-medium.en",
24+
"medium": "Systran/faster-whisper-medium",
25+
"large-v1": "Systran/faster-whisper-large-v1",
26+
"large-v2": "Systran/faster-whisper-large-v2",
27+
"large-v3": "Systran/faster-whisper-large-v3",
28+
"large": "Systran/faster-whisper-large-v3",
2829
}
2930

3031

0 commit comments

Comments
 (0)