Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 90fd332

Browse files
authored
Modified experimental vocab factory functions API (#1286)
1 parent fab63ed commit 90fd332

File tree

3 files changed

+36
-44
lines changed

3 files changed

+36
-44
lines changed

test/experimental/test_vocab.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def test_vocab_load_and_save(self):
219219

220220
def test_build_vocab_iterator(self):
221221
iterator = [['hello', 'hello', 'hello', 'freq_low', 'hello', 'world', 'world', 'world', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T',
222-
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'freq_low', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T']]
222+
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'freq_low', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T']]
223223
v = build_vocab_from_iterator(iterator)
224224
expected_itos = ['<unk>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world', 'freq_low']
225225
expected_stoi = {x: index for index, x in enumerate(expected_itos)}

test/experimental/test_with_asset.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,12 @@ class TestTransformsWithAsset(TorchtextTestCase):
7878
def test_vocab_transform(self):
7979
asset_name = 'vocab_test2.txt'
8080
asset_path = get_asset_path(asset_name)
81-
with open(asset_path, 'r') as f:
82-
vocab_transform = VocabTransform(load_vocab_from_file(f))
83-
self.assertEqual(vocab_transform(['of', 'that', 'new']),
84-
[7, 18, 24])
85-
jit_vocab_transform = torch.jit.script(vocab_transform)
86-
self.assertEqual(jit_vocab_transform(['of', 'that', 'new', 'that']),
87-
[7, 18, 24, 18])
81+
vocab_transform = VocabTransform(load_vocab_from_file(asset_path))
82+
self.assertEqual(vocab_transform(['of', 'that', 'new']),
83+
[7, 18, 24])
84+
jit_vocab_transform = torch.jit.script(vocab_transform)
85+
self.assertEqual(jit_vocab_transform(['of', 'that', 'new', 'that']),
86+
[7, 18, 24, 18])
8887

8988
def test_errors_vectors_python(self):
9089
tokens = []
@@ -179,27 +178,25 @@ def test_glove_different_dims(self):
179178
def test_vocab_from_file(self):
180179
asset_name = 'vocab_test.txt'
181180
asset_path = get_asset_path(asset_name)
182-
with open(asset_path, 'r') as f:
183-
v = load_vocab_from_file(f, unk_token='<new_unk>')
184-
expected_itos = ['<new_unk>', 'b', 'a', 'c']
185-
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
186-
self.assertEqual(v.get_itos(), expected_itos)
187-
self.assertEqual(dict(v.get_stoi()), expected_stoi)
181+
v = load_vocab_from_file(asset_path, unk_token='<new_unk>')
182+
expected_itos = ['<new_unk>', 'b', 'a', 'c']
183+
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
184+
self.assertEqual(v.get_itos(), expected_itos)
185+
self.assertEqual(dict(v.get_stoi()), expected_stoi)
188186

189187
def test_vocab_from_raw_text_file(self):
190188
asset_name = 'vocab_raw_text_test.txt'
191189
asset_path = get_asset_path(asset_name)
192-
with open(asset_path, 'r') as f:
193-
tokenizer = basic_english_normalize()
194-
jit_tokenizer = torch.jit.script(tokenizer)
195-
v = build_vocab_from_text_file(f, jit_tokenizer, unk_token='<new_unk>')
196-
expected_itos = ['<new_unk>', "'", 'after', 'talks', '.', 'are', 'at', 'disappointed',
197-
'fears', 'federal', 'firm', 'for', 'mogul', 'n', 'newall', 'parent',
198-
'pension', 'representing', 'say', 'stricken', 't', 'they', 'turner',
199-
'unions', 'with', 'workers']
200-
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
201-
self.assertEqual(v.get_itos(), expected_itos)
202-
self.assertEqual(dict(v.get_stoi()), expected_stoi)
190+
tokenizer = basic_english_normalize()
191+
jit_tokenizer = torch.jit.script(tokenizer)
192+
v = build_vocab_from_text_file(asset_path, jit_tokenizer, unk_token='<new_unk>')
193+
expected_itos = ['<new_unk>', "'", 'after', 'talks', '.', 'are', 'at', 'disappointed',
194+
'fears', 'federal', 'firm', 'for', 'mogul', 'n', 'newall', 'parent',
195+
'pension', 'representing', 'say', 'stricken', 't', 'they', 'turner',
196+
'unions', 'with', 'workers']
197+
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
198+
self.assertEqual(v.get_itos(), expected_itos)
199+
self.assertEqual(dict(v.get_stoi()), expected_stoi)
203200

204201
def test_builtin_pretrained_sentencepiece_processor(self):
205202
sp_model_path = download_from_url(PRETRAINED_SP_MODEL['text_unigram_25000'])
@@ -241,11 +238,10 @@ def batch_func(data):
241238
def test_text_sequential_transform(self):
242239
asset_name = 'vocab_test2.txt'
243240
asset_path = get_asset_path(asset_name)
244-
with open(asset_path, 'r') as f:
245-
pipeline = TextSequentialTransforms(basic_english_normalize(), load_vocab_from_file(f))
246-
jit_pipeline = torch.jit.script(pipeline)
247-
self.assertEqual(pipeline('of that new'), [7, 18, 24])
248-
self.assertEqual(jit_pipeline('of that new'), [7, 18, 24])
241+
pipeline = TextSequentialTransforms(basic_english_normalize(), load_vocab_from_file(asset_path))
242+
jit_pipeline = torch.jit.script(pipeline)
243+
self.assertEqual(pipeline('of that new'), [7, 18, 24])
244+
self.assertEqual(jit_pipeline('of that new'), [7, 18, 24])
249245

250246
def test_vectors_from_file(self):
251247
asset_name = 'vectors_test.csv'

torchtext/experimental/vocab.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@
1919
logger = logging.getLogger(__name__)
2020

2121

22-
def build_vocab_from_text_file(file_object, jited_tokenizer, min_freq=1, unk_token='<unk>', num_cpus=4):
22+
def build_vocab_from_text_file(file_path, jited_tokenizer, min_freq=1, unk_token='<unk>', num_cpus=4):
2323
r"""Create a `Vocab` object from a raw text file.
2424
25-
The `file_object` can contain any raw text. This function applies a generic JITed tokenizer in
26-
parallel to the text. Note that the vocab will be created in the order that the tokens first appear
27-
in the file (and not by the frequency of tokens).
25+
The `file_path` can contain any raw text. This function applies a generic JITed tokenizer in
26+
parallel to the text.
2827
2928
Args:
3029
file_object (FileObject): a file object to read data from.
@@ -40,20 +39,18 @@ def build_vocab_from_text_file(file_object, jited_tokenizer, min_freq=1, unk_tok
4039
Examples:
4140
>>> from torchtext.experimental.vocab import build_vocab_from_text_file
4241
>>> from torchtext.experimental.transforms import basic_english_normalize
43-
>>> f = open('vocab.txt', 'r')
44-
>>> tokenizer = basic_english_normalize()
42+
>>> tokenizer = basic_english_normalize()
4543
>>> tokenizer = basic_english_normalize()
4644
>>> jit_tokenizer = torch.jit.script(tokenizer)
47-
>>> v = build_vocab_from_text_file(f, jit_tokenizer)
45+
>>> v = build_vocab_from_text_file('vocab.txt', jit_tokenizer)
4846
"""
49-
vocab_obj = _build_vocab_from_text_file(file_object.name, unk_token, min_freq, num_cpus, jited_tokenizer)
47+
vocab_obj = _build_vocab_from_text_file(file_path, unk_token, min_freq, num_cpus, jited_tokenizer)
5048
return Vocab(vocab_obj)
5149

5250

53-
def load_vocab_from_file(file_object, min_freq=1, unk_token='<unk>', num_cpus=4):
51+
def load_vocab_from_file(file_path, min_freq=1, unk_token='<unk>', num_cpus=4):
5452
r"""Create a `Vocab` object from a text file.
55-
The `file_object` should contain tokens separated by new lines. Note that the vocab
56-
will be created in the order that the tokens first appear in the file (and not by the frequency of tokens).
53+
The `file_path` should contain tokens separated by new lines.
5754
Format for txt file:
5855
5956
token1
@@ -73,11 +70,10 @@ def load_vocab_from_file(file_object, min_freq=1, unk_token='<unk>', num_cpus=4)
7370
7471
Examples:
7572
>>> from torchtext.experimental.vocab import load_vocab_from_file
76-
>>> f = open('vocab.txt', 'r')
77-
>>> v = load_vocab_from_file(f)
73+
>>> v = load_vocab_from_file('vocab.txt')
7874
"""
7975

80-
vocab_obj = _load_vocab_from_file(file_object.name, unk_token, min_freq, num_cpus)
76+
vocab_obj = _load_vocab_from_file(file_path, unk_token, min_freq, num_cpus)
8177
return Vocab(vocab_obj)
8278

8379

0 commit comments

Comments
 (0)