Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions fadtk/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,158 @@ def float32_to_int16(self, x):
x = np.clip(x, a_min=-1., a_max=1.)
return (x * 32767.).astype(np.int16)


class W2V2Model(ModelLoader):
"""
W2V2 model from https://huggingface.co/facebook/wav2vec2-base-960h, https://huggingface.co/facebook/wav2vec2-large-960h

Please specify the size ('base' or 'large') and the layer to use (1-12 for 'base' or 1-24 for 'large').
"""
def __init__(self, size: Literal['base', 'large'], layer: Literal['12', '24'], limit_minutes=6):
model_dim = 768 if size == 'base' else 1024
model_identifier = f"w2v2-{size}" + ("" if (layer == 12 and size == 'base') or (layer == 24 and size == 'large') else f"-{layer}")

super().__init__(model_identifier, model_dim, 16000)
self.huggingface_id = f"facebook/wav2vec2-{size}-960h"
self.layer = layer
self.limit = limit_minutes * 60 * self.sr

def load_model(self):
from transformers import AutoProcessor, Wav2Vec2Model

self.model = Wav2Vec2Model.from_pretrained(self.huggingface_id)
self.processor = AutoProcessor.from_pretrained(self.huggingface_id)
self.model.to(self.device)

def _get_embedding(self, audio: np.ndarray) -> np.ndarray:
# Limit to specified minutes
if audio.shape[0] > self.limit:
log.warning(f"Audio is too long ({audio.shape[0] / self.sr / 60:.2f} minutes > {self.limit / self.sr / 60:.2f} minutes). Truncating.")
audio = audio[:self.limit]

inputs = self.processor(audio, sampling_rate=self.sr, return_tensors="pt").to(self.device)
with torch.no_grad():
out = self.model(**inputs, output_hidden_states=True)
out = torch.stack(out.hidden_states).squeeze() # [13 or 25 layers, timeframes, 768 or 1024]
out = out[self.layer] # [timeframes, 768 or 1024]

return out


class HuBERTModel(ModelLoader):
"""
HuBERT model from https://huggingface.co/facebook/hubert-base-ls960, https://huggingface.co/facebook/hubert-large-ls960

Please specify the size ('base' or 'large') and the layer to use (1-12 for 'base' or 1-24 for 'large').
"""
def __init__(self, size: Literal['base', 'large'], layer: Literal['12', '24'], limit_minutes=6):
model_dim = 768 if size == 'base' else 1024
model_identifier = f"hubert-{size}" + ("" if (layer == 12 and size == 'base') or (layer == 24 and size == 'large') else f"-{layer}")

super().__init__(model_identifier, model_dim, 16000)
self.huggingface_id = f"facebook/hubert-{size}-ls960"
self.layer = layer
self.limit = limit_minutes * 60 * self.sr

def load_model(self):
from transformers import AutoProcessor, HubertModel

self.model = HubertModel.from_pretrained(self.huggingface_id)
self.processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
self.model.to(self.device)

def _get_embedding(self, audio: np.ndarray) -> np.ndarray:
# Limit to specified minutes
if audio.shape[0] > self.limit:
log.warning(f"Audio is too long ({audio.shape[0] / self.sr / 60:.2f} minutes > {self.limit / self.sr / 60:.2f} minutes). Truncating.")
audio = audio[:self.limit]

inputs = self.processor(audio, sampling_rate=self.sr, return_tensors="pt").to(self.device)
with torch.no_grad():
out = self.model(**inputs, output_hidden_states=True)
out = torch.stack(out.hidden_states).squeeze() # [13 or 25 layers, timeframes, 768 or 1024]
out = out[self.layer] # [timeframes, 768 or 1024]

return out


class WavLMModel(ModelLoader):
"""
WavLM model from https://huggingface.co/microsoft/wavlm-base, https://huggingface.co/microsoft/wavlm-base-plus, https://huggingface.co/microsoft/wavlm-large

Please specify the model size ('base', 'base-plus', or 'large') and the layer to use (1-12 for 'base' or 'base-plus' and 1-24 for 'large').
"""
def __init__(self, size: Literal['base', 'base-plus', 'large'], layer: Literal['12', '24'], limit_minutes=6):
model_dim = 768 if size in ['base', 'base-plus'] else 1024
model_identifier = f"wavlm-{size}" + ("" if (layer == 12 and size in ['base', 'base-plus']) or (layer == 24 and size == 'large') else f"-{layer}")

super().__init__(model_identifier, model_dim, 16000)
self.huggingface_id = f"patrickvonplaten/wavlm-libri-clean-100h-{size}"
self.layer = layer
self.limit = limit_minutes * 60 * self.sr

def load_model(self):
from transformers import AutoProcessor, WavLMModel

self.model = WavLMModel.from_pretrained(self.huggingface_id)
self.processor = AutoProcessor.from_pretrained(self.huggingface_id)
self.model.to(self.device)

def _get_embedding(self, audio: np.ndarray) -> np.ndarray:
# Limit to specified minutes
if audio.shape[0] > self.limit:
log.warning(f"Audio is too long ({audio.shape[0] / self.sr / 60:.2f} minutes > {self.limit / self.sr / 60:.2f} minutes). Truncating.")
audio = audio[:self.limit]

inputs = self.processor(audio, sampling_rate=self.sr, return_tensors="pt").to(self.device)
with torch.no_grad():
out = self.model(**inputs, output_hidden_states=True)
out = torch.stack(out.hidden_states).squeeze() # [13 or 25 layers, timeframes, 768 or 1024]
out = out[self.layer] # [timeframes, 768 or 1024]

return out


class WhisperModel(ModelLoader):
"""
Whisper model from https://huggingface.co/openai/whisper-base

Please specify the model size ('tiny', 'base', 'small', 'medium', or 'large').
"""
def __init__(self, size: Literal['tiny', 'base', 'small', 'medium', 'large']):
dimensions = {
'tiny': 384,
'base': 512,
'small': 768,
'medium': 1024,
'large': 1280
}
model_dim = dimensions.get(size)
model_identifier = f"whisper-{size}"

super().__init__(model_identifier, model_dim, 16000)
self.huggingface_id = f"openai/whisper-large"

def load_model(self):
from transformers import AutoFeatureExtractor
from transformers import WhisperModel

self.model = WhisperModel.from_pretrained(self.huggingface_id)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(self.huggingface_id)
self.model.to(self.device)

def _get_embedding(self, audio: np.ndarray) -> np.ndarray:
inputs = self.feature_extractor(audio, sampling_rate=self.sr, return_tensors="pt").to(self.device)
input_features = inputs.input_features
decoder_input_ids = torch.tensor([[1, 1]]) * self.model.config.decoder_start_token_id
with torch.no_grad():
out = self.model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state # [1, timeframes, 512]
out = out.squeeze() # [timeframes, 384 or 512 or 768 or 1024 or 1280]

return out



def get_all_models() -> list[ModelLoader]:
ms = [
CLAPModel('2023'),
Expand All @@ -472,6 +624,16 @@ def get_all_models() -> list[ModelLoader]:
EncodecEmbModel('24k'), EncodecEmbModel('48k'),
# DACModel(),
# CdpamModel('acoustic'), CdpamModel('content'),
*(W2V2Model('base', layer=v) for v in range(1, 13)),
*(W2V2Model('large', layer=v) for v in range(1, 25)),
*(HuBERTModel('base', layer=v) for v in range(1, 13)),
*(HuBERTModel('large', layer=v) for v in range(1, 25)),
*(WavLMModel('base', layer=v) for v in range(1, 13)),
*(WavLMModel('base-plus', layer=v) for v in range(1, 13)),
*(WavLMModel('large', layer=v) for v in range(1, 25)),
WhisperModel('tiny'), WhisperModel('small'),
WhisperModel('base'), WhisperModel('medium'),
WhisperModel('large'),
]
if importlib.util.find_spec("dac") is not None:
ms.append(DACModel())
Expand Down