Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions deep_speech_2/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import multiprocessing


def ctc_best_path_decoder(probs_seq, vocabulary):
"""Best path decoder, also called argmax decoder or greedy decoder.
def ctc_greedy_decoder(probs_seq, vocabulary):
"""CTC greedy (best path) decoder.

Path consisting of the most probable tokens are further post-processed to
remove consecutive repetitions and all blanks.

Expand Down Expand Up @@ -45,10 +46,12 @@ def ctc_beam_search_decoder(probs_seq,
cutoff_prob=1.0,
ext_scoring_func=None,
nproc=False):
"""Beam search decoder for CTC-trained network. It utilizes beam search
to approximately select top best decoding labels and returning results
in the descending order. The implementation is based on Prefix
Beam Search (https://arxiv.org/abs/1408.2873), and the unclear part is
"""CTC Beam search decoder.

It utilizes beam search to approximately select top best decoding
labels and returning results in the descending order.
The implementation is based on Prefix Beam Search
(https://arxiv.org/abs/1408.2873), and the unclear part is
redesigned. Two important modifications: 1) in the iterative computation
of probabilities, the assignment operation is changed to accumulation for
one prefix may comes from different paths; 2) the if condition "if l^+ not
Expand Down
166 changes: 53 additions & 113 deletions deep_speech_2/demo_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,123 +3,63 @@
import time
import random
import argparse
import distutils.util
import functools
from time import gmtime, strftime
import SocketServer
import struct
import wave
import paddle.v2 as paddle
from utils import print_arguments
from data_utils.data import DataGenerator
from model import DeepSpeech2Model
from data_utils.utils import read_manifest
from utils import add_arguments, print_arguments

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--host_ip",
default="localhost",
type=str,
help="Server IP address. (default: %(default)s)")
parser.add_argument(
"--host_port",
default=8086,
type=int,
help="Server Port. (default: %(default)s)")
parser.add_argument(
"--speech_save_dir",
default="demo_cache",
type=str,
help="Directory for saving demo speech. (default: %(default)s)")
parser.add_argument(
"--vocab_filepath",
default='datasets/vocab/eng_vocab.txt',
type=str,
help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument(
"--mean_std_filepath",
default='mean_std.npz',
type=str,
help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
"--warmup_manifest_path",
default='datasets/manifest.test',
type=str,
help="Manifest path for warmup test. (default: %(default)s)")
parser.add_argument(
"--specgram_type",
default='linear',
type=str,
help="Feature type of audio data: 'linear' (power spectrum)"
" or 'mfcc'. (default: %(default)s)")
parser.add_argument(
"--num_conv_layers",
default=2,
type=int,
help="Convolution layer number. (default: %(default)s)")
parser.add_argument(
"--num_rnn_layers",
default=3,
type=int,
help="RNN layer number. (default: %(default)s)")
parser.add_argument(
"--rnn_layer_size",
default=2048,
type=int,
help="RNN layer cell number. (default: %(default)s)")
parser.add_argument(
"--share_rnn_weights",
default=True,
type=distutils.util.strtobool,
help="Whether to share input-hidden weights between forword and backward "
"directional simple RNNs. Only available when use_gru=False. "
"(default: %(default)s)")
parser.add_argument(
"--use_gru",
default=False,
type=distutils.util.strtobool,
help="Use GRU or simple RNN. (default: %(default)s)")
parser.add_argument(
"--use_gpu",
default=True,
type=distutils.util.strtobool,
help="Use gpu or not. (default: %(default)s)")
parser.add_argument(
"--model_filepath",
default='checkpoints/params.latest.tar.gz',
type=str,
help="Model filepath. (default: %(default)s)")
parser.add_argument(
"--decode_method",
default='beam_search',
type=str,
help="Method for ctc decoding: best_path or beam_search. "
"(default: %(default)s)")
parser.add_argument(
"--beam_size",
default=100,
type=int,
help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument(
"--language_model_path",
default="lm/data/common_crawl_00.prune01111.trie.klm",
type=str,
help="Path for language model. (default: %(default)s)")
parser.add_argument(
"--alpha",
default=0.36,
type=float,
help="Parameter associated with language model. (default: %(default)f)")
parser.add_argument(
"--beta",
default=0.25,
type=float,
help="Parameter associated with word count. (default: %(default)f)")
parser.add_argument(
"--cutoff_prob",
default=0.99,
type=float,
help="The cutoff probability of pruning"
"in beam search. (default: %(default)f)")
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('host_port', int, 8086, "Server's IP port.")
add_arg('beam_size', int, 500, "Beam search width.")
add_arg('num_conv_layers', int, 2, "# of convolution layers.")
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
add_arg('alpha', float, 0.36, "Coef of LM for beam search.")
add_arg('beta', float, 0.25, "Coef of WC for beam search.")
add_arg('cutoff_prob', float, 0.99, "Cutoff probability for pruning.")
add_arg('use_gru', bool, False, "Use GRUs instead of Simple RNNs.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simple -> simple

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

add_arg('use_gpu', bool, True, "Use GPU or not.")
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
"bi-directional RNNs. Not for GRU.")
add_arg('host_ip', str,
'localhost',
"Server's IP address.")
add_arg('speech_save_dir', str,
'demo_cache',
"Directory to save demo audios.")
add_arg('warmup_manifest', str,
'datasets/manifest.test',
"Filepath of manifest to warm up.")
add_arg('mean_std_path', str,
'mean_std.npz',
"Filepath of normalizer's mean & std.")
add_arg('vocab_path', str,
'datasets/vocab/eng_vocab.txt',
"Filepath of vocabulary.")
add_arg('model_path', str,
'./checkpoints/params.latest.tar.gz',
"If None, the training starts from scratch, "
"otherwise, it resumes from the pre-trained model.")
add_arg('lang_model_path', str,
'lm/data/common_crawl_00.prune01111.trie.klm',
"Filepath for language model.")
add_arg('decoder_method', str,
'ctc_beam_search',
"Decoder method. Options: ctc_beam_search, ctc_greedy",
choices = ['ctc_beam_search', 'ctc_greedy'])
add_arg('specgram_type', str,
'linear',
"Audio feature type. Options: linear, mfcc.",
choices=['linear', 'mfcc'])
# yapf: disable
args = parser.parse_args()


Expand Down Expand Up @@ -200,8 +140,8 @@ def start_server():
"""Start the ASR server"""
# prepare data generator
data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath,
mean_std_filepath=args.mean_std_filepath,
vocab_filepath=args.vocab_path,
mean_std_filepath=args.mean_std_path,
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=1)
Expand All @@ -212,21 +152,21 @@ def start_server():
num_rnn_layers=args.num_rnn_layers,
rnn_layer_size=args.rnn_layer_size,
use_gru=args.use_gru,
pretrained_model_path=args.model_filepath,
pretrained_model_path=args.model_path,
share_rnn_weights=args.share_rnn_weights)

# prepare ASR inference handler
def file_to_transcript(filename):
feature = data_generator.process_utterance(filename, "")
result_transcript = ds2_model.infer_batch(
infer_data=[feature],
decode_method=args.decode_method,
decoder_method=args.decoder_method,
beam_alpha=args.alpha,
beam_beta=args.beta,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
vocab_list=data_generator.vocab_list,
language_model_path=args.language_model_path,
language_model_path=args.lang_model_path,
num_processes=1)
return result_transcript[0]

Expand All @@ -235,7 +175,7 @@ def file_to_transcript(filename):
print('Warming up ...')
warm_up_test(
audio_process_handler=file_to_transcript,
manifest_path=args.warmup_manifest_path,
manifest_path=args.warmup_manifest,
num_test_cases=3)
print('-----------------------------------------------------------')

Expand Down
Loading