@@ -58,19 +58,19 @@ def get_vocab(self):
5858 return self .vocab
5959
6060
61- def _setup_datasets (dataset_name , tokenizer , root , vocab , data_select , year , language ):
61+ def _setup_datasets (dataset_name , tokenizer , root , vocab , split , year , language ):
6262 if tokenizer is None :
6363 tokenizer = get_tokenizer ('basic_english' )
6464
65- data_select = check_default_set (data_select , ('train' , 'test' , 'valid' ))
65+ split = check_default_set (split , ('train' , 'test' , 'valid' ))
6666
6767 if vocab is None :
68- if 'train' not in data_select :
68+ if 'train' not in split :
6969 raise TypeError ("Must pass a vocab if train is not selected." )
7070 if dataset_name == 'WMTNewsCrawl' :
71- raw_train , = raw .DATASETS [dataset_name ](root = root , data_select = ('train' ,), year = year , language = language )
71+ raw_train , = raw .DATASETS [dataset_name ](root = root , split = ('train' ,), year = year , language = language )
7272 else :
73- raw_train , = raw .DATASETS [dataset_name ](root = root , data_select = ('train' ,))
73+ raw_train , = raw .DATASETS [dataset_name ](root = root , split = ('train' ,))
7474 logger_ .info ('Building Vocab based on train data' )
7575 vocab = build_vocab (raw_train , tokenizer )
7676 logger_ .info ('Vocab has %d entries' , len (vocab ))
@@ -79,16 +79,16 @@ def text_transform(line):
7979 return torch .tensor ([vocab [token ] for token in tokenizer (line )], dtype = torch .long )
8080
8181 if dataset_name == 'WMTNewsCrawl' :
82- raw_datasets = raw .DATASETS [dataset_name ](root = root , data_select = data_select , year = year , language = language )
82+ raw_datasets = raw .DATASETS [dataset_name ](root = root , split = split , year = year , language = language )
8383 else :
84- raw_datasets = raw .DATASETS [dataset_name ](root = root , data_select = data_select )
85- raw_data = {name : list (map (text_transform , raw_dataset )) for name , raw_dataset in zip (data_select , raw_datasets )}
86- logger_ .info ('Building datasets for {}' .format (data_select ))
84+ raw_datasets = raw .DATASETS [dataset_name ](root = root , split = split )
85+ raw_data = {name : list (map (text_transform , raw_dataset )) for name , raw_dataset in zip (split , raw_datasets )}
86+ logger_ .info ('Building datasets for {}' .format (split ))
8787 return tuple (LanguageModelingDataset (raw_data [item ], vocab , text_transform )
88- for item in data_select )
88+ for item in split )
8989
9090
91- def WikiText2 (tokenizer = None , root = '.data' , vocab = None , data_select = ('train' , 'valid' , 'test' )):
91+ def WikiText2 (tokenizer = None , root = '.data' , vocab = None , split = ('train' , 'valid' , 'test' )):
9292 """ Defines WikiText2 datasets.
9393
9494 Create language modeling dataset: WikiText2
@@ -102,7 +102,7 @@ def WikiText2(tokenizer=None, root='.data', vocab=None, data_select=('train', 'v
102102 root: Directory where the datasets are saved. Default: ".data"
103103 vocab: Vocabulary used for dataset. If None, it will generate a new
104104 vocabulary based on the train data set.
105- data_select : a string or tupel for the returned datasets. Default: ('train', 'valid','test')
105+ split : a string or tuple for the returned datasets. Default: ('train', 'valid','test')
106106 By default, all the three datasets (train, test, valid) are generated. Users
107107 could also choose any one or two of them, for example ('train', 'test') or
108108 just a string 'train'. If 'train' is not in the tuple or string, a vocab
@@ -116,13 +116,13 @@ def WikiText2(tokenizer=None, root='.data', vocab=None, data_select=('train', 'v
116116 >>> train_dataset, valid_dataset, test_dataset = WikiText2(tokenizer=tokenizer)
117117 >>> vocab = train_dataset.get_vocab()
118118 >>> valid_dataset, = WikiText2(tokenizer=tokenizer, vocab=vocab,
119- data_select ='valid')
119+ split ='valid')
120120
121121 """
122- return _setup_datasets ("WikiText2" , tokenizer , root , vocab , data_select , None , None )
122+ return _setup_datasets ("WikiText2" , tokenizer , root , vocab , split , None , None )
123123
124124
125- def WikiText103 (tokenizer = None , root = '.data' , vocab = None , data_select = ('train' , 'valid' , 'test' )):
125+ def WikiText103 (tokenizer = None , root = '.data' , vocab = None , split = ('train' , 'valid' , 'test' )):
126126 """ Defines WikiText103 datasets.
127127
128128 Create language modeling dataset: WikiText103
@@ -136,7 +136,7 @@ def WikiText103(tokenizer=None, root='.data', vocab=None, data_select=('train',
136136 root: Directory where the datasets are saved. Default: ".data"
137137 vocab: Vocabulary used for dataset. If None, it will generate a new
138138 vocabulary based on the train data set.
139- data_select : a string or tupel for the returned datasets. Default: ('train', 'valid', 'test')
139+ split : a string or tuple for the returned datasets. Default: ('train', 'valid', 'test')
140140 By default, all the three datasets (train, test, valid) are generated. Users
141141 could also choose any one or two of them, for example ('train', 'test') or
142142 just a string 'train'. If 'train' is not in the tuple or string, a vocab
@@ -150,14 +150,14 @@ def WikiText103(tokenizer=None, root='.data', vocab=None, data_select=('train',
150150 >>> train_dataset, valid_dataset, test_dataset = WikiText103(tokenizer=tokenizer)
151151 >>> vocab = train_dataset.get_vocab()
152152 >>> valid_dataset, = WikiText103(tokenizer=tokenizer, vocab=vocab,
153- data_select ='valid')
153+ split ='valid')
154154
155155 """
156156
157- return _setup_datasets ("WikiText103" , tokenizer , root , vocab , data_select , None , None )
157+ return _setup_datasets ("WikiText103" , tokenizer , root , vocab , split , None , None )
158158
159159
160- def PennTreebank (tokenizer = None , root = '.data' , vocab = None , data_select = ('train' , 'valid' , 'test' )):
160+ def PennTreebank (tokenizer = None , root = '.data' , vocab = None , split = ('train' , 'valid' , 'test' )):
161161 """ Defines PennTreebank datasets.
162162
163163 Create language modeling dataset: PennTreebank
@@ -171,7 +171,7 @@ def PennTreebank(tokenizer=None, root='.data', vocab=None, data_select=('train',
171171 root: Directory where the datasets are saved. Default: ".data"
172172 vocab: Vocabulary used for dataset. If None, it will generate a new
173173 vocabulary based on the train data set.
174- data_select : a string or tupel for the returned datasets. Default: ('train', 'valid', 'test')
174+ split : a string or tuple for the returned datasets. Default: ('train', 'valid', 'test')
175175 By default, all the three datasets (train, test, valid) are generated. Users
176176 could also choose any one or two of them, for example ('train', 'test') or
177177 just a string 'train'. If 'train' is not in the tuple or string, a vocab
@@ -185,14 +185,14 @@ def PennTreebank(tokenizer=None, root='.data', vocab=None, data_select=('train',
185185 >>> train_dataset, valid_dataset, test_dataset = PennTreebank(tokenizer=tokenizer)
186186 >>> vocab = train_dataset.get_vocab()
187187 >>> valid_dataset, = PennTreebank(tokenizer=tokenizer, vocab=vocab,
188- data_select ='valid')
188+ split ='valid')
189189
190190 """
191191
192- return _setup_datasets ("PennTreebank" , tokenizer , root , vocab , data_select , None , None )
192+ return _setup_datasets ("PennTreebank" , tokenizer , root , vocab , split , None , None )
193193
194194
195- def WMTNewsCrawl (tokenizer = None , root = '.data' , vocab = None , data_select = ('train' ), year = 2010 , language = 'en' ):
195+ def WMTNewsCrawl (tokenizer = None , root = '.data' , vocab = None , split = ('train' ), year = 2010 , language = 'en' ):
196196 """ Defines WMTNewsCrawl datasets.
197197
198198 Create language modeling dataset: WMTNewsCrawl
@@ -206,7 +206,7 @@ def WMTNewsCrawl(tokenizer=None, root='.data', vocab=None, data_select=('train')
206206 root: Directory where the datasets are saved. Default: ".data"
207207 vocab: Vocabulary used for dataset. If None, it will generate a new
208208 vocabulary based on the train data set.
209- data_select : a string or tuple for the returned datasets
209+ split : a string or tuple for the returned datasets
210210 (Default: ('train',))
211211 year: the year of the dataset (Default: 2010)
212212 language: the language of the dataset (Default: 'en')
@@ -215,12 +215,12 @@ def WMTNewsCrawl(tokenizer=None, root='.data', vocab=None, data_select=('train')
215215 >>> from torchtext.experimental.datasets import WMTNewsCrawl
216216 >>> from torchtext.data.utils import get_tokenizer
217217 >>> tokenizer = get_tokenizer("spacy")
218- >>> train_dataset, = WMTNewsCrawl(tokenizer=tokenizer, data_select ='train')
218+ >>> train_dataset, = WMTNewsCrawl(tokenizer=tokenizer, split ='train')
219219
220220 Note: WMTNewsCrawl provides datasets based on the year and language instead of train/valid/test.
221221 """
222222
223- return _setup_datasets ("WMTNewsCrawl" , tokenizer , root , vocab , data_select , year , language )
223+ return _setup_datasets ("WMTNewsCrawl" , tokenizer , root , vocab , split , year , language )
224224
225225
226226DATASETS = {
0 commit comments