Skip to content

Commit 5dc3d79

Browse files
committed
models : add initial version of convert-solero-vad-to-ggml.py
wip
1 parent e6234cd commit 5dc3d79

File tree

1 file changed

+112
-0
lines changed

1 file changed

+112
-0
lines changed
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import os
2+
import struct
3+
import argparse
4+
import torch
5+
import numpy as np
6+
from silero_vad import load_silero_vad, __version__ as silero_version
7+
8+
def convert_silero_vad(output_path, use_f16=True, sample_rate=16000):
9+
model = load_silero_vad()
10+
11+
state_dict = model.state_dict()
12+
13+
if sample_rate == 16000:
14+
model_prefix = "_model"
15+
input_channels = 129
16+
sr_suffix = "16k"
17+
elif sample_rate == 8000:
18+
model_prefix = "_model_8k"
19+
input_channels = 65
20+
sr_suffix = "8k"
21+
else:
22+
raise ValueError(f"Unsupported sample rate: {sample_rate}")
23+
24+
base, ext = os.path.splitext(output_path)
25+
output_file = f"{base}-v{silero_version}_{sr_suffix}-ggml{ext}"
26+
27+
print(f"Converting {sample_rate//1000}kHz model")
28+
print(f"Saving GGML Silero-VAD model to {output_file}")
29+
30+
fout = open(output_file, "wb")
31+
32+
# Write magic and version
33+
fout.write(struct.pack("i", 0x67676d6c)) # "ggml" in hex
34+
fout.write(struct.pack("i", 1)) # Version
35+
36+
# Define and write the model architecture values
37+
fout.write(struct.pack("i", 1 if use_f16 else 0)) # Use f16 flag
38+
fout.write(struct.pack("i", sample_rate)) # Sample rate
39+
40+
# Write dimensions for model
41+
n_encoder_layers = 4
42+
fout.write(struct.pack("i", n_encoder_layers))
43+
44+
# Write encoder dimensions
45+
encoder_in_channels = [input_channels, 128, 64, 64]
46+
encoder_out_channels = [128, 64, 64, 128]
47+
kernel_size = 3
48+
49+
for i in range(n_encoder_layers):
50+
fout.write(struct.pack("i", encoder_in_channels[i]))
51+
fout.write(struct.pack("i", encoder_out_channels[i]))
52+
fout.write(struct.pack("i", kernel_size))
53+
54+
# Write LSTM dimensions
55+
lstm_input_size = 128
56+
lstm_hidden_size = 128
57+
fout.write(struct.pack("i", lstm_input_size))
58+
fout.write(struct.pack("i", lstm_hidden_size))
59+
60+
# Write final conv dimensions
61+
final_conv_in = 128
62+
final_conv_out = 1
63+
fout.write(struct.pack("i", final_conv_in))
64+
fout.write(struct.pack("i", final_conv_out))
65+
66+
# Helper function to write a tensor
67+
def write_tensor(name, tensor, f16=use_f16):
68+
print(f" Writing {name} with shape {tensor.shape}")
69+
70+
# Convert to numpy
71+
data = tensor.detach().cpu().numpy()
72+
73+
# Convert to float16 if requested (and tensor is float32)
74+
if f16 and tensor.dtype == torch.float32:
75+
data = data.astype(np.float16)
76+
77+
# Write tensor data
78+
data.tofile(fout)
79+
80+
print("Writing model weights:")
81+
82+
# 1. Encoder weights
83+
for i in range(n_encoder_layers):
84+
weight_key = f"{model_prefix}.encoder.{i}.reparam_conv.weight"
85+
bias_key = f"{model_prefix}.encoder.{i}.reparam_conv.bias"
86+
87+
# Write conv weights and biases
88+
write_tensor(weight_key, state_dict[weight_key])
89+
write_tensor(bias_key, state_dict[bias_key])
90+
91+
# 2. LSTM weights
92+
write_tensor("lstm_weight_ih", state_dict[f"{model_prefix}.decoder.rnn.weight_ih"])
93+
write_tensor("lstm_weight_hh", state_dict[f"{model_prefix}.decoder.rnn.weight_hh"])
94+
write_tensor("lstm_bias_ih", state_dict[f"{model_prefix}.decoder.rnn.bias_ih"])
95+
write_tensor("lstm_bias_hh", state_dict[f"{model_prefix}.decoder.rnn.bias_hh"])
96+
97+
# 3. Final conv layer
98+
write_tensor("final_conv_weight", state_dict[f"{model_prefix}.decoder.decoder.2.weight"])
99+
write_tensor("final_conv_bias", state_dict[f"{model_prefix}.decoder.decoder.2.bias"])
100+
101+
fout.close()
102+
print(f"Done! {sample_rate//1000}kHz model has been converted to GGML format: {output_file}")
103+
104+
if __name__ == "__main__":
105+
parser = argparse.ArgumentParser(description="Convert Silero-VAD PyTorch model to GGML format")
106+
parser.add_argument("--output", type=str, required=True, help="Path to output GGML model file")
107+
parser.add_argument("--use-f16", action="store_true", help="Use float16 precision")
108+
parser.add_argument("--sample-rate", type=int, choices=[8000, 16000], default=16000,
109+
help="Sample rate: 8000 or 16000")
110+
111+
args = parser.parse_args()
112+
convert_silero_vad(args.output, args.use_f16, args.sample_rate)

0 commit comments

Comments
 (0)