|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. |
| 2 | +import logging |
| 3 | +import os |
| 4 | + |
| 5 | +from torchtext._internal.module_utils import is_module_available |
| 6 | +from torchtext.data.datasets_utils import ( |
| 7 | + _add_docstring_header, |
| 8 | + _create_dataset_directory, |
| 9 | + _wrap_split_argument, |
| 10 | +) |
| 11 | + |
| 12 | +logger = logging.getLogger(__name__) |
| 13 | + |
| 14 | +if is_module_available("torchdata"): |
| 15 | + from torchdata.datapipes.iter import ( |
| 16 | + HttpReader, |
| 17 | + IterableWrapper, |
| 18 | + ) |
| 19 | +else: |
| 20 | + logger.warning( |
| 21 | + "Package `torchdata` is required to be installed to use this dataset." |
| 22 | + "Please refer to https://github.com/pytorch/data for instructions on " |
| 23 | + "how to install the package." |
| 24 | + ) |
| 25 | + |
| 26 | + |
| 27 | +NUM_LINES = { |
| 28 | + "train": 67349, |
| 29 | + "dev": 872, |
| 30 | + "test": 1821, |
| 31 | +} |
| 32 | + |
| 33 | +MD5 = "9f81648d4199384278b86e315dac217c" |
| 34 | +URL = "https://dl.fbaipublicfiles.com/glue/data/SST-2.zip" |
| 35 | + |
| 36 | +_EXTRACTED_FILES = { |
| 37 | + "train": f"{os.sep}".join(["SST-2", "train.tsv"]), |
| 38 | + "dev": f"{os.sep}".join(["SST-2", "dev.tsv"]), |
| 39 | + "test": f"{os.sep}".join(["SST-2", "test.tsv"]), |
| 40 | +} |
| 41 | + |
| 42 | +_EXTRACTED_FILES_MD5 = { |
| 43 | + "train": "da409a0a939379ed32a470bc0f7fe99a", |
| 44 | + "dev": "268856b487b2a31a28c0a93daaff7288", |
| 45 | + "test": "3230e4efec76488b87877a56ae49675a", |
| 46 | +} |
| 47 | + |
| 48 | +_FIRST_LINE_MD5 = { |
| 49 | + "train": "2552b8cecd57b2e022ef23411c688fa8", |
| 50 | + "dev": "1b0ffd6aa5f2bf0fd9840a5f6f1a9f07", |
| 51 | + "test": "f838c81fe40bfcd7e42e9ffc4dd004f7", |
| 52 | +} |
| 53 | + |
| 54 | +DATASET_NAME = "SST2" |
| 55 | + |
| 56 | + |
| 57 | +@_add_docstring_header(num_lines=NUM_LINES, num_classes=2) |
| 58 | +@_create_dataset_directory(dataset_name=DATASET_NAME) |
| 59 | +@_wrap_split_argument(("train", "dev", "test")) |
| 60 | +def SST2(root, split): |
| 61 | + return SST2Dataset(root, split).get_datapipe() |
| 62 | + |
| 63 | + |
| 64 | +class SST2Dataset: |
| 65 | + """The SST2 dataset uses torchdata datapipes end-2-end. |
| 66 | + To avoid download at every epoch, we cache the data on-disk |
| 67 | + We do sanity check on dowloaded and extracted data |
| 68 | + """ |
| 69 | + |
| 70 | + def __init__(self, root, split): |
| 71 | + self.root = root |
| 72 | + self.split = split |
| 73 | + |
| 74 | + def get_datapipe(self): |
| 75 | + # cache data on-disk |
| 76 | + cache_dp = IterableWrapper([URL]).on_disk_cache( |
| 77 | + HttpReader, |
| 78 | + op_map=lambda x: (x[0], x[1].read()), |
| 79 | + filepath_fn=lambda x: os.path.join(self.root, os.path.basename(x)), |
| 80 | + ) |
| 81 | + |
| 82 | + # extract data from zip |
| 83 | + extracted_files = cache_dp.read_from_zip() |
| 84 | + |
| 85 | + # Parse CSV file and yield data samples |
| 86 | + return ( |
| 87 | + extracted_files.filter(lambda x: self.split in x[0]) |
| 88 | + .parse_csv(skip_lines=1, delimiter="\t") |
| 89 | + .map(lambda x: (x[0], x[1])) |
| 90 | + ) |
0 commit comments