Skip to content

Text classification datasets with new torchtext dataset abstraction #701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 48 commits into from
Apr 21, 2020
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
9d8aa61
new dataset design
Feb 27, 2020
1031efb
remove doc
Feb 28, 2020
371c23e
minor
Feb 28, 2020
29e5d0a
Merge remote-tracking branch 'upstream/master' into new_dataset_design
Mar 2, 2020
8187b84
revise build_vocab func in torchtext.experimental.datasets.new_text_c…
Mar 2, 2020
6272f57
flake8
Mar 2, 2020
15e2a94
docs
Mar 2, 2020
6611a09
switch transforms to torch.nn.Module
Mar 3, 2020
7c31d2e
add default None to tokenizer_name in TokenizerTransform
Mar 3, 2020
c5d487a
jit support for Dict[str, int] vocab in VocabTransform
Mar 3, 2020
7c9c969
remove F821
Mar 3, 2020
4689b29
add functional.py file
Mar 3, 2020
fe90c51
minor fix to have split tokenizer scriptable
Mar 3, 2020
d9ef2ee
add functional.py file
Mar 3, 2020
e98ae46
add a wrapper to support one-command data loading
Mar 12, 2020
8ce6779
add raw file
Mar 12, 2020
8291ebc
flake8
Mar 13, 2020
1864e7d
update raw text classification dataset docs
Mar 19, 2020
64cbde6
minor docs
Mar 19, 2020
51d1b8e
add ngrams
Mar 19, 2020
855e701
add label transform
Mar 19, 2020
a955579
combine imdb and text classification datasets
Mar 20, 2020
fa3565b
add more attributes to dataset API
Mar 20, 2020
94870df
update text classification datasets docs
Mar 20, 2020
74f50b6
remove two transforms
Mar 20, 2020
55e4848
add get_vocab in text_classification
Mar 20, 2020
db66774
minor fix
Mar 20, 2020
5a20115
Add TextSequential
Mar 20, 2020
650928a
swithc text classification to TextSequential
Mar 20, 2020
2447837
fix flake8 error
Mar 23, 2020
be20884
add vocab to dataset
Mar 23, 2020
a6bc30a
add docs strings for transforms.
Mar 23, 2020
3821282
move raw datasets to a separate folder
Mar 23, 2020
9b97ac2
.flake8 file
Mar 23, 2020
b565565
move raw text folder
Mar 23, 2020
ebe87f7
move transforms to experimental.datasets.text_classification
Mar 23, 2020
e382503
Fix IMDB
Mar 23, 2020
c711c34
remove some transforms in experimental text classification
Apr 1, 2020
9a0c3ac
switch raw dataset to iterable style
Apr 9, 2020
c6f6a42
add squential_transforms
Apr 9, 2020
f1d394c
Merge branch 'master' into new_dataset_design
Apr 9, 2020
7404519
add get_iterator func
Apr 9, 2020
bc2c83a
flake8
Apr 9, 2020
644b759
support partial cache for raw text classification dataset
Apr 9, 2020
2f93dec
Merge branch 'master' into new_dataset_design
Apr 13, 2020
aa15019
change None arguments
Apr 14, 2020
793349c
change import raw path
Apr 14, 2020
33053e8
Merge branch 'master' into new_dataset_design
zhangguanheng66 Apr 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[flake8]
ignore = E402,E722,W503,W504
ignore = E402,E722,W503,W504, F821
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: extra space

max-line-length = 90
exclude = docs/source
Empty file modified examples/vocab/vocab.py
100644 → 100755
Empty file.
1 change: 1 addition & 0 deletions torchtext/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


def _split_tokenizer(x):
# type: (str) -> List[str]
return x.split()


Expand Down
7 changes: 5 additions & 2 deletions torchtext/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from . import datasets

__all__ = ['datasets']
from . import transforms
from . import functional
__all__ = ['datasets',
'transforms',
'functional']
13 changes: 12 additions & 1 deletion torchtext/experimental/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
from .language_modeling import LanguageModelingDataset, WikiText2, WikiText103, PennTreebank # NOQA
from .text_classification import IMDB
from .new_text_classification import AG_NEWS, SogouNews, DBpedia, YelpReviewPolarity, \
YelpReviewFull, YahooAnswers, \
AmazonReviewPolarity, AmazonReviewFull

__all__ = ['LanguageModelingDataset',
'WikiText2',
'WikiText103',
'PennTreebank',
'IMDB']
'IMDB',
'AG_NEWS',
'SogouNews',
'DBpedia',
'YelpReviewPolarity',
'YelpReviewFull',
'YahooAnswers',
'AmazonReviewPolarity',
'AmazonReviewFull']
297 changes: 297 additions & 0 deletions torchtext/experimental/datasets/new_text_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
import torch
import io
from torchtext.utils import download_from_url, extract_archive, unicode_csv_reader
from torchtext.vocab import build_vocab_from_iterator

URLS = {
'AG_NEWS':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUDNpeUdjb0wxRms',
'SogouNews':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUkVqNEszd0pHaFE',
'DBpedia':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbQ2Vic1kxMmZZQ1k',
'YelpReviewPolarity':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbNUpYQ2N3SGlFaDg',
'YelpReviewFull':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZlU4dXhHTFhZQU0',
'YahooAnswers':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9Qhbd2JNdDBsQUdocVU',
'AmazonReviewPolarity':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbaW12WVVZS2drcnM',
'AmazonReviewFull':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZVhsUnRWRDhETzA'
}


def _create_data_from_csv(data_path):
data = []
with io.open(data_path, encoding="utf8") as f:
reader = unicode_csv_reader(f)
for row in reader:
data.append((row[0], ' '.join(row[1:])))
return data


def build_vocab(dataset, transform):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this take a kwargs argument that will pass arguments to the vocab constructor inside build_vocab_from_iterator? This will allow us to do something like:

train, test = AG_NEWS()
transform1 = TokenizerTransform('basic_english')
train, valid = torch.utils.data.random_split(train, [90_000, 10_000]) #not exact numbers
vocab = build_vocab(train, transform1, max_size = 25_000, min_freq = 2)

It will also mean build_vocab_from_iterator will also need to be modified to accept kwargs too.

Copy link
Contributor Author

@zhangguanheng66 zhangguanheng66 Mar 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be added. However, the wrap-up to build vocab is pretty simple and users now have the flexibility to do that themselves.

if not isinstance(dataset, TextClassificationDataset):
raise TypeError('Passed dataset is not TextClassificationDataset')

# data are saved in the form of (label, text_string)
tok_list = [transform(seq[1]) for seq in dataset.data]
return build_vocab_from_iterator(tok_list)


class TextClassificationDataset(torch.utils.data.Dataset):
"""Defines an abstract text classification datasets.
Currently, we only support the following datasets:
- AG_NEWS
- SogouNews
- DBpedia
- YelpReviewPolarity
- YelpReviewFull
- YahooAnswers
- AmazonReviewPolarity
- AmazonReviewFull
"""

def __init__(self, data, transforms):
"""Initiate text-classification dataset.
Arguments:
Examples:
"""

super(TextClassificationDataset, self).__init__()
self.data = data
self.transforms = transforms # (label_transforms, tokens_transforms)

def __getitem__(self, i):
return (self.transforms[0](self.data[i][0]),
self.transforms[1](self.data[i][1]))

def __len__(self):
return len(self.data)

def get_labels(self):
return set([self.transforms[0](item[0]) for item in self.data])


def _setup_datasets(dataset_name, root='.data', transforms=[int, lambda x:x]):
dataset_tar = download_from_url(URLS[dataset_name], root=root)
extracted_files = extract_archive(dataset_tar)

for fname in extracted_files:
if fname.endswith('train.csv'):
train_csv_path = fname
if fname.endswith('test.csv'):
test_csv_path = fname

train_data = _create_data_from_csv(train_csv_path)
test_data = _create_data_from_csv(test_csv_path)
return (TextClassificationDataset(train_data, transforms),
TextClassificationDataset(test_data, transforms))


def AG_NEWS(*args, **kwargs):
""" Defines AG_NEWS datasets.
The labels includes:
- 1 : World
- 2 : Sports
- 3 : Business
- 4 : Sci/Tech
Create supervised learning dataset: AG_NEWS
Separately returns the training and test dataset
Arguments:
root: Directory where the datasets are saved. Default: ".data"
Examples:
"""

return _setup_datasets(*(("AG_NEWS",) + args), **kwargs)


def SogouNews(*args, **kwargs):
""" Defines SogouNews datasets.
The labels includes:
- 1 : Sports
- 2 : Finance
- 3 : Entertainment
- 4 : Automobile
- 5 : Technology
Create supervised learning dataset: SogouNews
Separately returns the training and test dataset
Arguments:
root: Directory where the datasets are saved. Default: ".data"
Examples:
"""

return _setup_datasets(*(("SogouNews",) + args), **kwargs)


def DBpedia(*args, **kwargs):
""" Defines DBpedia datasets.
The labels includes:
- 1 : Company
- 2 : EducationalInstitution
- 3 : Artist
- 4 : Athlete
- 5 : OfficeHolder
- 6 : MeanOfTransportation
- 7 : Building
- 8 : NaturalPlace
- 9 : Village
- 10 : Animal
- 11 : Plant
- 12 : Album
- 13 : Film
- 14 : WrittenWork
Create supervised learning dataset: DBpedia
Separately returns the training and test dataset
Arguments:
root: Directory where the datasets are saved. Default: ".data"
Examples:
"""

return _setup_datasets(*(("DBpedia",) + args), **kwargs)


def YelpReviewPolarity(*args, **kwargs):
""" Defines YelpReviewPolarity datasets.
The labels includes:
- 1 : Negative polarity.
- 2 : Positive polarity.
Create supervised learning dataset: YelpReviewPolarity
Separately returns the training and test dataset
Arguments:
root: Directory where the datasets are saved. Default: ".data"
Examples:
"""

return _setup_datasets(*(("YelpReviewPolarity",) + args), **kwargs)


def YelpReviewFull(*args, **kwargs):
""" Defines YelpReviewFull datasets.
The labels includes:
1 - 5 : rating classes (5 is highly recommended).
Create supervised learning dataset: YelpReviewFull
Separately returns the training and test dataset
Arguments:
root: Directory where the datasets are saved. Default: ".data"
Examples:
"""

return _setup_datasets(*(("YelpReviewFull",) + args), **kwargs)


def YahooAnswers(*args, **kwargs):
""" Defines YahooAnswers datasets.
The labels includes:
- 1 : Society & Culture
- 2 : Science & Mathematics
- 3 : Health
- 4 : Education & Reference
- 5 : Computers & Internet
- 6 : Sports
- 7 : Business & Finance
- 8 : Entertainment & Music
- 9 : Family & Relationships
- 10 : Politics & Government
Create supervised learning dataset: YahooAnswers
Separately returns the training and test dataset
Arguments:
root: Directory where the datasets are saved. Default: ".data"
Examples:
"""

return _setup_datasets(*(("YahooAnswers",) + args), **kwargs)


def AmazonReviewPolarity(*args, **kwargs):
""" Defines AmazonReviewPolarity datasets.
The labels includes:
- 1 : Negative polarity
- 2 : Positive polarity
Create supervised learning dataset: AmazonReviewPolarity
Separately returns the training and test dataset
Arguments:
root: Directory where the datasets are saved. Default: ".data"
Examples:
"""

return _setup_datasets(*(("AmazonReviewPolarity",) + args), **kwargs)


def AmazonReviewFull(*args, **kwargs):
""" Defines AmazonReviewFull datasets.
The labels includes:
1 - 5 : rating classes (5 is highly recommended)
Create supervised learning dataset: AmazonReviewFull
Separately returns the training and test dataset
Arguments:
root: Directory where the dataset are saved. Default: ".data"
Examples:
"""

return _setup_datasets(*(("AmazonReviewFull",) + args), **kwargs)


DATASETS = {
'AG_NEWS': AG_NEWS,
'SogouNews': SogouNews,
'DBpedia': DBpedia,
'YelpReviewPolarity': YelpReviewPolarity,
'YelpReviewFull': YelpReviewFull,
'YahooAnswers': YahooAnswers,
'AmazonReviewPolarity': AmazonReviewPolarity,
'AmazonReviewFull': AmazonReviewFull
}


LABELS = {
'AG_NEWS': {1: 'World',
2: 'Sports',
3: 'Business',
4: 'Sci/Tech'},
'SogouNews': {1: 'Sports',
2: 'Finance',
3: 'Entertainment',
4: 'Automobile',
5: 'Technology'},
'DBpedia': {1: 'Company',
2: 'EducationalInstitution',
3: 'Artist',
4: 'Athlete',
5: 'OfficeHolder',
6: 'MeanOfTransportation',
7: 'Building',
8: 'NaturalPlace',
9: 'Village',
10: 'Animal',
11: 'Plant',
12: 'Album',
13: 'Film',
14: 'WrittenWork'},
'YelpReviewPolarity': {1: 'Negative polarity',
2: 'Positive polarity'},
'YelpReviewFull': {1: 'score 1',
2: 'score 2',
3: 'score 3',
4: 'score 4',
5: 'score 5'},
'YahooAnswers': {1: 'Society & Culture',
2: 'Science & Mathematics',
3: 'Health',
4: 'Education & Reference',
5: 'Computers & Internet',
6: 'Sports',
7: 'Business & Finance',
8: 'Entertainment & Music',
9: 'Family & Relationships',
10: 'Politics & Government'},
'AmazonReviewPolarity': {1: 'Negative polarity',
2: 'Positive polarity'},
'AmazonReviewFull': {1: 'score 1',
2: 'score 2',
3: 'score 3',
4: 'score 4',
5: 'score 5'}
}
3 changes: 3 additions & 0 deletions torchtext/experimental/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def vocab_transform(vocab, tok):
# type: (Dict[str, int], str) -> int
return vocab[tok]
Loading