-
Notifications
You must be signed in to change notification settings - Fork 24
Update model_loader.py #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
add wav2vec 2.0 [base, large]; hubert [base, large]; wavlm [base, baseplus, large]; whisper [base] to utilize fadtk for speech analysis
|
@microsoft-github-policy-service agree |
|
Thank you so much for adding support for more embedding models! I will review the code changes in a moment. |
fadtk/model_loader.py
Outdated
| class W2V2baseModel(ModelLoader): | ||
| """ | ||
| W2V2base model from https://huggingface.co/facebook/wav2vec2-base-960h | ||
| Please specify the layer to use (1-12). | ||
| """ | ||
| def __init__(self, size='960h', layer=12, limit_minutes=6): | ||
| super().__init__(f"w2v2base" + ("" if layer == 12 else f"-{layer}"), 768, 16000) | ||
| self.huggingface_id = f"facebook/wav2vec2-base-{size}" | ||
| self.layer = layer | ||
| self.limit = limit_minutes * 60 * self.sr | ||
|
|
||
| def load_model(self): | ||
| from transformers import AutoProcessor | ||
| from transformers import Wav2Vec2Model | ||
|
|
||
| self.model = Wav2Vec2Model.from_pretrained(self.huggingface_id) | ||
| self.processor = AutoProcessor.from_pretrained(self.huggingface_id) | ||
| self.model.to(self.device) | ||
|
|
||
| def _get_embedding(self, audio: np.ndarray) -> np.ndarray: | ||
| # Limit to 9 minutes | ||
| if audio.shape[0] > self.limit: | ||
| log.warning(f"Audio is too long ({audio.shape[0] / self.sr / 60:.2f} minutes > {self.limit / self.sr / 60:.2f} minutes). Truncating.") | ||
| audio = audio[:self.limit] | ||
|
|
||
| inputs = self.processor(audio, sampling_rate=self.sr, return_tensors="pt").to(self.device) | ||
| with torch.no_grad(): | ||
| out = self.model(**inputs, output_hidden_states=True) | ||
| out = torch.stack(out.hidden_states).squeeze() # [13 layers, timeframes, 768] | ||
| out = out[self.layer] # [timeframes, 768] | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| class W2V2largeModel(ModelLoader): | ||
| """ | ||
| W2V2large model from https://huggingface.co/facebook/wav2vec2-large-960h | ||
| Please specify the layer to use (1-24). | ||
| """ | ||
| def __init__(self, size='960h', layer=24, limit_minutes=6): | ||
| super().__init__(f"w2v2large" + ("" if layer == 24 else f"-{layer}"), 1024, 16000) | ||
| self.huggingface_id = f"facebook/wav2vec2-large-{size}" | ||
| self.layer = layer | ||
| self.limit = limit_minutes * 60 * self.sr | ||
|
|
||
| def load_model(self): | ||
| from transformers import AutoProcessor | ||
| from transformers import Wav2Vec2Model | ||
|
|
||
| self.model = Wav2Vec2Model.from_pretrained(self.huggingface_id) | ||
| self.processor = AutoProcessor.from_pretrained(self.huggingface_id) | ||
| self.model.to(self.device) | ||
|
|
||
| def _get_embedding(self, audio: np.ndarray) -> np.ndarray: | ||
| # Limit to 9 minutes | ||
| if audio.shape[0] > self.limit: | ||
| log.warning(f"Audio is too long ({audio.shape[0] / self.sr / 60:.2f} minutes > {self.limit / self.sr / 60:.2f} minutes). Truncating.") | ||
| audio = audio[:self.limit] | ||
|
|
||
| inputs = self.processor(audio, sampling_rate=self.sr, return_tensors="pt").to(self.device) | ||
| with torch.no_grad(): | ||
| out = self.model(**inputs, output_hidden_states=True) | ||
| out = torch.stack(out.hidden_states).squeeze() # [25 layers, timeframes, 1024] | ||
| out = out[self.layer] # [timeframes, 1024] | ||
|
|
||
| return out | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that the code for W2V2 base and W2V2 large are mostly identical. Would it be better to reuse some duplicate parts of the code by using abstractions? (e.g. defining and extending from a base class for each model family containing the duplicated functions)
This also applies to the base and large variants of HuBERT, and WavLM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated!
I also integrated more model types of whisper. Now, the most widely used speech models, i.e., w2v2, hubert, wavlm, and whisper have been included, which should be sufficient for the majority of speech tasks.
reusing duplicate parts of w2v2, hubert, and wavlm integrating mode types of whisper These four models are the most widely used ones in the speech area.
|
Thanks! I just merged the pull request. By the way, are there any extra dependencies required for these models? From what I see in the code changes, the only conditional import you used is |
|
Yep. No other dependencies required. Thanks for merging! |
add citation
add
wav2vec 2.0 [base, large];
hubert [base, large];
wavlm [base, baseplus, large];
whisper [base]
to utilize fadtk for speech analysis