44from __future__ import print_function
55
66import unittest
7- from decoder import *
7+ from models import decoder
88
99
1010class TestDecoders (unittest .TestCase ):
@@ -53,11 +53,13 @@ def setUp(self):
5353 self .beam_search_result = ['acdc' , "b'a" ]
5454
5555 def test_greedy_decoder_1 (self ):
56- bst_result = ctc_greedy_decoder (self .probs_seq1 , self .vocab_list )
56+ bst_result = decoder .ctc_greedy_decoder (self .probs_seq1 ,
57+ self .vocab_list )
5758 self .assertEqual (bst_result , self .greedy_result [0 ])
5859
5960 def test_greedy_decoder_2 (self ):
60- bst_result = ctc_greedy_decoder (self .probs_seq2 , self .vocab_list )
61+ bst_result = decoder .ctc_greedy_decoder (self .probs_seq2 ,
62+ self .vocab_list )
6163 self .assertEqual (bst_result , self .greedy_result [1 ])
6264
6365 def test_beam_search_decoder_1 (self ):
@@ -69,15 +71,15 @@ def test_beam_search_decoder_1(self):
6971 self .assertEqual (beam_result [0 ][1 ], self .beam_search_result [0 ])
7072
7173 def test_beam_search_decoder_2 (self ):
72- beam_result = ctc_beam_search_decoder (
74+ beam_result = decoder . ctc_beam_search_decoder (
7375 probs_seq = self .probs_seq2 ,
7476 beam_size = self .beam_size ,
7577 vocabulary = self .vocab_list ,
7678 blank_id = len (self .vocab_list ))
7779 self .assertEqual (beam_result [0 ][1 ], self .beam_search_result [1 ])
7880
7981 def test_beam_search_decoder_batch (self ):
80- beam_results = ctc_beam_search_decoder_batch (
82+ beam_results = decoder . ctc_beam_search_decoder_batch (
8183 probs_split = [self .probs_seq1 , self .probs_seq2 ],
8284 beam_size = self .beam_size ,
8385 vocabulary = self .vocab_list ,
0 commit comments