diff --git a/examples/qualcomm/oss_scripts/moshi/mimi.py b/examples/qualcomm/oss_scripts/moshi/mimi.py index 6b59a71ae64..6bfd56d66c1 100644 --- a/examples/qualcomm/oss_scripts/moshi/mimi.py +++ b/examples/qualcomm/oss_scripts/moshi/mimi.py @@ -210,22 +210,73 @@ def forward(self, x): def mimi_decode( - mimi, encode_res_list, pcm_chunk_size, skip_node_id_set, skip_node_op_set + args, mimi, encode_res_list, pcm_chunk_size, skip_node_id_set, skip_node_op_set ) -> torch.Tensor: - class MimiDecode(nn.Module): - def __init__(self, mimi: nn.Module): - super().__init__() - self.mimi_model = mimi - + from pathlib import Path + from safetensors.torch import load_model + def _is_safetensors(path: Path | str) -> bool: + return Path(path).suffix in (".safetensors", ".sft", ".sfts") + from moshi.models.compression import MimiModel + from moshi.modules.seanet import SEANetEncoder, SEANetDecoder + from moshi.modules import transformer + from moshi.models.loaders import _seanet_kwargs, _quantizer_kwargs, _transformer_kwargs + from moshi.quantization.vq import SplitResidualVectorQuantizer + + class MimiDecode(MimiModel): def forward(self, x): - return self.mimi_model.decode(x) + return super().decode(x) - mimi_decode_model = MimiDecode(mimi) - decode_inputs, decode_input_list = [], "" - for index, encoder_res in enumerate(encode_res_list): - decode_inputs.append((encoder_res.to(torch.int32),)) - decode_input_list += f"input_{index}_0.raw\n" + encoder = SEANetEncoder(**_seanet_kwargs) + decoder = SEANetDecoder(**_seanet_kwargs) + encoder_transformer = transformer.ProjectedTransformer( + device='cpu', **_transformer_kwargs + ) + decoder_transformer = transformer.ProjectedTransformer( + device='cpu', **_transformer_kwargs + ) + quantizer = SplitResidualVectorQuantizer( + **_quantizer_kwargs, + ) + mimi_decode_model = MimiDecode( + encoder, + decoder, + quantizer, + channels=1, + sample_rate=24000, + frame_rate=12.5, + encoder_frame_rate=24000 / encoder.hop_length, + causal=True, + resample_method="conv", + encoder_transformer=encoder_transformer, + decoder_transformer=decoder_transformer,) + mimi_decode_model.eval() + if _is_safetensors(args.mimi_weight): + load_model(mimi_decode_model, args.mimi_weight, strict=False) + + decode_inputs, decode_input_list = [], "" + + + all_codes = [] + sample_input = encode_res_list[..., 0 : 1] + with mimi_decode_model.streaming(1): + #---------------------------------------------Works fine below with nn.Module--------------------------------------------- + # for i in range(encode_res_list.shape[-1]): + # codes = encode_res_list[..., i : i + 1] + # pcm = mimi_decode_model(codes) + # all_codes.append(pcm) + #---------------------------------------------SQNR drops to 8.5 after export--------------------------------------------- + captured_model = torch.export.export(mimi_decode_model, (sample_input,), strict=False).module() + for i in range(encode_res_list.shape[-1]): + codes = encode_res_list[..., i : i + 1] + pcm = captured_model(codes) + all_codes.append(pcm) + + + + cpu_decode_res = torch.cat(all_codes, dim=-1) + return cpu_decode_res + pte_filename = "mimi_decoder_qnn" quantizer = make_quantizer( @@ -314,14 +365,14 @@ def export_mimi(mimi, args, max_duration_sec=10.0): print("streaming encoding...") cpu_encode_res = mimi.encode(sample_pcm) - htp_encode_res = mimi_encode( - mimi, - encoder_inputs, - encoder_input_list, - pcm_chunk_size, - skip_node_id_set, - skip_node_op_set, - ) + # htp_encode_res = mimi_encode( + # mimi, + # encoder_inputs, + # encoder_input_list, + # pcm_chunk_size, + # skip_node_id_set, + # skip_node_op_set, + # ) # Leave it here for now, uncomment this to check htp_encoder with cpu_decoder # htp_res = torch.cat(htp_encode_res, dim=-1) @@ -332,10 +383,10 @@ def export_mimi(mimi, args, max_duration_sec=10.0): cpu_decode_res = mimi.decode(cpu_encode_res) # TODO: Enable streaming mode, which is the correct way to execute 1 chunk at a time. # with mimi.streaming(1): - htp_decode_res = mimi_decode( - mimi, htp_encode_res, pcm_chunk_size, skip_node_id_set, skip_node_op_set + cpu_streaming_decode_res = mimi_decode( + args, mimi, cpu_encode_res, pcm_chunk_size, skip_node_id_set, skip_node_op_set ) - compute_scores(cpu_decode_res, htp_decode_res) + compute_scores(cpu_decode_res, cpu_streaming_decode_res) sphn.write_wav( f"{args.artifact}/cpu_decode_res.wav", @@ -343,8 +394,8 @@ def export_mimi(mimi, args, max_duration_sec=10.0): sample_rate, ) sphn.write_wav( - f"{args.artifact}/htp_decode_res.wav", - htp_decode_res[0, 0].cpu().numpy(), + f"{args.artifact}/cpu_streaming_decode_res.wav", + cpu_streaming_decode_res[0, 0].cpu().numpy(), sample_rate, )