@@ -112,13 +112,7 @@ def __init__(
112
112
dataset_name : Optional [str ] = None ,
113
113
val_tensors : Optional [BaseDatasetInputType ] = None ,
114
114
test_tensors : Optional [BaseDatasetInputType ] = None ,
115
- << << << < HEAD
116
115
resampling_strategy : ResamplingStrategies = HoldoutValTypes .holdout_validation ,
117
- == == == =
118
- resampling_strategy : Union [CrossValTypes ,
119
- HoldoutValTypes ,
120
- NoResamplingStrategyTypes ] = HoldoutValTypes .holdout_validation ,
121
- >> >> >> > Create fit evaluator , no resampling strategy and fix bug for test statistics
122
116
resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
123
117
shuffle : Optional [bool ] = True ,
124
118
seed : Optional [int ] = 42 ,
@@ -135,12 +129,7 @@ def __init__(
135
129
validation data
136
130
test_tensors (An optional tuple of objects that have a __len__ and a __getitem__ attribute):
137
131
test data
138
- <<<<<<< HEAD
139
132
resampling_strategy (RESAMPLING_STRATEGIES: default=HoldoutValTypes.holdout_validation):
140
- =======
141
- resampling_strategy (Union[CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes]),
142
- (default=HoldoutValTypes.holdout_validation):
143
- >>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
144
133
strategy to split the training data.
145
134
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
146
135
required for the chosen resampling strategy. If None, uses
@@ -162,17 +151,11 @@ def __init__(
162
151
if not hasattr (train_tensors [0 ], 'shape' ):
163
152
type_check (train_tensors , val_tensors )
164
153
self .train_tensors , self .val_tensors , self .test_tensors = train_tensors , val_tensors , test_tensors
165
- << << << < HEAD
166
154
self .cross_validators : Dict [str , CrossValFunc ] = {}
167
155
self .holdout_validators : Dict [str , HoldOutFunc ] = {}
168
156
self .no_resampling_validators : Dict [str , NoResamplingFunc ] = {}
169
157
self .random_state = np .random .RandomState (seed = seed )
170
- == == == =
171
- self .cross_validators : Dict [str , CROSS_VAL_FN ] = {}
172
- self .holdout_validators : Dict [str , HOLDOUT_FN ] = {}
173
- self .no_resampling_validators : Dict [str , NO_RESAMPLING_FN ] = {}
174
- self .rng = np .random .RandomState (seed = seed )
175
- >> >> >> > Fix mypy and flake
158
+ self .no_resampling_validators : Dict [str , NoResamplingFunc ] = {}
176
159
self .shuffle = shuffle
177
160
self .resampling_strategy = resampling_strategy
178
161
self .resampling_strategy_args = resampling_strategy_args
@@ -189,11 +172,8 @@ def __init__(
189
172
# Make sure cross validation splits are created once
190
173
self .cross_validators = CrossValFuncs .get_cross_validators (* CrossValTypes )
191
174
self .holdout_validators = HoldOutFuncs .get_holdout_validators (* HoldoutValTypes )
192
- < << << << HEAD
175
+
193
176
self .no_resampling_validators = NoResamplingFuncs .get_no_resampling_validators (* NoResamplingStrategyTypes )
194
- == == == =
195
- self .no_resampling_validators = get_no_resampling_validators (* NoResamplingStrategyTypes )
196
- >> >> >> > Create fit evaluator , no resampling strategy and fix bug for test statistics
197
177
198
178
self .splits = self .get_splits_from_resampling_strategy ()
199
179
@@ -294,12 +274,8 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], Optional[
294
274
)
295
275
)
296
276
elif isinstance (self .resampling_strategy , NoResamplingStrategyTypes ):
297
- << << << < HEAD
298
277
splits .append ((self .no_resampling_validators [self .resampling_strategy .name ](self .random_state ,
299
278
self ._get_indices ()), None ))
300
- == == == =
301
- splits .append ((self .no_resampling_validators [self .resampling_strategy .name ](self ._get_indices ()), None ))
302
- >> > >> > > Create fit evaluator , no resampling strategy and fix bug for test statistics
303
279
else :
304
280
raise ValueError (f"Unsupported resampling strategy={ self .resampling_strategy } " )
305
281
return splits
@@ -371,11 +347,7 @@ def create_holdout_val_split(
371
347
self .random_state , val_share , self ._get_indices (), ** kwargs )
372
348
return train , val
373
349
374
- << << < << HEAD
375
350
def get_dataset (self , split_id : int , train : bool ) -> Dataset :
376
- == == == =
377
- def get_dataset_for_training (self , split_id : int , train : bool ) - > Dataset :
378
- >> >> >> > Create fit evaluator , no resampling strategy and fix bug for test statistics
379
351
"""
380
352
The above split methods employ the Subset to internally subsample the whole dataset.
381
353
@@ -390,7 +362,6 @@ def get_dataset_for_training(self, split_id: int, train: bool) -> Dataset:
390
362
Dataset: the reduced dataset to be used for testing
391
363
"""
392
364
# Subset creates a dataset. Splits is a (train_indices, test_indices) tuple
393
- << << < << HEAD
394
365
if split_id >= len (self .splits ): # old version: split_id > len(self.splits)
395
366
raise IndexError (f"self.splits index out of range, got split_id={ split_id } "
396
367
f" (>= num_splits={ len (self .splits )} )" )
@@ -399,9 +370,6 @@ def get_dataset_for_training(self, split_id: int, train: bool) -> Dataset:
399
370
raise ValueError ("Specified fold (or subset) does not exist" )
400
371
401
372
return TransformSubset (self , indices , train = train )
402
- == == == =
403
- return TransformSubset (self , self .splits [split_id ][0 ], train = train )
404
- >> >> > >> Create fit evaluator , no resampling strategy and fix bug for test statistics
405
373
406
374
def replace_data (self , X_train : BaseDatasetInputType ,
407
375
X_test : Optional [BaseDatasetInputType ]) -> 'BaseDataset' :
0 commit comments