Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 3de3fcf

Browse files
Switch data_select in dataset signature to split (#1143)
1 parent 1ac252b commit 3de3fcf

File tree

12 files changed

+216
-216
lines changed

12 files changed

+216
-216
lines changed

test/data/test_builtin_datasets.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ def test_wikitext2(self):
5757
self.assertEqual(tokens_ids, [2, 286, 503, 700])
5858

5959
# Add test for the subset of the standard datasets
60-
train_iter, valid_iter, test_iter = torchtext.experimental.datasets.raw.WikiText2(data_select=('train', 'valid', 'test'))
60+
train_iter, valid_iter, test_iter = torchtext.experimental.datasets.raw.WikiText2(split=('train', 'valid', 'test'))
6161
self._helper_test_func(len(train_iter), 36718, next(iter(train_iter)), ' \n')
6262
self._helper_test_func(len(valid_iter), 3760, next(iter(valid_iter)), ' \n')
6363
self._helper_test_func(len(test_iter), 4358, next(iter(test_iter)), ' \n')
6464
del train_iter, valid_iter, test_iter
65-
train_dataset, test_dataset = WikiText2(data_select=('train', 'test'))
65+
train_dataset, test_dataset = WikiText2(split=('train', 'test'))
6666
train_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, train_dataset)))
6767
test_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, test_dataset)))
6868
self._helper_test_func(len(train_data), 2049990, train_data[20:25],
@@ -105,14 +105,14 @@ def test_penntreebank(self):
105105
self.assertEqual(tokens_ids, [2, 2550, 3344, 1125])
106106

107107
# Add test for the subset of the standard datasets
108-
train_dataset, test_dataset = PennTreebank(data_select=('train', 'test'))
108+
train_dataset, test_dataset = PennTreebank(split=('train', 'test'))
109109
train_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, train_dataset)))
110110
test_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, test_dataset)))
111111
self._helper_test_func(len(train_data), 924412, train_data[20:25],
112112
[9919, 9920, 9921, 9922, 9188])
113113
self._helper_test_func(len(test_data), 82114, test_data[30:35],
114114
[397, 93, 4, 16, 7])
115-
train_iter, test_iter = torchtext.experimental.datasets.raw.PennTreebank(data_select=('train', 'test'))
115+
train_iter, test_iter = torchtext.experimental.datasets.raw.PennTreebank(split=('train', 'test'))
116116
self._helper_test_func(len(train_iter), 42068, next(iter(train_iter))[:15], ' aer banknote b')
117117
self._helper_test_func(len(test_iter), 3761, next(iter(test_iter))[:25], " no it was n't black mond")
118118
del train_iter, test_iter
@@ -130,7 +130,7 @@ def test_text_classification(self):
130130
[2351, 758, 96, 38581, 2351, 220, 5, 396, 3, 14786])
131131

132132
# Add test for the subset of the standard datasets
133-
train_dataset, = AG_NEWS(data_select=('train'))
133+
train_dataset, = AG_NEWS(split=('train'))
134134
self._helper_test_func(len(train_dataset), 120000, train_dataset[-1][1][:10],
135135
[2155, 223, 2405, 30, 3010, 2204, 54, 3603, 4930, 2405])
136136
train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS()
@@ -160,7 +160,7 @@ def test_imdb(self):
160160
new_train_data, new_test_data = IMDB(vocab=new_vocab)
161161

162162
# Add test for the subset of the standard datasets
163-
train_dataset, = IMDB(data_select=('train'))
163+
train_dataset, = IMDB(split=('train'))
164164
self._helper_test_func(len(train_dataset), 25000, train_dataset[0][1][:10],
165165
[13, 1568, 13, 246, 35468, 43, 64, 398, 1135, 92])
166166
train_iter, test_iter = torchtext.experimental.datasets.raw.IMDB()
@@ -240,15 +240,15 @@ def test_multi30k(self):
240240
[18, 24, 1168, 807, 16, 56, 83, 335, 1338])
241241

242242
# Add test for the subset of the standard datasets
243-
train_iter, valid_iter = torchtext.experimental.datasets.raw.Multi30k(data_select=('train', 'valid'))
243+
train_iter, valid_iter = torchtext.experimental.datasets.raw.Multi30k(split=('train', 'valid'))
244244
self._helper_test_func(len(train_iter), 29000, ' '.join(next(iter(train_iter))),
245245
' '.join(['Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.\n',
246246
'Two young, White males are outside near many bushes.\n']))
247247
self._helper_test_func(len(valid_iter), 1014, ' '.join(next(iter(valid_iter))),
248248
' '.join(['Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen\n',
249249
'A group of men are loading cotton onto a truck\n']))
250250
del train_iter, valid_iter
251-
train_dataset, = Multi30k(data_select=('train'))
251+
train_dataset, = Multi30k(split=('train'))
252252

253253
# This change is due to the BC breaking in spacy 3.0
254254
self._helper_test_func(len(train_dataset), 29000, train_dataset[20],
@@ -311,11 +311,11 @@ def test_udpos_sequence_tagging(self):
311311
self.assertEqual(tokens_ids, [1206, 8, 69, 60, 157, 452])
312312

313313
# Add test for the subset of the standard datasets
314-
train_dataset, = UDPOS(data_select=('train'))
314+
train_dataset, = UDPOS(split=('train'))
315315
self._helper_test_func(len(train_dataset), 12543, (train_dataset[0][0][:10], train_dataset[-1][2][:10]),
316316
([262, 16, 5728, 45, 289, 701, 1160, 4436, 10660, 585],
317317
[6, 20, 8, 10, 8, 8, 24, 13, 8, 15]))
318-
train_iter, valid_iter = torchtext.experimental.datasets.raw.UDPOS(data_select=('train', 'valid'))
318+
train_iter, valid_iter = torchtext.experimental.datasets.raw.UDPOS(split=('train', 'valid'))
319319
self._helper_test_func(len(train_iter), 12543, ' '.join(next(iter(train_iter))[0][:5]),
320320
' '.join(['Al', '-', 'Zaman', ':', 'American']))
321321
self._helper_test_func(len(valid_iter), 2002, ' '.join(next(iter(valid_iter))[0][:5]),
@@ -358,7 +358,7 @@ def test_conll_sequence_tagging(self):
358358
self.assertEqual(tokens_ids, [970, 5, 135, 43, 214, 690])
359359

360360
# Add test for the subset of the standard datasets
361-
train_dataset, = CoNLL2000Chunking(data_select=('train'))
361+
train_dataset, = CoNLL2000Chunking(split=('train'))
362362
self._helper_test_func(len(train_dataset), 8936, (train_dataset[0][0][:10], train_dataset[0][1][:10],
363363
train_dataset[0][2][:10], train_dataset[-1][0][:10],
364364
train_dataset[-1][1][:10], train_dataset[-1][2][:10]),
@@ -393,7 +393,7 @@ def test_squad1(self):
393393
new_train_data, new_test_data = SQuAD1(vocab=new_vocab)
394394

395395
# Add test for the subset of the standard datasets
396-
train_dataset, = SQuAD1(data_select=('train'))
396+
train_dataset, = SQuAD1(split=('train'))
397397
context, question, answers, ans_pos = train_dataset[100]
398398
self._helper_test_func(len(train_dataset), 87599, (question[:5], ans_pos[0]),
399399
([7, 24, 86, 52, 2], [72, 72]))
@@ -422,7 +422,7 @@ def test_squad2(self):
422422
new_train_data, new_test_data = SQuAD2(vocab=new_vocab)
423423

424424
# Add test for the subset of the standard datasets
425-
train_dataset, = SQuAD2(data_select=('train'))
425+
train_dataset, = SQuAD2(split=('train'))
426426
context, question, answers, ans_pos = train_dataset[200]
427427
self._helper_test_func(len(train_dataset), 130319, (question[:5], ans_pos[0]),
428428
([84, 50, 1421, 12, 5439], [9, 9]))

test/experimental/test_with_asset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ def test_wikitext103(self):
6262
self.assertEqual(tokens_ids, [2, 320, 437, 687])
6363

6464
# Add test for the subset of the standard datasets
65-
train_dataset, test_dataset = torchtext.experimental.datasets.raw.WikiText103(data_select=('train', 'test'))
65+
train_dataset, test_dataset = torchtext.experimental.datasets.raw.WikiText103(split=('train', 'test'))
6666
self._helper_test_func(len(train_dataset), 1801350, next(iter(train_dataset)), ' \n')
6767
self._helper_test_func(len(test_dataset), 4358, next(iter(test_dataset)), ' \n')
68-
train_dataset, test_dataset = WikiText103(vocab=builtin_vocab, data_select=('train', 'test'))
68+
train_dataset, test_dataset = WikiText103(vocab=builtin_vocab, split=('train', 'test'))
6969
self._helper_test_func(len(train_dataset), 1801350, train_dataset[10][:5],
7070
[2, 69, 12, 14, 265])
7171
self._helper_test_func(len(test_dataset), 4358, test_dataset[28][:5],

torchtext/experimental/datasets/language_modeling.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,19 @@ def get_vocab(self):
5858
return self.vocab
5959

6060

61-
def _setup_datasets(dataset_name, tokenizer, root, vocab, data_select, year, language):
61+
def _setup_datasets(dataset_name, tokenizer, root, vocab, split, year, language):
6262
if tokenizer is None:
6363
tokenizer = get_tokenizer('basic_english')
6464

65-
data_select = check_default_set(data_select, ('train', 'test', 'valid'))
65+
split = check_default_set(split, ('train', 'test', 'valid'))
6666

6767
if vocab is None:
68-
if 'train' not in data_select:
68+
if 'train' not in split:
6969
raise TypeError("Must pass a vocab if train is not selected.")
7070
if dataset_name == 'WMTNewsCrawl':
71-
raw_train, = raw.DATASETS[dataset_name](root=root, data_select=('train',), year=year, language=language)
71+
raw_train, = raw.DATASETS[dataset_name](root=root, split=('train',), year=year, language=language)
7272
else:
73-
raw_train, = raw.DATASETS[dataset_name](root=root, data_select=('train',))
73+
raw_train, = raw.DATASETS[dataset_name](root=root, split=('train',))
7474
logger_.info('Building Vocab based on train data')
7575
vocab = build_vocab(raw_train, tokenizer)
7676
logger_.info('Vocab has %d entries', len(vocab))
@@ -79,16 +79,16 @@ def text_transform(line):
7979
return torch.tensor([vocab[token] for token in tokenizer(line)], dtype=torch.long)
8080

8181
if dataset_name == 'WMTNewsCrawl':
82-
raw_datasets = raw.DATASETS[dataset_name](root=root, data_select=data_select, year=year, language=language)
82+
raw_datasets = raw.DATASETS[dataset_name](root=root, split=split, year=year, language=language)
8383
else:
84-
raw_datasets = raw.DATASETS[dataset_name](root=root, data_select=data_select)
85-
raw_data = {name: list(map(text_transform, raw_dataset)) for name, raw_dataset in zip(data_select, raw_datasets)}
86-
logger_.info('Building datasets for {}'.format(data_select))
84+
raw_datasets = raw.DATASETS[dataset_name](root=root, split=split)
85+
raw_data = {name: list(map(text_transform, raw_dataset)) for name, raw_dataset in zip(split, raw_datasets)}
86+
logger_.info('Building datasets for {}'.format(split))
8787
return tuple(LanguageModelingDataset(raw_data[item], vocab, text_transform)
88-
for item in data_select)
88+
for item in split)
8989

9090

91-
def WikiText2(tokenizer=None, root='.data', vocab=None, data_select=('train', 'valid', 'test')):
91+
def WikiText2(tokenizer=None, root='.data', vocab=None, split=('train', 'valid', 'test')):
9292
""" Defines WikiText2 datasets.
9393
9494
Create language modeling dataset: WikiText2
@@ -102,7 +102,7 @@ def WikiText2(tokenizer=None, root='.data', vocab=None, data_select=('train', 'v
102102
root: Directory where the datasets are saved. Default: ".data"
103103
vocab: Vocabulary used for dataset. If None, it will generate a new
104104
vocabulary based on the train data set.
105-
data_select: a string or tupel for the returned datasets. Default: ('train', 'valid','test')
105+
split: a string or tuple for the returned datasets. Default: ('train', 'valid','test')
106106
By default, all the three datasets (train, test, valid) are generated. Users
107107
could also choose any one or two of them, for example ('train', 'test') or
108108
just a string 'train'. If 'train' is not in the tuple or string, a vocab
@@ -116,13 +116,13 @@ def WikiText2(tokenizer=None, root='.data', vocab=None, data_select=('train', 'v
116116
>>> train_dataset, valid_dataset, test_dataset = WikiText2(tokenizer=tokenizer)
117117
>>> vocab = train_dataset.get_vocab()
118118
>>> valid_dataset, = WikiText2(tokenizer=tokenizer, vocab=vocab,
119-
data_select='valid')
119+
split='valid')
120120
121121
"""
122-
return _setup_datasets("WikiText2", tokenizer, root, vocab, data_select, None, None)
122+
return _setup_datasets("WikiText2", tokenizer, root, vocab, split, None, None)
123123

124124

125-
def WikiText103(tokenizer=None, root='.data', vocab=None, data_select=('train', 'valid', 'test')):
125+
def WikiText103(tokenizer=None, root='.data', vocab=None, split=('train', 'valid', 'test')):
126126
""" Defines WikiText103 datasets.
127127
128128
Create language modeling dataset: WikiText103
@@ -136,7 +136,7 @@ def WikiText103(tokenizer=None, root='.data', vocab=None, data_select=('train',
136136
root: Directory where the datasets are saved. Default: ".data"
137137
vocab: Vocabulary used for dataset. If None, it will generate a new
138138
vocabulary based on the train data set.
139-
data_select: a string or tupel for the returned datasets. Default: ('train', 'valid', 'test')
139+
split: a string or tuple for the returned datasets. Default: ('train', 'valid', 'test')
140140
By default, all the three datasets (train, test, valid) are generated. Users
141141
could also choose any one or two of them, for example ('train', 'test') or
142142
just a string 'train'. If 'train' is not in the tuple or string, a vocab
@@ -150,14 +150,14 @@ def WikiText103(tokenizer=None, root='.data', vocab=None, data_select=('train',
150150
>>> train_dataset, valid_dataset, test_dataset = WikiText103(tokenizer=tokenizer)
151151
>>> vocab = train_dataset.get_vocab()
152152
>>> valid_dataset, = WikiText103(tokenizer=tokenizer, vocab=vocab,
153-
data_select='valid')
153+
split='valid')
154154
155155
"""
156156

157-
return _setup_datasets("WikiText103", tokenizer, root, vocab, data_select, None, None)
157+
return _setup_datasets("WikiText103", tokenizer, root, vocab, split, None, None)
158158

159159

160-
def PennTreebank(tokenizer=None, root='.data', vocab=None, data_select=('train', 'valid', 'test')):
160+
def PennTreebank(tokenizer=None, root='.data', vocab=None, split=('train', 'valid', 'test')):
161161
""" Defines PennTreebank datasets.
162162
163163
Create language modeling dataset: PennTreebank
@@ -171,7 +171,7 @@ def PennTreebank(tokenizer=None, root='.data', vocab=None, data_select=('train',
171171
root: Directory where the datasets are saved. Default: ".data"
172172
vocab: Vocabulary used for dataset. If None, it will generate a new
173173
vocabulary based on the train data set.
174-
data_select: a string or tupel for the returned datasets. Default: ('train', 'valid', 'test')
174+
split: a string or tuple for the returned datasets. Default: ('train', 'valid', 'test')
175175
By default, all the three datasets (train, test, valid) are generated. Users
176176
could also choose any one or two of them, for example ('train', 'test') or
177177
just a string 'train'. If 'train' is not in the tuple or string, a vocab
@@ -185,14 +185,14 @@ def PennTreebank(tokenizer=None, root='.data', vocab=None, data_select=('train',
185185
>>> train_dataset, valid_dataset, test_dataset = PennTreebank(tokenizer=tokenizer)
186186
>>> vocab = train_dataset.get_vocab()
187187
>>> valid_dataset, = PennTreebank(tokenizer=tokenizer, vocab=vocab,
188-
data_select='valid')
188+
split='valid')
189189
190190
"""
191191

192-
return _setup_datasets("PennTreebank", tokenizer, root, vocab, data_select, None, None)
192+
return _setup_datasets("PennTreebank", tokenizer, root, vocab, split, None, None)
193193

194194

195-
def WMTNewsCrawl(tokenizer=None, root='.data', vocab=None, data_select=('train'), year=2010, language='en'):
195+
def WMTNewsCrawl(tokenizer=None, root='.data', vocab=None, split=('train'), year=2010, language='en'):
196196
""" Defines WMTNewsCrawl datasets.
197197
198198
Create language modeling dataset: WMTNewsCrawl
@@ -206,7 +206,7 @@ def WMTNewsCrawl(tokenizer=None, root='.data', vocab=None, data_select=('train')
206206
root: Directory where the datasets are saved. Default: ".data"
207207
vocab: Vocabulary used for dataset. If None, it will generate a new
208208
vocabulary based on the train data set.
209-
data_select: a string or tuple for the returned datasets
209+
split: a string or tuple for the returned datasets
210210
(Default: ('train',))
211211
year: the year of the dataset (Default: 2010)
212212
language: the language of the dataset (Default: 'en')
@@ -215,12 +215,12 @@ def WMTNewsCrawl(tokenizer=None, root='.data', vocab=None, data_select=('train')
215215
>>> from torchtext.experimental.datasets import WMTNewsCrawl
216216
>>> from torchtext.data.utils import get_tokenizer
217217
>>> tokenizer = get_tokenizer("spacy")
218-
>>> train_dataset, = WMTNewsCrawl(tokenizer=tokenizer, data_select='train')
218+
>>> train_dataset, = WMTNewsCrawl(tokenizer=tokenizer, split='train')
219219
220220
Note: WMTNewsCrawl provides datasets based on the year and language instead of train/valid/test.
221221
"""
222222

223-
return _setup_datasets("WMTNewsCrawl", tokenizer, root, vocab, data_select, year, language)
223+
return _setup_datasets("WMTNewsCrawl", tokenizer, root, vocab, split, year, language)
224224

225225

226226
DATASETS = {

0 commit comments

Comments
 (0)