@@ -78,13 +78,12 @@ class TestTransformsWithAsset(TorchtextTestCase):
7878 def test_vocab_transform (self ):
7979 asset_name = 'vocab_test2.txt'
8080 asset_path = get_asset_path (asset_name )
81- with open (asset_path , 'r' ) as f :
82- vocab_transform = VocabTransform (load_vocab_from_file (f ))
83- self .assertEqual (vocab_transform (['of' , 'that' , 'new' ]),
84- [7 , 18 , 24 ])
85- jit_vocab_transform = torch .jit .script (vocab_transform )
86- self .assertEqual (jit_vocab_transform (['of' , 'that' , 'new' , 'that' ]),
87- [7 , 18 , 24 , 18 ])
81+ vocab_transform = VocabTransform (load_vocab_from_file (asset_path ))
82+ self .assertEqual (vocab_transform (['of' , 'that' , 'new' ]),
83+ [7 , 18 , 24 ])
84+ jit_vocab_transform = torch .jit .script (vocab_transform )
85+ self .assertEqual (jit_vocab_transform (['of' , 'that' , 'new' , 'that' ]),
86+ [7 , 18 , 24 , 18 ])
8887
8988 def test_errors_vectors_python (self ):
9089 tokens = []
@@ -179,27 +178,25 @@ def test_glove_different_dims(self):
179178 def test_vocab_from_file (self ):
180179 asset_name = 'vocab_test.txt'
181180 asset_path = get_asset_path (asset_name )
182- with open (asset_path , 'r' ) as f :
183- v = load_vocab_from_file (f , unk_token = '<new_unk>' )
184- expected_itos = ['<new_unk>' , 'b' , 'a' , 'c' ]
185- expected_stoi = {x : index for index , x in enumerate (expected_itos )}
186- self .assertEqual (v .get_itos (), expected_itos )
187- self .assertEqual (dict (v .get_stoi ()), expected_stoi )
181+ v = load_vocab_from_file (asset_path , unk_token = '<new_unk>' )
182+ expected_itos = ['<new_unk>' , 'b' , 'a' , 'c' ]
183+ expected_stoi = {x : index for index , x in enumerate (expected_itos )}
184+ self .assertEqual (v .get_itos (), expected_itos )
185+ self .assertEqual (dict (v .get_stoi ()), expected_stoi )
188186
189187 def test_vocab_from_raw_text_file (self ):
190188 asset_name = 'vocab_raw_text_test.txt'
191189 asset_path = get_asset_path (asset_name )
192- with open (asset_path , 'r' ) as f :
193- tokenizer = basic_english_normalize ()
194- jit_tokenizer = torch .jit .script (tokenizer )
195- v = build_vocab_from_text_file (f , jit_tokenizer , unk_token = '<new_unk>' )
196- expected_itos = ['<new_unk>' , "'" , 'after' , 'talks' , '.' , 'are' , 'at' , 'disappointed' ,
197- 'fears' , 'federal' , 'firm' , 'for' , 'mogul' , 'n' , 'newall' , 'parent' ,
198- 'pension' , 'representing' , 'say' , 'stricken' , 't' , 'they' , 'turner' ,
199- 'unions' , 'with' , 'workers' ]
200- expected_stoi = {x : index for index , x in enumerate (expected_itos )}
201- self .assertEqual (v .get_itos (), expected_itos )
202- self .assertEqual (dict (v .get_stoi ()), expected_stoi )
190+ tokenizer = basic_english_normalize ()
191+ jit_tokenizer = torch .jit .script (tokenizer )
192+ v = build_vocab_from_text_file (asset_path , jit_tokenizer , unk_token = '<new_unk>' )
193+ expected_itos = ['<new_unk>' , "'" , 'after' , 'talks' , '.' , 'are' , 'at' , 'disappointed' ,
194+ 'fears' , 'federal' , 'firm' , 'for' , 'mogul' , 'n' , 'newall' , 'parent' ,
195+ 'pension' , 'representing' , 'say' , 'stricken' , 't' , 'they' , 'turner' ,
196+ 'unions' , 'with' , 'workers' ]
197+ expected_stoi = {x : index for index , x in enumerate (expected_itos )}
198+ self .assertEqual (v .get_itos (), expected_itos )
199+ self .assertEqual (dict (v .get_stoi ()), expected_stoi )
203200
204201 def test_builtin_pretrained_sentencepiece_processor (self ):
205202 sp_model_path = download_from_url (PRETRAINED_SP_MODEL ['text_unigram_25000' ])
@@ -241,11 +238,10 @@ def batch_func(data):
241238 def test_text_sequential_transform (self ):
242239 asset_name = 'vocab_test2.txt'
243240 asset_path = get_asset_path (asset_name )
244- with open (asset_path , 'r' ) as f :
245- pipeline = TextSequentialTransforms (basic_english_normalize (), load_vocab_from_file (f ))
246- jit_pipeline = torch .jit .script (pipeline )
247- self .assertEqual (pipeline ('of that new' ), [7 , 18 , 24 ])
248- self .assertEqual (jit_pipeline ('of that new' ), [7 , 18 , 24 ])
241+ pipeline = TextSequentialTransforms (basic_english_normalize (), load_vocab_from_file (asset_path ))
242+ jit_pipeline = torch .jit .script (pipeline )
243+ self .assertEqual (pipeline ('of that new' ), [7 , 18 , 24 ])
244+ self .assertEqual (jit_pipeline ('of that new' ), [7 , 18 , 24 ])
249245
250246 def test_vectors_from_file (self ):
251247 asset_name = 'vectors_test.csv'
0 commit comments