1- import bz2
2- import functools
3- from typing import Any , Dict , List , Tuple , BinaryIO , Iterator
1+ from typing import Any , Dict , List , Tuple
42
53import numpy as np
64import torch
7- from torchdata .datapipes .iter import IterDataPipe , IterableWrapper , LineReader , Mapper
5+ from torchdata .datapipes .iter import IterDataPipe , LineReader , Mapper , Decompressor
86from torchvision .prototype .datasets .utils import Dataset , DatasetInfo , DatasetConfig , OnlineResource , HttpResource
97from torchvision .prototype .datasets .utils ._internal import hint_sharding , hint_shuffling
108from torchvision .prototype .features import Image , Label
119
1210
13- class USPSFileReader (IterDataPipe [torch .Tensor ]):
14- def __init__ (self , datapipe : IterDataPipe [Tuple [Any , BinaryIO ]]) -> None :
15- self .datapipe = datapipe
16-
17- def __iter__ (self ) -> Iterator [Tuple [torch .Tensor , torch .Tensor ]]:
18- for path , _ in self .datapipe :
19- with bz2 .open (path ) as fp :
20- datapipe = IterableWrapper ([(path , fp )])
21- line_reader = LineReader (datapipe , decode = True )
22- for _ , line in line_reader :
23- raw_data = line .split ()
24- tmp_list = [x .split (":" )[- 1 ] for x in raw_data [1 :]]
25- img = np .asarray (tmp_list , dtype = np .float32 ).reshape ((- 1 , 16 , 16 ))
26- img = ((img + 1 ) / 2 * 255 ).astype (dtype = np .uint8 )
27- target = int (raw_data [0 ]) - 1
28- yield torch .from_numpy (img ), torch .tensor (target )
29-
30-
3111class USPS (Dataset ):
3212 def _make_info (self ) -> DatasetInfo :
3313 return DatasetInfo (
@@ -54,10 +34,18 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
5434 return [USPS ._RESOURCES [config .split ]]
5535
5636 def _prepare_sample (self , data : Tuple [torch .Tensor , torch .Tensor ]) -> Dict [str , Any ]:
57- image , label = data
37+ _filename , line = data
38+
39+ raw_data = line .split ()
40+ tmp_list = [x .split (":" )[- 1 ] for x in raw_data [1 :]]
41+ img = np .asarray (tmp_list , dtype = np .float32 ).reshape ((- 1 , 16 , 16 ))
42+ img = ((img + 1 ) / 2 * 255 ).astype (dtype = np .uint8 )
43+ img = torch .from_numpy (img )
44+ target = int (raw_data [0 ]) - 1
45+
5846 return dict (
59- image = Image (image ),
60- label = Label (label , dtype = torch .int64 , categories = self .categories ),
47+ image = Image (img ),
48+ label = Label (target , dtype = torch .int64 , categories = self .categories ),
6149 )
6250
6351 def _make_datapipe (
@@ -66,7 +54,8 @@ def _make_datapipe(
6654 * ,
6755 config : DatasetConfig ,
6856 ) -> IterDataPipe [Dict [str , Any ]]:
69- dp = USPSFileReader (resource_dps [0 ])
57+ dp = Decompressor (resource_dps [0 ])
58+ dp = LineReader (dp , decode = True )
7059 dp = hint_sharding (dp )
7160 dp = hint_shuffling (dp )
7261 return Mapper (dp , self ._prepare_sample )
0 commit comments