-
Notifications
You must be signed in to change notification settings - Fork 814
Text classification datasets with new torchtext dataset abstraction #701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
9d8aa61
1031efb
371c23e
29e5d0a
8187b84
6272f57
15e2a94
6611a09
7c31d2e
c5d487a
7c9c969
4689b29
fe90c51
d9ef2ee
e98ae46
8ce6779
8291ebc
1864e7d
64cbde6
51d1b8e
855e701
a955579
fa3565b
94870df
74f50b6
55e4848
db66774
5a20115
650928a
2447837
be20884
a6bc30a
3821282
9b97ac2
b565565
ebe87f7
e382503
c711c34
9a0c3ac
c6f6a42
f1d394c
7404519
bc2c83a
644b759
2f93dec
aa15019
793349c
33053e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
[flake8] | ||
ignore = E402,E722,W503,W504 | ||
ignore = E402,E722,W503,W504, F821 | ||
max-line-length = 90 | ||
exclude = docs/source |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
|
||
|
||
def _split_tokenizer(x): | ||
# type: (str) -> List[str] | ||
return x.split() | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
from . import datasets | ||
|
||
__all__ = ['datasets'] | ||
from . import transforms | ||
from . import functional | ||
__all__ = ['datasets', | ||
'transforms', | ||
'functional'] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,19 @@ | ||
from .language_modeling import LanguageModelingDataset, WikiText2, WikiText103, PennTreebank # NOQA | ||
from .text_classification import IMDB | ||
from .new_text_classification import AG_NEWS, SogouNews, DBpedia, YelpReviewPolarity, \ | ||
YelpReviewFull, YahooAnswers, \ | ||
AmazonReviewPolarity, AmazonReviewFull | ||
|
||
__all__ = ['LanguageModelingDataset', | ||
'WikiText2', | ||
'WikiText103', | ||
'PennTreebank', | ||
'IMDB'] | ||
'IMDB', | ||
'AG_NEWS', | ||
'SogouNews', | ||
'DBpedia', | ||
'YelpReviewPolarity', | ||
'YelpReviewFull', | ||
'YahooAnswers', | ||
'AmazonReviewPolarity', | ||
'AmazonReviewFull'] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,297 @@ | ||
import torch | ||
import io | ||
from torchtext.utils import download_from_url, extract_archive, unicode_csv_reader | ||
from torchtext.vocab import build_vocab_from_iterator | ||
|
||
URLS = { | ||
'AG_NEWS': | ||
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUDNpeUdjb0wxRms', | ||
'SogouNews': | ||
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUkVqNEszd0pHaFE', | ||
'DBpedia': | ||
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbQ2Vic1kxMmZZQ1k', | ||
'YelpReviewPolarity': | ||
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbNUpYQ2N3SGlFaDg', | ||
'YelpReviewFull': | ||
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZlU4dXhHTFhZQU0', | ||
'YahooAnswers': | ||
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9Qhbd2JNdDBsQUdocVU', | ||
'AmazonReviewPolarity': | ||
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbaW12WVVZS2drcnM', | ||
'AmazonReviewFull': | ||
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZVhsUnRWRDhETzA' | ||
} | ||
|
||
|
||
def _create_data_from_csv(data_path): | ||
data = [] | ||
with io.open(data_path, encoding="utf8") as f: | ||
reader = unicode_csv_reader(f) | ||
for row in reader: | ||
data.append((row[0], ' '.join(row[1:]))) | ||
return data | ||
|
||
|
||
def build_vocab(dataset, transform): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this take a kwargs argument that will pass arguments to the vocab constructor inside
It will also mean There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be added. However, the wrap-up to build |
||
if not isinstance(dataset, TextClassificationDataset): | ||
raise TypeError('Passed dataset is not TextClassificationDataset') | ||
|
||
# data are saved in the form of (label, text_string) | ||
tok_list = [transform(seq[1]) for seq in dataset.data] | ||
return build_vocab_from_iterator(tok_list) | ||
|
||
|
||
class TextClassificationDataset(torch.utils.data.Dataset): | ||
"""Defines an abstract text classification datasets. | ||
Currently, we only support the following datasets: | ||
- AG_NEWS | ||
- SogouNews | ||
- DBpedia | ||
- YelpReviewPolarity | ||
- YelpReviewFull | ||
- YahooAnswers | ||
- AmazonReviewPolarity | ||
- AmazonReviewFull | ||
""" | ||
|
||
def __init__(self, data, transforms): | ||
"""Initiate text-classification dataset. | ||
Arguments: | ||
Examples: | ||
""" | ||
|
||
super(TextClassificationDataset, self).__init__() | ||
self.data = data | ||
self.transforms = transforms # (label_transforms, tokens_transforms) | ||
|
||
def __getitem__(self, i): | ||
return (self.transforms[0](self.data[i][0]), | ||
self.transforms[1](self.data[i][1])) | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def get_labels(self): | ||
return set([self.transforms[0](item[0]) for item in self.data]) | ||
|
||
|
||
def _setup_datasets(dataset_name, root='.data', transforms=[int, lambda x:x]): | ||
dataset_tar = download_from_url(URLS[dataset_name], root=root) | ||
extracted_files = extract_archive(dataset_tar) | ||
|
||
for fname in extracted_files: | ||
if fname.endswith('train.csv'): | ||
train_csv_path = fname | ||
if fname.endswith('test.csv'): | ||
test_csv_path = fname | ||
|
||
train_data = _create_data_from_csv(train_csv_path) | ||
test_data = _create_data_from_csv(test_csv_path) | ||
return (TextClassificationDataset(train_data, transforms), | ||
TextClassificationDataset(test_data, transforms)) | ||
|
||
|
||
def AG_NEWS(*args, **kwargs): | ||
""" Defines AG_NEWS datasets. | ||
The labels includes: | ||
- 1 : World | ||
- 2 : Sports | ||
- 3 : Business | ||
- 4 : Sci/Tech | ||
Create supervised learning dataset: AG_NEWS | ||
Separately returns the training and test dataset | ||
Arguments: | ||
root: Directory where the datasets are saved. Default: ".data" | ||
Examples: | ||
""" | ||
|
||
return _setup_datasets(*(("AG_NEWS",) + args), **kwargs) | ||
|
||
|
||
def SogouNews(*args, **kwargs): | ||
""" Defines SogouNews datasets. | ||
The labels includes: | ||
- 1 : Sports | ||
- 2 : Finance | ||
- 3 : Entertainment | ||
- 4 : Automobile | ||
- 5 : Technology | ||
Create supervised learning dataset: SogouNews | ||
Separately returns the training and test dataset | ||
Arguments: | ||
root: Directory where the datasets are saved. Default: ".data" | ||
Examples: | ||
""" | ||
|
||
return _setup_datasets(*(("SogouNews",) + args), **kwargs) | ||
|
||
|
||
def DBpedia(*args, **kwargs): | ||
""" Defines DBpedia datasets. | ||
The labels includes: | ||
- 1 : Company | ||
- 2 : EducationalInstitution | ||
- 3 : Artist | ||
- 4 : Athlete | ||
- 5 : OfficeHolder | ||
- 6 : MeanOfTransportation | ||
- 7 : Building | ||
- 8 : NaturalPlace | ||
- 9 : Village | ||
- 10 : Animal | ||
- 11 : Plant | ||
- 12 : Album | ||
- 13 : Film | ||
- 14 : WrittenWork | ||
Create supervised learning dataset: DBpedia | ||
Separately returns the training and test dataset | ||
Arguments: | ||
root: Directory where the datasets are saved. Default: ".data" | ||
Examples: | ||
""" | ||
|
||
return _setup_datasets(*(("DBpedia",) + args), **kwargs) | ||
|
||
|
||
def YelpReviewPolarity(*args, **kwargs): | ||
""" Defines YelpReviewPolarity datasets. | ||
The labels includes: | ||
- 1 : Negative polarity. | ||
- 2 : Positive polarity. | ||
Create supervised learning dataset: YelpReviewPolarity | ||
Separately returns the training and test dataset | ||
Arguments: | ||
root: Directory where the datasets are saved. Default: ".data" | ||
Examples: | ||
""" | ||
|
||
return _setup_datasets(*(("YelpReviewPolarity",) + args), **kwargs) | ||
|
||
|
||
def YelpReviewFull(*args, **kwargs): | ||
""" Defines YelpReviewFull datasets. | ||
The labels includes: | ||
1 - 5 : rating classes (5 is highly recommended). | ||
Create supervised learning dataset: YelpReviewFull | ||
Separately returns the training and test dataset | ||
Arguments: | ||
root: Directory where the datasets are saved. Default: ".data" | ||
Examples: | ||
""" | ||
|
||
return _setup_datasets(*(("YelpReviewFull",) + args), **kwargs) | ||
|
||
|
||
def YahooAnswers(*args, **kwargs): | ||
""" Defines YahooAnswers datasets. | ||
The labels includes: | ||
- 1 : Society & Culture | ||
- 2 : Science & Mathematics | ||
- 3 : Health | ||
- 4 : Education & Reference | ||
- 5 : Computers & Internet | ||
- 6 : Sports | ||
- 7 : Business & Finance | ||
- 8 : Entertainment & Music | ||
- 9 : Family & Relationships | ||
- 10 : Politics & Government | ||
Create supervised learning dataset: YahooAnswers | ||
Separately returns the training and test dataset | ||
Arguments: | ||
root: Directory where the datasets are saved. Default: ".data" | ||
Examples: | ||
""" | ||
|
||
return _setup_datasets(*(("YahooAnswers",) + args), **kwargs) | ||
|
||
|
||
def AmazonReviewPolarity(*args, **kwargs): | ||
""" Defines AmazonReviewPolarity datasets. | ||
The labels includes: | ||
- 1 : Negative polarity | ||
- 2 : Positive polarity | ||
Create supervised learning dataset: AmazonReviewPolarity | ||
Separately returns the training and test dataset | ||
Arguments: | ||
root: Directory where the datasets are saved. Default: ".data" | ||
Examples: | ||
""" | ||
|
||
return _setup_datasets(*(("AmazonReviewPolarity",) + args), **kwargs) | ||
|
||
|
||
def AmazonReviewFull(*args, **kwargs): | ||
""" Defines AmazonReviewFull datasets. | ||
The labels includes: | ||
1 - 5 : rating classes (5 is highly recommended) | ||
Create supervised learning dataset: AmazonReviewFull | ||
Separately returns the training and test dataset | ||
Arguments: | ||
root: Directory where the dataset are saved. Default: ".data" | ||
Examples: | ||
""" | ||
|
||
return _setup_datasets(*(("AmazonReviewFull",) + args), **kwargs) | ||
|
||
|
||
DATASETS = { | ||
'AG_NEWS': AG_NEWS, | ||
'SogouNews': SogouNews, | ||
'DBpedia': DBpedia, | ||
'YelpReviewPolarity': YelpReviewPolarity, | ||
'YelpReviewFull': YelpReviewFull, | ||
'YahooAnswers': YahooAnswers, | ||
'AmazonReviewPolarity': AmazonReviewPolarity, | ||
'AmazonReviewFull': AmazonReviewFull | ||
} | ||
|
||
|
||
LABELS = { | ||
'AG_NEWS': {1: 'World', | ||
2: 'Sports', | ||
3: 'Business', | ||
4: 'Sci/Tech'}, | ||
'SogouNews': {1: 'Sports', | ||
2: 'Finance', | ||
3: 'Entertainment', | ||
4: 'Automobile', | ||
5: 'Technology'}, | ||
'DBpedia': {1: 'Company', | ||
2: 'EducationalInstitution', | ||
3: 'Artist', | ||
4: 'Athlete', | ||
5: 'OfficeHolder', | ||
6: 'MeanOfTransportation', | ||
7: 'Building', | ||
8: 'NaturalPlace', | ||
9: 'Village', | ||
10: 'Animal', | ||
11: 'Plant', | ||
12: 'Album', | ||
13: 'Film', | ||
14: 'WrittenWork'}, | ||
'YelpReviewPolarity': {1: 'Negative polarity', | ||
2: 'Positive polarity'}, | ||
'YelpReviewFull': {1: 'score 1', | ||
2: 'score 2', | ||
3: 'score 3', | ||
4: 'score 4', | ||
5: 'score 5'}, | ||
'YahooAnswers': {1: 'Society & Culture', | ||
2: 'Science & Mathematics', | ||
3: 'Health', | ||
4: 'Education & Reference', | ||
5: 'Computers & Internet', | ||
6: 'Sports', | ||
7: 'Business & Finance', | ||
8: 'Entertainment & Music', | ||
9: 'Family & Relationships', | ||
10: 'Politics & Government'}, | ||
'AmazonReviewPolarity': {1: 'Negative polarity', | ||
2: 'Positive polarity'}, | ||
'AmazonReviewFull': {1: 'score 1', | ||
2: 'score 2', | ||
3: 'score 3', | ||
4: 'score 4', | ||
5: 'score 5'} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
def vocab_transform(vocab, tok): | ||
# type: (Dict[str, int], str) -> int | ||
return vocab[tok] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: extra space