Skip to content

Commit 8440695

Browse files
authored
Merge pull request #105 from nabenabe0928/refactoring-base-dataset
Refactoring base dataset
2 parents b5d1c8b + eac426d commit 8440695

File tree

2 files changed

+43
-55
lines changed

2 files changed

+43
-55
lines changed

autoPyTorch/datasets/base_dataset.py

Lines changed: 43 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,17 @@
2424
)
2525
from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix
2626

27-
BASE_DATASET_INPUT = Union[Tuple[np.ndarray, np.ndarray], Dataset]
27+
BaseDatasetType = Union[Tuple[np.ndarray, np.ndarray], Dataset]
2828

2929

3030
def check_valid_data(data: Any) -> None:
31-
if not (hasattr(data, '__getitem__') and hasattr(data, '__len__')):
31+
if not all(hasattr(data, attr) for attr in ['__getitem__', '__len__']):
3232
raise ValueError(
33-
'The specified Data for Dataset does either not have a __getitem__ or a __len__ attribute.')
33+
'The specified Data for Dataset must have both __getitem__ and __len__ attribute.')
3434

3535

36-
def type_check(train_tensors: BASE_DATASET_INPUT, val_tensors: Optional[BASE_DATASET_INPUT] = None) -> None:
36+
def type_check(train_tensors: BaseDatasetType, val_tensors: Optional[BaseDatasetType] = None) -> None:
37+
"""To avoid unexpected behavior, we use loops over indices."""
3738
for i in range(len(train_tensors)):
3839
check_valid_data(train_tensors[i])
3940
if val_tensors is not None:
@@ -42,12 +43,20 @@ def type_check(train_tensors: BASE_DATASET_INPUT, val_tensors: Optional[BASE_DAT
4243

4344

4445
class TransformSubset(Subset):
45-
"""
46-
Because the BaseDataset contains all the data (train/val/test), the transformations
47-
have to be applied with some directions. That is, if yielding train data,
48-
we expect to apply train transformation (which have augmentations exclusively).
46+
"""Wrapper of BaseDataset for splitted datasets
47+
48+
Since the BaseDataset contains all the data points (train/val/test),
49+
we require different transformation for each data point.
50+
This class helps to take the subset of the dataset
51+
with either training or validation transformation.
4952
5053
We achieve so by adding a train flag to the pytorch subset
54+
55+
Attributes:
56+
dataset (BaseDataset/Dataset): Dataset to sample the subset
57+
indices names (Sequence[int]): Indices to sample from the dataset
58+
train (bool): If we apply train or validation transformation
59+
5160
"""
5261

5362
def __init__(self, dataset: Dataset, indices: Sequence[int], train: bool) -> None:
@@ -62,10 +71,10 @@ def __getitem__(self, idx: int) -> np.ndarray:
6271
class BaseDataset(Dataset, metaclass=ABCMeta):
6372
def __init__(
6473
self,
65-
train_tensors: BASE_DATASET_INPUT,
74+
train_tensors: BaseDatasetType,
6675
dataset_name: Optional[str] = None,
67-
val_tensors: Optional[BASE_DATASET_INPUT] = None,
68-
test_tensors: Optional[BASE_DATASET_INPUT] = None,
76+
val_tensors: Optional[BaseDatasetType] = None,
77+
test_tensors: Optional[BaseDatasetType] = None,
6978
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
7079
resampling_strategy_args: Optional[Dict[str, Any]] = None,
7180
shuffle: Optional[bool] = True,
@@ -97,18 +106,15 @@ def __init__(
97106
val_transforms (Optional[torchvision.transforms.Compose]):
98107
Additional Transforms to be applied to the validation/test data
99108
"""
100-
if dataset_name is not None:
101-
self.dataset_name = dataset_name
102-
else:
103-
self.dataset_name = hash_array_or_matrix(train_tensors[0])
109+
self.dataset_name = dataset_name if dataset_name is not None \
110+
else hash_array_or_matrix(train_tensors[0])
111+
104112
if not hasattr(train_tensors[0], 'shape'):
105113
type_check(train_tensors, val_tensors)
106-
self.train_tensors = train_tensors
107-
self.val_tensors = val_tensors
108-
self.test_tensors = test_tensors
114+
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
109115
self.cross_validators: Dict[str, CROSS_VAL_FN] = {}
110116
self.holdout_validators: Dict[str, HOLDOUT_FN] = {}
111-
self.rand = np.random.RandomState(seed=seed)
117+
self.rng = np.random.RandomState(seed=seed)
112118
self.shuffle = shuffle
113119
self.resampling_strategy = resampling_strategy
114120
self.resampling_strategy_args = resampling_strategy_args
@@ -128,16 +134,8 @@ def __init__(
128134
self.is_small_preprocess = True
129135

130136
# Make sure cross validation splits are created once
131-
self.cross_validators = get_cross_validators(
132-
CrossValTypes.stratified_k_fold_cross_validation,
133-
CrossValTypes.k_fold_cross_validation,
134-
CrossValTypes.shuffle_split_cross_validation,
135-
CrossValTypes.stratified_shuffle_split_cross_validation
136-
)
137-
self.holdout_validators = get_holdout_validators(
138-
HoldoutValTypes.holdout_validation,
139-
HoldoutValTypes.stratified_holdout_validation
140-
)
137+
self.cross_validators = get_cross_validators(*CrossValTypes)
138+
self.holdout_validators = get_holdout_validators(*HoldoutValTypes)
141139
self.splits = self.get_splits_from_resampling_strategy()
142140

143141
# We also need to be able to transform the data, be it for pre-processing
@@ -146,19 +144,19 @@ def __init__(
146144
self.val_transform = val_transforms
147145

148146
def update_transform(self, transform: Optional[torchvision.transforms.Compose],
149-
train: bool = True,
150-
) -> 'BaseDataset':
147+
train: bool = True) -> 'BaseDataset':
151148
"""
152149
During the pipeline execution, the pipeline object might propose transformations
153150
as a product of the current pipeline configuration being tested.
154151
155-
This utility allows to return a self with the updated transformation, so that
152+
This utility allows to return self with the updated transformation, so that
156153
a dataloader can yield this dataset with the desired transformations
157154
158155
Args:
159-
transform (torchvision.transforms.Compose): The transformations proposed
160-
by the current pipeline
161-
train (bool): Whether to update the train or validation transform
156+
transform (torchvision.transforms.Compose):
157+
The transformations proposed by the current pipeline
158+
train (bool):
159+
Whether to update the train or validation transform
162160
163161
Returns:
164162
self: A copy of the update pipeline
@@ -171,9 +169,9 @@ def update_transform(self, transform: Optional[torchvision.transforms.Compose],
171169

172170
def __getitem__(self, index: int, train: bool = True) -> Tuple[np.ndarray, ...]:
173171
"""
174-
The base dataset uses a Subset of the data. Nevertheless, the base dataset expect
175-
both validation and test data to be present in the same dataset, which motivated the
176-
need to dynamically give train/test data with the __getitem__ command.
172+
The base dataset uses a Subset of the data. Nevertheless, the base dataset expects
173+
both validation and test data to be present in the same dataset, which motivates
174+
the need to dynamically give train/test data with the __getitem__ command.
177175
178176
This method yields a datapoint of the whole data (after a Subset has selected a given
179177
item, based on the resampling strategy) and applies a train/testing transformation, if any.
@@ -186,34 +184,24 @@ def __getitem__(self, index: int, train: bool = True) -> Tuple[np.ndarray, ...]:
186184
A transformed single point prediction
187185
"""
188186

189-
if hasattr(self.train_tensors[0], 'loc'):
190-
X = self.train_tensors[0].iloc[[index]]
191-
else:
192-
X = self.train_tensors[0][index]
187+
X = self.train_tensors[0].iloc[[index]] if hasattr(self.train_tensors[0], 'loc') \
188+
else self.train_tensors[0][index]
193189

194190
if self.train_transform is not None and train:
195191
X = self.train_transform(X)
196192
elif self.val_transform is not None and not train:
197193
X = self.val_transform(X)
198194

199195
# In case of prediction, the targets are not provided
200-
Y = self.train_tensors[1]
201-
if Y is not None:
202-
Y = Y[index]
203-
else:
204-
Y = None
196+
Y = self.train_tensors[1][index] if self.train_tensors[1] is not None else None
205197

206198
return X, Y
207199

208200
def __len__(self) -> int:
209201
return self.train_tensors[0].shape[0]
210202

211203
def _get_indices(self) -> np.ndarray:
212-
if self.shuffle:
213-
indices = self.rand.permutation(len(self))
214-
else:
215-
indices = np.arange(len(self))
216-
return indices
204+
return self.rng.permutation(len(self)) if self.shuffle else np.arange(len(self))
217205

218206
def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]]]:
219207
"""
@@ -333,7 +321,7 @@ def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
333321
return (TransformSubset(self, self.splits[split_id][0], train=True),
334322
TransformSubset(self, self.splits[split_id][1], train=False))
335323

336-
def replace_data(self, X_train: BASE_DATASET_INPUT, X_test: Optional[BASE_DATASET_INPUT]) -> 'BaseDataset':
324+
def replace_data(self, X_train: BaseDatasetType, X_test: Optional[BaseDatasetType]) -> 'BaseDataset':
337325
"""
338326
To speed up the training of small dataset, early pre-processing of the data
339327
can be made on the fly by the pipeline.
@@ -361,7 +349,8 @@ def get_dataset_properties(self, dataset_requirements: List[FitRequirement]) ->
361349
contain.
362350
363351
Returns:
364-
352+
dataset_properties (Dict[str, Any]):
353+
Dict of the dataset properties.
365354
"""
366355
dataset_properties = dict()
367356
for dataset_requirement in dataset_requirements:

test/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,6 @@ def get_fit_dictionary(X, y, validator, backend):
276276
info = datamanager.get_required_dataset_info()
277277

278278
dataset_properties = datamanager.get_dataset_properties(get_dataset_requirements(info))
279-
280279
fit_dictionary = {
281280
'X_train': datamanager.train_tensors[0],
282281
'y_train': datamanager.train_tensors[1],

0 commit comments

Comments
 (0)