Skip to content

Commit 096afd6

Browse files
committed
vad : fix handling of stft tensor
Still trying to figure out what the issue with the incorrect probabilities are. I'll go through the conversion script later to clean it up as it contains leftover code from previous interations.
1 parent fbed604 commit 096afd6

File tree

2 files changed

+152
-155
lines changed

2 files changed

+152
-155
lines changed

models/convert-silero-vad-to-ggml.py

Lines changed: 146 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -23,159 +23,155 @@ def convert_silero_vad(output_path, print_tensors=True):
2323
output_file = f"{base}-v{silero_version}-ggml{ext}"
2424
print(f"Saving GGML Silero-VAD model to {output_file}")
2525

26-
# Create a lookup to identify which tensors might need special handling
27-
op_tensors = {
28-
'encoder': [],
29-
'decoder': [],
30-
'stft': []
31-
}
32-
33-
for key in cleaned_dict.keys():
34-
if 'encoder' in key:
35-
op_tensors['encoder'].append(key)
36-
elif 'decoder' in key:
37-
op_tensors['decoder'].append(key)
38-
elif 'stft' in key:
39-
op_tensors['stft'].append(key)
40-
41-
print("\nTensor groups for debugging:")
42-
for group, tensors in op_tensors.items():
43-
print(f"{group}: {len(tensors)} tensors")
44-
for t in tensors:
45-
print(f" - {t}: {cleaned_dict[t].shape} ({cleaned_dict[t].dtype})")
26+
print("\nTensor info for debugging:")
27+
for key, tensor in cleaned_dict.items():
28+
print(f" - {key}: {tensor.shape} ({tensor.dtype})")
4629
print()
4730

48-
fout = open(output_file, "wb")
49-
50-
# Write magic and version
51-
fout.write(struct.pack("i", 0x67676d6c)) # "ggml" in hex
52-
53-
# Write a flag to indicate we're preserving original tensor types
54-
fout.write(struct.pack("i", 2)) # 2 = preserve original type
55-
56-
n_encoder_layers = 4
57-
fout.write(struct.pack("i", n_encoder_layers))
58-
59-
# Write encoder dimensions
60-
input_channels = 129
61-
encoder_in_channels = [input_channels, 128, 64, 64]
62-
encoder_out_channels = [128, 64, 64, 128]
63-
kernel_size = 3
64-
65-
for i in range(n_encoder_layers):
66-
fout.write(struct.pack("i", encoder_in_channels[i]))
67-
fout.write(struct.pack("i", encoder_out_channels[i]))
68-
fout.write(struct.pack("i", kernel_size))
69-
70-
# Write LSTM dimensions
71-
lstm_input_size = 128
72-
lstm_hidden_size = 128
73-
fout.write(struct.pack("i", lstm_input_size))
74-
fout.write(struct.pack("i", lstm_hidden_size))
75-
76-
# Write final conv dimensions
77-
final_conv_in = 128
78-
final_conv_out = 1
79-
fout.write(struct.pack("i", final_conv_in))
80-
fout.write(struct.pack("i", final_conv_out))
81-
82-
print("Writing model weights:")
83-
84-
tensor_keys_to_write = []
85-
86-
for i in range(n_encoder_layers):
87-
weight_key = f"_model.encoder.{i}.reparam_conv.weight"
88-
bias_key = f"_model.encoder.{i}.reparam_conv.bias"
89-
if weight_key in cleaned_dict and bias_key in cleaned_dict:
90-
tensor_keys_to_write.append(weight_key)
91-
tensor_keys_to_write.append(bias_key)
92-
93-
lstm_keys = [
94-
"_model.decoder.rnn.weight_ih",
95-
"_model.decoder.rnn.weight_hh",
96-
"_model.decoder.rnn.bias_ih",
97-
"_model.decoder.rnn.bias_hh"
98-
]
99-
tensor_keys_to_write.extend([k for k in lstm_keys if k in cleaned_dict])
100-
101-
final_keys = [
102-
"_model.decoder.decoder.2.weight",
103-
"_model.decoder.decoder.2.bias"
104-
]
105-
tensor_keys_to_write.extend([k for k in final_keys if k in cleaned_dict])
106-
107-
stft_tensor = "_model.stft.forward_basis_buffer"
108-
tensor_keys_to_write.extend([stft_tensor])
109-
110-
for name in tensor_keys_to_write:
111-
if name not in cleaned_dict:
112-
print(f"Warning: Missing tensor {name}, skipping")
113-
continue
114-
115-
tensor = cleaned_dict[name]
116-
data = tensor.squeeze().numpy()
117-
print(f"Processing variable: {name} with shape: {data.shape}")
118-
119-
# Print values of the tensor (original values)
120-
if print_tensors:
121-
if name == "_model.stft.forward_basis_buffer":
122-
first_values = tensor.flatten()[:258].tolist()
123-
print(f" First 258 values for {name}:")
124-
for i, val in enumerate(first_values):
125-
print(f" [{i}]: {val}")
31+
with open(output_file, "wb") as fout:
32+
# Write magic and version
33+
fout.write(struct.pack("i", 0x67676d6c))
34+
35+
# Write model version - Try version 0 for simplicity
36+
fout.write(struct.pack("i", 0))
37+
38+
# Write model architecture parameters
39+
n_encoder_layers = 4
40+
fout.write(struct.pack("i", n_encoder_layers))
41+
42+
# Write encoder dimensions
43+
input_channels = 129
44+
encoder_in_channels = [input_channels, 128, 64, 64]
45+
encoder_out_channels = [128, 64, 64, 128]
46+
kernel_size = 3
47+
48+
for i in range(n_encoder_layers):
49+
fout.write(struct.pack("i", encoder_in_channels[i]))
50+
fout.write(struct.pack("i", encoder_out_channels[i]))
51+
fout.write(struct.pack("i", kernel_size))
52+
53+
# Write LSTM dimensions
54+
lstm_input_size = 128
55+
lstm_hidden_size = 128
56+
fout.write(struct.pack("i", lstm_input_size))
57+
fout.write(struct.pack("i", lstm_hidden_size))
58+
59+
# Write final conv dimensions
60+
final_conv_in = 128
61+
final_conv_out = 1
62+
fout.write(struct.pack("i", final_conv_in))
63+
fout.write(struct.pack("i", final_conv_out))
64+
65+
# Define tensor keys to write
66+
tensor_keys = []
67+
68+
# Encoder weights
69+
for i in range(n_encoder_layers):
70+
weight_key = f"_model.encoder.{i}.reparam_conv.weight"
71+
bias_key = f"_model.encoder.{i}.reparam_conv.bias"
72+
if weight_key in cleaned_dict and bias_key in cleaned_dict:
73+
tensor_keys.append(weight_key)
74+
tensor_keys.append(bias_key)
75+
76+
# LSTM weights
77+
lstm_keys = [
78+
"_model.decoder.rnn.weight_ih",
79+
"_model.decoder.rnn.weight_hh",
80+
"_model.decoder.rnn.bias_ih",
81+
"_model.decoder.rnn.bias_hh"
82+
]
83+
tensor_keys.extend([k for k in lstm_keys if k in cleaned_dict])
84+
85+
# Final conv weights
86+
final_keys = [
87+
"_model.decoder.decoder.2.weight",
88+
"_model.decoder.decoder.2.bias"
89+
]
90+
tensor_keys.extend([k for k in final_keys if k in cleaned_dict])
91+
92+
# STFT basis - add this last
93+
stft_tensor = "_model.stft.forward_basis_buffer"
94+
tensor_keys.append(stft_tensor)
95+
96+
print(f"Writing {len(tensor_keys)} tensors:")
97+
for key in tensor_keys:
98+
if key in cleaned_dict:
99+
print(f" - {key}: {cleaned_dict[key].shape}")
126100
else:
127-
first_values = tensor.flatten()[:10].tolist()
128-
print(f" First 10 values for {name}:")
129-
for i, val in enumerate(first_values):
130-
print(f" [{i}]: {val}")
131-
132-
if name.endswith(".reparam_conv.weight") and len(data.shape) == 3:
133-
print(f" Keeping original convolution weight shape: {data.shape}")
134-
135-
# Get original dtype
136-
orig_dtype = tensor.dtype
137-
print(f" Original tensor dtype: {orig_dtype}")
138-
139-
# Check if this is an encoder convolution weight that needs to be F16
140-
force_f16 = False
141-
if "encoder" in name and "weight" in name:
142-
print(f" This tensor will be forced to F16 for GGML im2col compatibility")
143-
force_f16 = True
144-
if "_model.decoder.decoder.2.weight" in name:
145-
print(f" This tensor will be forced to F16 for GGML im2col compatibility")
146-
force_f16 = True
147-
if "_model.stft.forward_basis_buffer" in name:
148-
print(f" This tensor will be forced to F16 for GGML im2col compatibility")
149-
force_f16 = True
150-
151-
# Set ftype based on the original dtype or force to F16 for certain tensors
152-
if force_f16:
153-
ftype = 1 # float16
154-
data = data.astype(np.float16)
155-
elif orig_dtype == torch.float16:
156-
ftype = 1 # float16
157-
else:
158-
ftype = 0 # float32
159-
160-
# Ensure data has the same type as the original tensor
161-
if ftype == 1 and not np.issubdtype(data.dtype, np.float16):
162-
data = data.astype(np.float16)
163-
164-
n_dims = len(data.shape)
165-
166-
# Write header
167-
str_bytes = name.encode('utf-8')
168-
fout.write(struct.pack("iii", n_dims, len(str_bytes), ftype))
169-
170-
for i in range(n_dims):
171-
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
172-
173-
fout.write(str_bytes)
174-
175-
data.tofile(fout)
176-
177-
fout.close()
178-
print(f"Done! Model has been converted to GGML format: {output_file}")
101+
print(f" - {key}: MISSING")
102+
103+
# Process each tensor
104+
for key in tensor_keys:
105+
if key not in cleaned_dict:
106+
print(f"Warning: Missing tensor {key}, skipping")
107+
continue
108+
109+
tensor = cleaned_dict[key]
110+
111+
# Special handling for STFT tensor
112+
if key == "_model.stft.forward_basis_buffer":
113+
# Get the original numpy array without squeezing
114+
data = tensor.detach().cpu().numpy()
115+
# Ensure it has the expected shape
116+
print(f"STFT tensor original shape: {data.shape}")
117+
n_dims = 3
118+
tensor_shape = [data.shape[0], data.shape[1], data.shape[2]]
119+
is_conv_weight = True
120+
else:
121+
# For other tensors, we can use standard processing
122+
data = tensor.detach().cpu().squeeze().numpy()
123+
tensor_shape = list(data.shape)
124+
125+
# Ensure we have at most 4 dimensions for GGML
126+
n_dims = min(len(tensor_shape), 4)
127+
128+
# Reverse dimensions for GGML
129+
tensor_shape = tensor_shape[:n_dims]
130+
tensor_shape.reverse()
131+
132+
# Check if this is a convolution weight tensor
133+
is_conv_weight = "weight" in key and ("encoder" in key or "_model.decoder.decoder.2" in key)
134+
135+
# Convert to float16 for convolution weights
136+
if is_conv_weight:
137+
data = data.astype(np.float16)
138+
ftype = 1 # float16
139+
else:
140+
ftype = 0 # float32
141+
142+
# Debug printing of tensor info
143+
print(f"\nWriting tensor: {key}")
144+
print(f" Original shape: {tensor.shape}")
145+
print(f" Processed shape: {data.shape}")
146+
print(f" GGML dimensions: {n_dims}")
147+
print(f" GGML shape: {tensor_shape}")
148+
print(f" Type: {'float16' if ftype == 1 else 'float32'}")
149+
150+
# Convert tensor name to bytes
151+
name_bytes = key.encode('utf-8')
152+
name_length = len(name_bytes)
153+
154+
# Write tensor header
155+
fout.write(struct.pack("i", n_dims))
156+
fout.write(struct.pack("i", name_length))
157+
fout.write(struct.pack("i", ftype))
158+
159+
# Write tensor dimensions
160+
for i in range(n_dims):
161+
size = tensor_shape[i] if i < len(tensor_shape) else 1
162+
fout.write(struct.pack("i", size))
163+
print(f" Writing dimension {i}: {size}")
164+
165+
# Write tensor name
166+
fout.write(name_bytes)
167+
168+
# Write tensor data
169+
data.tofile(fout)
170+
171+
print(f" Wrote {data.size * (2 if ftype==1 else 4)} bytes")
172+
173+
print(f"\nDone! Model has been converted to GGML format: {output_file}")
174+
print(f"File size: {os.path.getsize(output_file)} bytes")
179175

180176
if __name__ == "__main__":
181177
parser = argparse.ArgumentParser(description="Convert Silero-VAD PyTorch model to GGML format")

src/whisper.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4480,10 +4480,11 @@ static ggml_tensor * whisper_vad_build_stft_layer(ggml_context* ctx0,
44804480
const whisper_vad_model & model, ggml_tensor * cur) {
44814481
struct ggml_tensor * padded_frame = ggml_pad(ctx0, cur, 64, 0, 0, 0);
44824482
struct ggml_tensor * reshaped_frame = ggml_reshape_3d(ctx0, padded_frame, 640, 1, 1);
4483-
struct ggml_tensor * reshaped_basis = ggml_reshape_3d(ctx0, model.stft_forward_basis, 256, 1, 258);
4484-
struct ggml_tensor * permuted_basis = ggml_permute(ctx0, reshaped_basis, 2, 1, 0, 3);
4485-
permuted_basis = ggml_cont(ctx0, permuted_basis);
4486-
cur = ggml_conv_1d(ctx0, permuted_basis, reshaped_frame, 1, 0, 1);
4483+
4484+
// We need the stft tensor to be in {258, 1, 256},
4485+
// that is a kernel size of 258, 1 channel, and 256 frequency bins (output)
4486+
struct ggml_tensor * reshaped_stft = ggml_reshape_3d(ctx0, model.stft_forward_basis, 258, 1, 256);
4487+
cur = ggml_conv_1d(ctx0, reshaped_stft, reshaped_frame, 1, 1, 1);
44874488
ggml_set_name(cur, "stft");
44884489
ggml_set_output(cur);
44894490
return cur;
@@ -4842,7 +4843,7 @@ whisper_vad_context * whisper_vad_init_from_file_with_params_no_state(
48424843

48434844
// SFTF precomputed basis matrix
48444845
model.stft_forward_basis = create_tensor(VAD_TENSOR_STFT_BASIS,
4845-
ggml_new_tensor_2d(ctx, GGML_TYPE_F16, 256, 258));
4846+
ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 258, 1, 256));
48464847

48474848
model.encoder_0_weight = create_tensor(VAD_TENSOR_ENC_0_WEIGHT,
48484849
ggml_new_tensor_3d(

0 commit comments

Comments
 (0)