Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 0930843

Browse files
Nayef211nayef211
andauthored
Added new SST2 dataset class (#1410)
* added new SST2 dataset class based on sst2 functional dataset in torchdata * Reset sp submodule to previous commit * Updated function name. Added torchdata as a dep * Added torchdata as a dep to setup.py * Updated unit test to check hash of first line in dataset * Fixed dependency_link url for torchdata * Added torchdata install to circleci config * Updated commit id for torchdata install. Specified torchdata as an optional dependency * Removed additional hash checks during dataset construction * Removed new line from config.yml * Removed changes from config.yml, requirements.txt, and setup.py. Updated unittests to be skipped if module is not available * Incroporated review feedback * Added torchdata installation for unittests * Removed newline changes Co-authored-by: nayef211 <[email protected]>
1 parent 7c5f083 commit 0930843

File tree

10 files changed

+152
-3
lines changed

10 files changed

+152
-3
lines changed

.circleci/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ jobs:
497497
- v1-windows-dataset-vector-{{ checksum ".cachekey" }}
498498
- v1-windows-dataset-{{ checksum ".cachekey" }}
499499

500-
500+
501501
- run:
502502
name: Run tests
503503
# Downloading embedding vector takes long time.

.circleci/config.yml.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ jobs:
497497
- v1-windows-dataset-vector-{{ checksum ".cachekey" }}
498498
- v1-windows-dataset-{{ checksum ".cachekey" }}
499499
{% endraw %}
500-
500+
501501
- run:
502502
name: Run tests
503503
# Downloading embedding vector takes long time.

.circleci/unittest/linux/scripts/install.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ conda activate ./env
1313
printf "* Installing PyTorch\n"
1414
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" ${CONDA_CHANNEL_FLAGS} pytorch cpuonly
1515

16+
printf "Installing torchdata from source\n"
17+
pip install git+https://github.com/pytorch/data.git
18+
1619
printf "* Installing torchtext\n"
1720
git submodule update --init --recursive
1821
python setup.py develop

.circleci/unittest/windows/scripts/install.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ conda activate ./env
1818
printf "* Installing PyTorch\n"
1919
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" ${CONDA_CHANNEL_FLAGS} pytorch cpuonly
2020

21+
printf "Installing torchdata from source\n"
22+
pip install git+https://github.com/pytorch/data.git
23+
2124
printf "* Installing torchtext\n"
2225
git submodule update --init --recursive
2326
"$root_dir/packaging/vc_env_helper.bat" python setup.py develop

test/common/case_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import unittest
2+
from torchtext._internal.module_utils import is_module_available
3+
4+
5+
def skipIfNoModule(module, display_name=None):
6+
display_name = display_name or module
7+
return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available')

test/experimental/test_datasets.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import hashlib
2+
import json
3+
4+
from torchtext.experimental.datasets import sst2
5+
6+
from ..common.case_utils import skipIfNoModule
7+
from ..common.torchtext_test_case import TorchtextTestCase
8+
9+
10+
class TestDataset(TorchtextTestCase):
11+
@skipIfNoModule("torchdata")
12+
def test_sst2_dataset(self):
13+
split = ("train", "dev", "test")
14+
train_dp, dev_dp, test_dp = sst2.SST2(split=split)
15+
16+
# verify hashes of first line in dataset
17+
self.assertEqual(
18+
hashlib.md5(
19+
json.dumps(next(iter(train_dp)), sort_keys=True).encode("utf-8")
20+
).hexdigest(),
21+
sst2._FIRST_LINE_MD5["train"],
22+
)
23+
self.assertEqual(
24+
hashlib.md5(
25+
json.dumps(next(iter(dev_dp)), sort_keys=True).encode("utf-8")
26+
).hexdigest(),
27+
sst2._FIRST_LINE_MD5["dev"],
28+
)
29+
self.assertEqual(
30+
hashlib.md5(
31+
json.dumps(next(iter(test_dp)), sort_keys=True).encode("utf-8")
32+
).hexdigest(),
33+
sst2._FIRST_LINE_MD5["test"],
34+
)

torchtext/_internal/__init__.py

Whitespace-only changes.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import importlib.util
2+
3+
4+
def is_module_available(*modules: str) -> bool:
5+
r"""Returns if a top-level module with :attr:`name` exists *without**
6+
importing it. This is generally safer than try-catch block around a
7+
`import X`. It avoids third party libraries breaking assumptions of some of
8+
our tests, e.g., setting multiprocessing start method when imported
9+
(see librosa/#747, torchvision/#544).
10+
"""
11+
return all(importlib.util.find_spec(m) is not None for m in modules)
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from . import raw
2+
from . import sst2
23

3-
__all__ = ['raw']
4+
__all__ = ["raw", "sst2"]
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
import logging
3+
import os
4+
5+
from torchtext._internal.module_utils import is_module_available
6+
from torchtext.data.datasets_utils import (
7+
_add_docstring_header,
8+
_create_dataset_directory,
9+
_wrap_split_argument,
10+
)
11+
12+
logger = logging.getLogger(__name__)
13+
14+
if is_module_available("torchdata"):
15+
from torchdata.datapipes.iter import (
16+
HttpReader,
17+
IterableWrapper,
18+
)
19+
else:
20+
logger.warning(
21+
"Package `torchdata` is required to be installed to use this dataset."
22+
"Please refer to https://github.com/pytorch/data for instructions on "
23+
"how to install the package."
24+
)
25+
26+
27+
NUM_LINES = {
28+
"train": 67349,
29+
"dev": 872,
30+
"test": 1821,
31+
}
32+
33+
MD5 = "9f81648d4199384278b86e315dac217c"
34+
URL = "https://dl.fbaipublicfiles.com/glue/data/SST-2.zip"
35+
36+
_EXTRACTED_FILES = {
37+
"train": f"{os.sep}".join(["SST-2", "train.tsv"]),
38+
"dev": f"{os.sep}".join(["SST-2", "dev.tsv"]),
39+
"test": f"{os.sep}".join(["SST-2", "test.tsv"]),
40+
}
41+
42+
_EXTRACTED_FILES_MD5 = {
43+
"train": "da409a0a939379ed32a470bc0f7fe99a",
44+
"dev": "268856b487b2a31a28c0a93daaff7288",
45+
"test": "3230e4efec76488b87877a56ae49675a",
46+
}
47+
48+
_FIRST_LINE_MD5 = {
49+
"train": "2552b8cecd57b2e022ef23411c688fa8",
50+
"dev": "1b0ffd6aa5f2bf0fd9840a5f6f1a9f07",
51+
"test": "f838c81fe40bfcd7e42e9ffc4dd004f7",
52+
}
53+
54+
DATASET_NAME = "SST2"
55+
56+
57+
@_add_docstring_header(num_lines=NUM_LINES, num_classes=2)
58+
@_create_dataset_directory(dataset_name=DATASET_NAME)
59+
@_wrap_split_argument(("train", "dev", "test"))
60+
def SST2(root, split):
61+
return SST2Dataset(root, split).get_datapipe()
62+
63+
64+
class SST2Dataset:
65+
"""The SST2 dataset uses torchdata datapipes end-2-end.
66+
To avoid download at every epoch, we cache the data on-disk
67+
We do sanity check on dowloaded and extracted data
68+
"""
69+
70+
def __init__(self, root, split):
71+
self.root = root
72+
self.split = split
73+
74+
def get_datapipe(self):
75+
# cache data on-disk
76+
cache_dp = IterableWrapper([URL]).on_disk_cache(
77+
HttpReader,
78+
op_map=lambda x: (x[0], x[1].read()),
79+
filepath_fn=lambda x: os.path.join(self.root, os.path.basename(x)),
80+
)
81+
82+
# extract data from zip
83+
extracted_files = cache_dp.read_from_zip()
84+
85+
# Parse CSV file and yield data samples
86+
return (
87+
extracted_files.filter(lambda x: self.split in x[0])
88+
.parse_csv(skip_lines=1, delimiter="\t")
89+
.map(lambda x: (x[0], x[1]))
90+
)

0 commit comments

Comments
 (0)