24
24
)
25
25
from autoPyTorch .utils .common import FitRequirement , hash_array_or_matrix
26
26
27
- BASE_DATASET_INPUT = Union [Tuple [np .ndarray , np .ndarray ], Dataset ]
27
+ BaseDatasetType = Union [Tuple [np .ndarray , np .ndarray ], Dataset ]
28
28
29
29
30
30
def check_valid_data (data : Any ) -> None :
31
- if not (hasattr (data , '__getitem__' ) and hasattr ( data , '__len__' ) ):
31
+ if not all (hasattr (data , attr ) for attr in [ '__getitem__' , '__len__' ] ):
32
32
raise ValueError (
33
- 'The specified Data for Dataset does either not have a __getitem__ or a __len__ attribute.' )
33
+ 'The specified Data for Dataset must have both __getitem__ and __len__ attribute.' )
34
34
35
35
36
- def type_check (train_tensors : BASE_DATASET_INPUT , val_tensors : Optional [BASE_DATASET_INPUT ] = None ) -> None :
36
+ def type_check (train_tensors : BaseDatasetType , val_tensors : Optional [BaseDatasetType ] = None ) -> None :
37
+ """To avoid unexpected behavior, we use loops over indices."""
37
38
for i in range (len (train_tensors )):
38
39
check_valid_data (train_tensors [i ])
39
40
if val_tensors is not None :
@@ -42,12 +43,20 @@ def type_check(train_tensors: BASE_DATASET_INPUT, val_tensors: Optional[BASE_DAT
42
43
43
44
44
45
class TransformSubset (Subset ):
45
- """
46
- Because the BaseDataset contains all the data (train/val/test), the transformations
47
- have to be applied with some directions. That is, if yielding train data,
48
- we expect to apply train transformation (which have augmentations exclusively).
46
+ """Wrapper of BaseDataset for splitted datasets
47
+
48
+ Since the BaseDataset contains all the data points (train/val/test),
49
+ we require different transformation for each data point.
50
+ This class helps to take the subset of the dataset
51
+ with either training or validation transformation.
49
52
50
53
We achieve so by adding a train flag to the pytorch subset
54
+
55
+ Attributes:
56
+ dataset (BaseDataset/Dataset): Dataset to sample the subset
57
+ indices names (Sequence[int]): Indices to sample from the dataset
58
+ train (bool): If we apply train or validation transformation
59
+
51
60
"""
52
61
53
62
def __init__ (self , dataset : Dataset , indices : Sequence [int ], train : bool ) -> None :
@@ -62,10 +71,10 @@ def __getitem__(self, idx: int) -> np.ndarray:
62
71
class BaseDataset (Dataset , metaclass = ABCMeta ):
63
72
def __init__ (
64
73
self ,
65
- train_tensors : BASE_DATASET_INPUT ,
74
+ train_tensors : BaseDatasetType ,
66
75
dataset_name : Optional [str ] = None ,
67
- val_tensors : Optional [BASE_DATASET_INPUT ] = None ,
68
- test_tensors : Optional [BASE_DATASET_INPUT ] = None ,
76
+ val_tensors : Optional [BaseDatasetType ] = None ,
77
+ test_tensors : Optional [BaseDatasetType ] = None ,
69
78
resampling_strategy : Union [CrossValTypes , HoldoutValTypes ] = HoldoutValTypes .holdout_validation ,
70
79
resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
71
80
shuffle : Optional [bool ] = True ,
@@ -97,18 +106,15 @@ def __init__(
97
106
val_transforms (Optional[torchvision.transforms.Compose]):
98
107
Additional Transforms to be applied to the validation/test data
99
108
"""
100
- if dataset_name is not None :
101
- self .dataset_name = dataset_name
102
- else :
103
- self .dataset_name = hash_array_or_matrix (train_tensors [0 ])
109
+ self .dataset_name = dataset_name if dataset_name is not None \
110
+ else hash_array_or_matrix (train_tensors [0 ])
111
+
104
112
if not hasattr (train_tensors [0 ], 'shape' ):
105
113
type_check (train_tensors , val_tensors )
106
- self .train_tensors = train_tensors
107
- self .val_tensors = val_tensors
108
- self .test_tensors = test_tensors
114
+ self .train_tensors , self .val_tensors , self .test_tensors = train_tensors , val_tensors , test_tensors
109
115
self .cross_validators : Dict [str , CROSS_VAL_FN ] = {}
110
116
self .holdout_validators : Dict [str , HOLDOUT_FN ] = {}
111
- self .rand = np .random .RandomState (seed = seed )
117
+ self .rng = np .random .RandomState (seed = seed )
112
118
self .shuffle = shuffle
113
119
self .resampling_strategy = resampling_strategy
114
120
self .resampling_strategy_args = resampling_strategy_args
@@ -128,16 +134,8 @@ def __init__(
128
134
self .is_small_preprocess = True
129
135
130
136
# Make sure cross validation splits are created once
131
- self .cross_validators = get_cross_validators (
132
- CrossValTypes .stratified_k_fold_cross_validation ,
133
- CrossValTypes .k_fold_cross_validation ,
134
- CrossValTypes .shuffle_split_cross_validation ,
135
- CrossValTypes .stratified_shuffle_split_cross_validation
136
- )
137
- self .holdout_validators = get_holdout_validators (
138
- HoldoutValTypes .holdout_validation ,
139
- HoldoutValTypes .stratified_holdout_validation
140
- )
137
+ self .cross_validators = get_cross_validators (* CrossValTypes )
138
+ self .holdout_validators = get_holdout_validators (* HoldoutValTypes )
141
139
self .splits = self .get_splits_from_resampling_strategy ()
142
140
143
141
# We also need to be able to transform the data, be it for pre-processing
@@ -146,19 +144,19 @@ def __init__(
146
144
self .val_transform = val_transforms
147
145
148
146
def update_transform (self , transform : Optional [torchvision .transforms .Compose ],
149
- train : bool = True ,
150
- ) -> 'BaseDataset' :
147
+ train : bool = True ) -> 'BaseDataset' :
151
148
"""
152
149
During the pipeline execution, the pipeline object might propose transformations
153
150
as a product of the current pipeline configuration being tested.
154
151
155
- This utility allows to return a self with the updated transformation, so that
152
+ This utility allows to return self with the updated transformation, so that
156
153
a dataloader can yield this dataset with the desired transformations
157
154
158
155
Args:
159
- transform (torchvision.transforms.Compose): The transformations proposed
160
- by the current pipeline
161
- train (bool): Whether to update the train or validation transform
156
+ transform (torchvision.transforms.Compose):
157
+ The transformations proposed by the current pipeline
158
+ train (bool):
159
+ Whether to update the train or validation transform
162
160
163
161
Returns:
164
162
self: A copy of the update pipeline
@@ -171,9 +169,9 @@ def update_transform(self, transform: Optional[torchvision.transforms.Compose],
171
169
172
170
def __getitem__ (self , index : int , train : bool = True ) -> Tuple [np .ndarray , ...]:
173
171
"""
174
- The base dataset uses a Subset of the data. Nevertheless, the base dataset expect
175
- both validation and test data to be present in the same dataset, which motivated the
176
- need to dynamically give train/test data with the __getitem__ command.
172
+ The base dataset uses a Subset of the data. Nevertheless, the base dataset expects
173
+ both validation and test data to be present in the same dataset, which motivates
174
+ the need to dynamically give train/test data with the __getitem__ command.
177
175
178
176
This method yields a datapoint of the whole data (after a Subset has selected a given
179
177
item, based on the resampling strategy) and applies a train/testing transformation, if any.
@@ -186,34 +184,24 @@ def __getitem__(self, index: int, train: bool = True) -> Tuple[np.ndarray, ...]:
186
184
A transformed single point prediction
187
185
"""
188
186
189
- if hasattr (self .train_tensors [0 ], 'loc' ):
190
- X = self .train_tensors [0 ].iloc [[index ]]
191
- else :
192
- X = self .train_tensors [0 ][index ]
187
+ X = self .train_tensors [0 ].iloc [[index ]] if hasattr (self .train_tensors [0 ], 'loc' ) \
188
+ else self .train_tensors [0 ][index ]
193
189
194
190
if self .train_transform is not None and train :
195
191
X = self .train_transform (X )
196
192
elif self .val_transform is not None and not train :
197
193
X = self .val_transform (X )
198
194
199
195
# In case of prediction, the targets are not provided
200
- Y = self .train_tensors [1 ]
201
- if Y is not None :
202
- Y = Y [index ]
203
- else :
204
- Y = None
196
+ Y = self .train_tensors [1 ][index ] if self .train_tensors [1 ] is not None else None
205
197
206
198
return X , Y
207
199
208
200
def __len__ (self ) -> int :
209
201
return self .train_tensors [0 ].shape [0 ]
210
202
211
203
def _get_indices (self ) -> np .ndarray :
212
- if self .shuffle :
213
- indices = self .rand .permutation (len (self ))
214
- else :
215
- indices = np .arange (len (self ))
216
- return indices
204
+ return self .rng .permutation (len (self )) if self .shuffle else np .arange (len (self ))
217
205
218
206
def get_splits_from_resampling_strategy (self ) -> List [Tuple [List [int ], List [int ]]]:
219
207
"""
@@ -333,7 +321,7 @@ def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
333
321
return (TransformSubset (self , self .splits [split_id ][0 ], train = True ),
334
322
TransformSubset (self , self .splits [split_id ][1 ], train = False ))
335
323
336
- def replace_data (self , X_train : BASE_DATASET_INPUT , X_test : Optional [BASE_DATASET_INPUT ]) -> 'BaseDataset' :
324
+ def replace_data (self , X_train : BaseDatasetType , X_test : Optional [BaseDatasetType ]) -> 'BaseDataset' :
337
325
"""
338
326
To speed up the training of small dataset, early pre-processing of the data
339
327
can be made on the fly by the pipeline.
@@ -361,7 +349,8 @@ def get_dataset_properties(self, dataset_requirements: List[FitRequirement]) ->
361
349
contain.
362
350
363
351
Returns:
364
-
352
+ dataset_properties (Dict[str, Any]):
353
+ Dict of the dataset properties.
365
354
"""
366
355
dataset_properties = dict ()
367
356
for dataset_requirement in dataset_requirements :
0 commit comments