Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit a0975d3

Browse files
add __next__ method to RawTextIterableDataset (#1141)
1 parent 911744e commit a0975d3

File tree

2 files changed

+36
-19
lines changed

2 files changed

+36
-19
lines changed

test/data/test_builtin_datasets.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ def test_wikitext2(self):
5858

5959
# Add test for the subset of the standard datasets
6060
train_iter, valid_iter, test_iter = torchtext.experimental.datasets.raw.WikiText2(split=('train', 'valid', 'test'))
61-
self._helper_test_func(len(train_iter), 36718, next(iter(train_iter)), ' \n')
62-
self._helper_test_func(len(valid_iter), 3760, next(iter(valid_iter)), ' \n')
63-
self._helper_test_func(len(test_iter), 4358, next(iter(test_iter)), ' \n')
61+
self._helper_test_func(len(train_iter), 36718, next(train_iter), ' \n')
62+
self._helper_test_func(len(valid_iter), 3760, next(valid_iter), ' \n')
63+
self._helper_test_func(len(test_iter), 4358, next(test_iter), ' \n')
6464
del train_iter, valid_iter, test_iter
6565
train_dataset, test_dataset = WikiText2(split=('train', 'test'))
6666
train_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, train_dataset)))
@@ -113,8 +113,8 @@ def test_penntreebank(self):
113113
self._helper_test_func(len(test_data), 82114, test_data[30:35],
114114
[397, 93, 4, 16, 7])
115115
train_iter, test_iter = torchtext.experimental.datasets.raw.PennTreebank(split=('train', 'test'))
116-
self._helper_test_func(len(train_iter), 42068, next(iter(train_iter))[:15], ' aer banknote b')
117-
self._helper_test_func(len(test_iter), 3761, next(iter(test_iter))[:25], " no it was n't black mond")
116+
self._helper_test_func(len(train_iter), 42068, next(train_iter)[:15], ' aer banknote b')
117+
self._helper_test_func(len(test_iter), 3761, next(test_iter)[:25], " no it was n't black mond")
118118
del train_iter, test_iter
119119

120120
def test_text_classification(self):
@@ -134,8 +134,8 @@ def test_text_classification(self):
134134
self._helper_test_func(len(train_dataset), 120000, train_dataset[-1][1][:10],
135135
[2155, 223, 2405, 30, 3010, 2204, 54, 3603, 4930, 2405])
136136
train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS()
137-
self._helper_test_func(len(train_iter), 120000, next(iter(train_iter))[1][:25], 'Wall St. Bears Claw Back ')
138-
self._helper_test_func(len(test_iter), 7600, next(iter(test_iter))[1][:25], 'Fears for T N pension aft')
137+
self._helper_test_func(len(train_iter), 120000, next(train_iter)[1][:25], 'Wall St. Bears Claw Back ')
138+
self._helper_test_func(len(test_iter), 7600, next(test_iter)[1][:25], 'Fears for T N pension aft')
139139
del train_iter, test_iter
140140

141141
def test_num_lines_of_dataset(self):
@@ -151,6 +151,19 @@ def test_offset_dataset(self):
151151
'Non-OPEC Nations Sho', 'Google IPO Auction O',
152152
'Dollar Falls Broadly'])
153153

154+
def test_next_method_dataset(self):
155+
train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS()
156+
for_count = 0
157+
next_count = 0
158+
for line in train_iter:
159+
for_count += 1
160+
try:
161+
next(train_iter)
162+
next_count += 1
163+
except:
164+
break
165+
self.assertEqual((for_count, next_count), (60000, 60000))
166+
154167
def test_imdb(self):
155168
from torchtext.experimental.datasets import IMDB
156169
from torchtext.vocab import Vocab
@@ -171,8 +184,8 @@ def test_imdb(self):
171184
self._helper_test_func(len(train_dataset), 25000, train_dataset[0][1][:10],
172185
[13, 1568, 13, 246, 35468, 43, 64, 398, 1135, 92])
173186
train_iter, test_iter = torchtext.experimental.datasets.raw.IMDB()
174-
self._helper_test_func(len(train_iter), 25000, next(iter(train_iter))[1][:25], 'I rented I AM CURIOUS-YEL')
175-
self._helper_test_func(len(test_iter), 25000, next(iter(test_iter))[1][:25], 'I love sci-fi and am will')
187+
self._helper_test_func(len(train_iter), 25000, next(train_iter)[1][:25], 'I rented I AM CURIOUS-YEL')
188+
self._helper_test_func(len(test_iter), 25000, next(test_iter)[1][:25], 'I love sci-fi and am will')
176189
del train_iter, test_iter
177190

178191
def test_iwslt(self):
@@ -248,10 +261,10 @@ def test_multi30k(self):
248261

249262
# Add test for the subset of the standard datasets
250263
train_iter, valid_iter = torchtext.experimental.datasets.raw.Multi30k(split=('train', 'valid'))
251-
self._helper_test_func(len(train_iter), 29000, ' '.join(next(iter(train_iter))),
264+
self._helper_test_func(len(train_iter), 29000, ' '.join(next(train_iter)),
252265
' '.join(['Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.\n',
253266
'Two young, White males are outside near many bushes.\n']))
254-
self._helper_test_func(len(valid_iter), 1014, ' '.join(next(iter(valid_iter))),
267+
self._helper_test_func(len(valid_iter), 1014, ' '.join(next(valid_iter)),
255268
' '.join(['Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen\n',
256269
'A group of men are loading cotton onto a truck\n']))
257270
del train_iter, valid_iter
@@ -323,9 +336,9 @@ def test_udpos_sequence_tagging(self):
323336
([262, 16, 5728, 45, 289, 701, 1160, 4436, 10660, 585],
324337
[6, 20, 8, 10, 8, 8, 24, 13, 8, 15]))
325338
train_iter, valid_iter = torchtext.experimental.datasets.raw.UDPOS(split=('train', 'valid'))
326-
self._helper_test_func(len(train_iter), 12543, ' '.join(next(iter(train_iter))[0][:5]),
339+
self._helper_test_func(len(train_iter), 12543, ' '.join(next(train_iter)[0][:5]),
327340
' '.join(['Al', '-', 'Zaman', ':', 'American']))
328-
self._helper_test_func(len(valid_iter), 2002, ' '.join(next(iter(valid_iter))[0][:5]),
341+
self._helper_test_func(len(valid_iter), 2002, ' '.join(next(valid_iter)[0][:5]),
329342
' '.join(['From', 'the', 'AP', 'comes', 'this']))
330343
del train_iter, valid_iter
331344

@@ -376,9 +389,9 @@ def test_conll_sequence_tagging(self):
376389
[18, 17, 12, 19, 10, 6, 3, 3, 4, 4],
377390
[3, 5, 7, 7, 3, 2, 6, 6, 3, 2]))
378391
train_iter, test_iter = torchtext.experimental.datasets.raw.CoNLL2000Chunking()
379-
self._helper_test_func(len(train_iter), 8936, ' '.join(next(iter(train_iter))[0][:5]),
392+
self._helper_test_func(len(train_iter), 8936, ' '.join(next(train_iter)[0][:5]),
380393
' '.join(['Confidence', 'in', 'the', 'pound', 'is']))
381-
self._helper_test_func(len(test_iter), 2012, ' '.join(next(iter(test_iter))[0][:5]),
394+
self._helper_test_func(len(test_iter), 2012, ' '.join(next(test_iter)[0][:5]),
382395
' '.join(['Rockwell', 'International', 'Corp.', "'s", 'Tulsa']))
383396
del train_iter, test_iter
384397

@@ -405,9 +418,9 @@ def test_squad1(self):
405418
self._helper_test_func(len(train_dataset), 87599, (question[:5], ans_pos[0]),
406419
([7, 24, 86, 52, 2], [72, 72]))
407420
train_iter, dev_iter = torchtext.experimental.datasets.raw.SQuAD1()
408-
self._helper_test_func(len(train_iter), 87599, next(iter(train_iter))[0][:50],
421+
self._helper_test_func(len(train_iter), 87599, next(train_iter)[0][:50],
409422
'Architecturally, the school has a Catholic charact')
410-
self._helper_test_func(len(dev_iter), 10570, next(iter(dev_iter))[0][:50],
423+
self._helper_test_func(len(dev_iter), 10570, next(dev_iter)[0][:50],
411424
'Super Bowl 50 was an American football game to det')
412425
del train_iter, dev_iter
413426

@@ -434,8 +447,8 @@ def test_squad2(self):
434447
self._helper_test_func(len(train_dataset), 130319, (question[:5], ans_pos[0]),
435448
([84, 50, 1421, 12, 5439], [9, 9]))
436449
train_iter, dev_iter = torchtext.experimental.datasets.raw.SQuAD2()
437-
self._helper_test_func(len(train_iter), 130319, next(iter(train_iter))[0][:50],
450+
self._helper_test_func(len(train_iter), 130319, next(train_iter)[0][:50],
438451
'Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-Y')
439-
self._helper_test_func(len(dev_iter), 11873, next(iter(dev_iter))[0][:50],
452+
self._helper_test_func(len(dev_iter), 11873, next(dev_iter)[0][:50],
440453
'The Normans (Norman: Nourmands; French: Normands; ')
441454
del train_iter, dev_iter

torchtext/experimental/datasets/raw/common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ def __iter__(self):
3232
break
3333
yield item
3434

35+
def __next__(self):
36+
item = next(self._iterator)
37+
return item
38+
3539
def __len__(self):
3640
return self.num_lines
3741

0 commit comments

Comments
 (0)