Skip to content

Commit 85bc5fd

Browse files
committed
use decompressor for extracting bz2
1 parent f7d0a50 commit 85bc5fd

File tree

1 file changed

+15
-26
lines changed
  • torchvision/prototype/datasets/_builtin

1 file changed

+15
-26
lines changed
Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,13 @@
1-
import bz2
2-
import functools
3-
from typing import Any, Dict, List, Tuple, BinaryIO, Iterator
1+
from typing import Any, Dict, List, Tuple
42

53
import numpy as np
64
import torch
7-
from torchdata.datapipes.iter import IterDataPipe, IterableWrapper, LineReader, Mapper
5+
from torchdata.datapipes.iter import IterDataPipe, LineReader, Mapper, Decompressor
86
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource, HttpResource
97
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
108
from torchvision.prototype.features import Image, Label
119

1210

13-
class USPSFileReader(IterDataPipe[torch.Tensor]):
14-
def __init__(self, datapipe: IterDataPipe[Tuple[Any, BinaryIO]]) -> None:
15-
self.datapipe = datapipe
16-
17-
def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
18-
for path, _ in self.datapipe:
19-
with bz2.open(path) as fp:
20-
datapipe = IterableWrapper([(path, fp)])
21-
line_reader = LineReader(datapipe, decode=True)
22-
for _, line in line_reader:
23-
raw_data = line.split()
24-
tmp_list = [x.split(":")[-1] for x in raw_data[1:]]
25-
img = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
26-
img = ((img + 1) / 2 * 255).astype(dtype=np.uint8)
27-
target = int(raw_data[0]) - 1
28-
yield torch.from_numpy(img), torch.tensor(target)
29-
30-
3111
class USPS(Dataset):
3212
def _make_info(self) -> DatasetInfo:
3313
return DatasetInfo(
@@ -54,10 +34,18 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
5434
return [USPS._RESOURCES[config.split]]
5535

5636
def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]:
57-
image, label = data
37+
_filename, line = data
38+
39+
raw_data = line.split()
40+
tmp_list = [x.split(":")[-1] for x in raw_data[1:]]
41+
img = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
42+
img = ((img + 1) / 2 * 255).astype(dtype=np.uint8)
43+
img = torch.from_numpy(img)
44+
target = int(raw_data[0]) - 1
45+
5846
return dict(
59-
image=Image(image),
60-
label=Label(label, dtype=torch.int64, categories=self.categories),
47+
image=Image(img),
48+
label=Label(target, dtype=torch.int64, categories=self.categories),
6149
)
6250

6351
def _make_datapipe(
@@ -66,7 +54,8 @@ def _make_datapipe(
6654
*,
6755
config: DatasetConfig,
6856
) -> IterDataPipe[Dict[str, Any]]:
69-
dp = USPSFileReader(resource_dps[0])
57+
dp = Decompressor(resource_dps[0])
58+
dp = LineReader(dp, decode=True)
7059
dp = hint_sharding(dp)
7160
dp = hint_shuffling(dp)
7261
return Mapper(dp, self._prepare_sample)

0 commit comments

Comments
 (0)