From 3fed7fbb6bb4bc5640485f502f01cd601b4d73d2 Mon Sep 17 00:00:00 2001 From: Anjos Date: Mon, 20 Mar 2023 19:54:48 +0800 Subject: [PATCH 1/2] add use gpu(torch.jit.script to zipformer encoder) --- sherpa/cpp_api/bin/online-recognizer.cc | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/sherpa/cpp_api/bin/online-recognizer.cc b/sherpa/cpp_api/bin/online-recognizer.cc index ffbd26920..c9b6551f8 100644 --- a/sherpa/cpp_api/bin/online-recognizer.cc +++ b/sherpa/cpp_api/bin/online-recognizer.cc @@ -132,6 +132,12 @@ int32_t main(int32_t argc, char *argv[]) { } config.Validate(); + + if ( config.use_gpu ) { + config.feat_config.fbank_opts.device = torch::Device("cuda:0"); + } else { + config.feat_config.fbank_opts.device = torch::Device("cpu"); + } SHERPA_CHECK_EQ(config.feat_config.fbank_opts.frame_opts.samp_freq, expected_sample_rate) @@ -147,6 +153,8 @@ int32_t main(int32_t argc, char *argv[]) { torch::Tensor tail_padding = torch::zeros( {static_cast(padding_seconds * expected_sample_rate)}, torch::kFloat); + + tail_padding = tail_padding.to( config.feat_config.fbank_opts.device ); sherpa::OnlineRecognizer recognizer(config); if (use_wav_scp) { @@ -193,6 +201,7 @@ int32_t main(int32_t argc, char *argv[]) { {d.NumCols()}, torch::kFloat) / 32768; auto s = recognizer.CreateStream(); + tensor = tensor.to( config.feat_config.fbank_opts.device ); s->AcceptWaveform(expected_sample_rate, tensor); s->AcceptWaveform(expected_sample_rate, tail_padding); s->InputFinished(); @@ -227,6 +236,7 @@ int32_t main(int32_t argc, char *argv[]) { wave.index({torch::indexing::Slice(start, end)}); start = end; + samples = samples.to( config.feat_config.fbank_opts.device ); s->AcceptWaveform(expected_sample_rate, samples); while (recognizer.IsReady(s.get())) { @@ -265,7 +275,7 @@ int32_t main(int32_t argc, char *argv[]) { torch::Tensor wave = sherpa::ReadWave(po.GetArg(i), expected_sample_rate).first; - + wave = wave.to( config.feat_config.fbank_opts.device ); s->AcceptWaveform(expected_sample_rate, wave); s->AcceptWaveform(expected_sample_rate, tail_padding); From 4c788a05465c4137319d800f1cb18e2e0abae2f5 Mon Sep 17 00:00:00 2001 From: Anjos Date: Tue, 21 Mar 2023 09:50:03 +0800 Subject: [PATCH 2/2] fix the code style --- sherpa/cpp_api/bin/online-recognizer.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sherpa/cpp_api/bin/online-recognizer.cc b/sherpa/cpp_api/bin/online-recognizer.cc index c9b6551f8..b55a369ae 100644 --- a/sherpa/cpp_api/bin/online-recognizer.cc +++ b/sherpa/cpp_api/bin/online-recognizer.cc @@ -133,7 +133,7 @@ int32_t main(int32_t argc, char *argv[]) { config.Validate(); - if ( config.use_gpu ) { + if (config.use_gpu) { config.feat_config.fbank_opts.device = torch::Device("cuda:0"); } else { config.feat_config.fbank_opts.device = torch::Device("cpu"); @@ -154,7 +154,7 @@ int32_t main(int32_t argc, char *argv[]) { {static_cast(padding_seconds * expected_sample_rate)}, torch::kFloat); - tail_padding = tail_padding.to( config.feat_config.fbank_opts.device ); + tail_padding = tail_padding.to(config.feat_config.fbank_opts.device); sherpa::OnlineRecognizer recognizer(config); if (use_wav_scp) { @@ -201,7 +201,7 @@ int32_t main(int32_t argc, char *argv[]) { {d.NumCols()}, torch::kFloat) / 32768; auto s = recognizer.CreateStream(); - tensor = tensor.to( config.feat_config.fbank_opts.device ); + tensor = tensor.to(config.feat_config.fbank_opts.device); s->AcceptWaveform(expected_sample_rate, tensor); s->AcceptWaveform(expected_sample_rate, tail_padding); s->InputFinished(); @@ -236,7 +236,7 @@ int32_t main(int32_t argc, char *argv[]) { wave.index({torch::indexing::Slice(start, end)}); start = end; - samples = samples.to( config.feat_config.fbank_opts.device ); + samples = samples.to(config.feat_config.fbank_opts.device); s->AcceptWaveform(expected_sample_rate, samples); while (recognizer.IsReady(s.get())) { @@ -275,7 +275,7 @@ int32_t main(int32_t argc, char *argv[]) { torch::Tensor wave = sherpa::ReadWave(po.GetArg(i), expected_sample_rate).first; - wave = wave.to( config.feat_config.fbank_opts.device ); + wave = wave.to(config.feat_config.fbank_opts.device); s->AcceptWaveform(expected_sample_rate, wave); s->AcceptWaveform(expected_sample_rate, tail_padding);