1
- import bz2
2
- import functools
3
- from typing import Any , Dict , List , Tuple , BinaryIO , Iterator
1
+ from typing import Any , Dict , List , Tuple
4
2
5
3
import numpy as np
6
4
import torch
7
- from torchdata .datapipes .iter import IterDataPipe , IterableWrapper , LineReader , Mapper
5
+ from torchdata .datapipes .iter import IterDataPipe , LineReader , Mapper , Decompressor
8
6
from torchvision .prototype .datasets .utils import Dataset , DatasetInfo , DatasetConfig , OnlineResource , HttpResource
9
7
from torchvision .prototype .datasets .utils ._internal import hint_sharding , hint_shuffling
10
8
from torchvision .prototype .features import Image , Label
11
9
12
10
13
- class USPSFileReader (IterDataPipe [torch .Tensor ]):
14
- def __init__ (self , datapipe : IterDataPipe [Tuple [Any , BinaryIO ]]) -> None :
15
- self .datapipe = datapipe
16
-
17
- def __iter__ (self ) -> Iterator [Tuple [torch .Tensor , torch .Tensor ]]:
18
- for path , _ in self .datapipe :
19
- with bz2 .open (path ) as fp :
20
- datapipe = IterableWrapper ([(path , fp )])
21
- line_reader = LineReader (datapipe , decode = True )
22
- for _ , line in line_reader :
23
- raw_data = line .split ()
24
- tmp_list = [x .split (":" )[- 1 ] for x in raw_data [1 :]]
25
- img = np .asarray (tmp_list , dtype = np .float32 ).reshape ((- 1 , 16 , 16 ))
26
- img = ((img + 1 ) / 2 * 255 ).astype (dtype = np .uint8 )
27
- target = int (raw_data [0 ]) - 1
28
- yield torch .from_numpy (img ), torch .tensor (target )
29
-
30
-
31
11
class USPS (Dataset ):
32
12
def _make_info (self ) -> DatasetInfo :
33
13
return DatasetInfo (
@@ -54,10 +34,18 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
54
34
return [USPS ._RESOURCES [config .split ]]
55
35
56
36
def _prepare_sample (self , data : Tuple [torch .Tensor , torch .Tensor ]) -> Dict [str , Any ]:
57
- image , label = data
37
+ _filename , line = data
38
+
39
+ raw_data = line .split ()
40
+ tmp_list = [x .split (":" )[- 1 ] for x in raw_data [1 :]]
41
+ img = np .asarray (tmp_list , dtype = np .float32 ).reshape ((- 1 , 16 , 16 ))
42
+ img = ((img + 1 ) / 2 * 255 ).astype (dtype = np .uint8 )
43
+ img = torch .from_numpy (img )
44
+ target = int (raw_data [0 ]) - 1
45
+
58
46
return dict (
59
- image = Image (image ),
60
- label = Label (label , dtype = torch .int64 , categories = self .categories ),
47
+ image = Image (img ),
48
+ label = Label (target , dtype = torch .int64 , categories = self .categories ),
61
49
)
62
50
63
51
def _make_datapipe (
@@ -66,7 +54,8 @@ def _make_datapipe(
66
54
* ,
67
55
config : DatasetConfig ,
68
56
) -> IterDataPipe [Dict [str , Any ]]:
69
- dp = USPSFileReader (resource_dps [0 ])
57
+ dp = Decompressor (resource_dps [0 ])
58
+ dp = LineReader (dp , decode = True )
70
59
dp = hint_sharding (dp )
71
60
dp = hint_shuffling (dp )
72
61
return Mapper (dp , self ._prepare_sample )
0 commit comments