@@ -45,25 +45,53 @@ def __init__(self,
45
45
An object for tracking when to stop the network training.
46
46
It handles epoch based criteria as well as training based criteria.
47
47
48
- It also allows to define a 'epoch_or_time' budget type, which means,
49
- the first of them both which is exhausted, is honored
48
+ It also allows to define a 'epoch_or_time' budget type, which means, the first of them both which is
49
+ exhausted, is honored
50
+
51
+ Args:
52
+ budget_type (str):
53
+ Type of budget to be used when fitting the pipeline.
54
+ Possible values are 'epochs', 'runtime', or 'epoch_or_time'
55
+ max_epochs (Optional[int], default=None):
56
+ Maximum number of epochs to train the pipeline for
57
+ max_runtime (Optional[int], default=None):
58
+ Maximum number of seconds to train the pipeline for
50
59
"""
51
60
self .start_time = time .time ()
52
61
self .budget_type = budget_type
53
62
self .max_epochs = max_epochs
54
63
self .max_runtime = max_runtime
55
64
56
65
def is_max_epoch_reached (self , epoch : int ) -> bool :
66
+ """
67
+ For budget type 'epoch' or 'epoch_or_time' return True if the maximum number of epochs is reached.
68
+
69
+ Args:
70
+ epoch (int):
71
+ the current epoch
57
72
58
- # Make None a method to run without this constrain
73
+ Returns:
74
+ bool:
75
+ True if the current epoch is larger than the maximum epochs, False otherwise.
76
+ Additionally, returns False if the run is without this constraint.
77
+ """
78
+ # Make None a method to run without this constraint
59
79
if self .max_epochs is None :
60
80
return False
61
81
if self .budget_type in ['epochs' , 'epoch_or_time' ] and epoch > self .max_epochs :
62
82
return True
63
83
return False
64
84
65
85
def is_max_time_reached (self ) -> bool :
66
- # Make None a method to run without this constrain
86
+ """
87
+ For budget type 'runtime' or 'epoch_or_time' return True if the maximum runtime is reached.
88
+
89
+ Returns:
90
+ bool:
91
+ True if the maximum runtime is reached, False otherwise.
92
+ Additionally, returns False if the run is without this constraint.
93
+ """
94
+ # Make None a method to run without this constraint
67
95
if self .max_runtime is None :
68
96
return False
69
97
elapsed_time = time .time () - self .start_time
@@ -78,14 +106,22 @@ def __init__(
78
106
total_parameter_count : float ,
79
107
trainable_parameter_count : float ,
80
108
optimize_metric : Optional [str ] = None ,
81
- ):
109
+ ) -> None :
82
110
"""
83
111
A useful object to track performance per epoch.
84
112
85
- It allows to track train, validation and test information not only for
86
- debug, but for research purposes (Like understanding overfit).
113
+ It allows to track train, validation and test information not only for debug, but for research purposes
114
+ (Like understanding overfit).
87
115
88
116
It does so by tracking a metric/loss at the end of each epoch.
117
+
118
+ Args:
119
+ total_parameter_count (float):
120
+ the total number of parameters of the model
121
+ trainable_parameter_count (float):
122
+ only the parameters being optimized
123
+ optimize_metric (Optional[str], default=None):
124
+ name of the metric that is used to evaluate a pipeline.
89
125
"""
90
126
self .performance_tracker : Dict [str , Dict ] = {
91
127
'start_time' : {},
@@ -121,8 +157,30 @@ def add_performance(self,
121
157
test_loss : Optional [float ] = None ,
122
158
) -> None :
123
159
"""
124
- Tracks performance information about the run, useful for
125
- plotting individual runs
160
+ Tracks performance information about the run, useful for plotting individual runs.
161
+
162
+ Args:
163
+ epoch (int):
164
+ the current epoch
165
+ start_time (float):
166
+ timestamp at the beginning of current epoch
167
+ end_time (float):
168
+ timestamp when gathering the information after the current epoch
169
+ train_loss (float):
170
+ the training loss
171
+ train_metrics (Dict[str, float]):
172
+ training scores for each desired metric
173
+ val_metrics (Dict[str, float]):
174
+ validation scores for each desired metric
175
+ test_metrics (Dict[str, float]):
176
+ test scores for each desired metric
177
+ val_loss (Optional[float], default=None):
178
+ the validation loss
179
+ test_loss (Optional[float], default=None):
180
+ the test loss
181
+
182
+ Returns:
183
+ None
126
184
"""
127
185
self .performance_tracker ['train_loss' ][epoch ] = train_loss
128
186
self .performance_tracker ['val_loss' ][epoch ] = val_loss
@@ -134,6 +192,18 @@ def add_performance(self,
134
192
self .performance_tracker ['test_metrics' ][epoch ] = test_metrics
135
193
136
194
def get_best_epoch (self , split_type : str = 'val' ) -> int :
195
+ """
196
+ Get the epoch with the best metric.
197
+
198
+ Args:
199
+ split_type (str, default=val):
200
+ Which split's metric to consider.
201
+ Possible values are 'train' or 'val
202
+
203
+ Returns:
204
+ int:
205
+ the epoch with the best metric
206
+ """
137
207
# If we compute for optimization, prefer the performance
138
208
# metric to the loss
139
209
if self .optimize_metric is not None :
@@ -159,6 +229,13 @@ def get_best_epoch(self, split_type: str = 'val') -> int:
159
229
)) + 1 # Epochs start at 1
160
230
161
231
def get_last_epoch (self ) -> int :
232
+ """
233
+ Get the last epoch.
234
+
235
+ Returns:
236
+ int:
237
+ the last epoch
238
+ """
162
239
if 'train_loss' not in self .performance_tracker :
163
240
return 0
164
241
else :
@@ -170,7 +247,8 @@ def repr_last_epoch(self) -> str:
170
247
performance
171
248
172
249
Returns:
173
- str: A nice representation of the last epoch
250
+ str:
251
+ A nice representation of the last epoch
174
252
"""
175
253
last_epoch = len (self .performance_tracker ['train_loss' ])
176
254
string = "\n "
@@ -202,30 +280,43 @@ def is_empty(self) -> bool:
202
280
Checks if the object is empty or not
203
281
204
282
Returns:
205
- bool
283
+ bool:
284
+ True if the object is empty, False otherwise
206
285
"""
207
286
# if train_loss is empty, we can be sure that RunSummary is empty.
208
287
return not bool (self .performance_tracker ['train_loss' ])
209
288
210
289
211
290
class BaseTrainerComponent (autoPyTorchTrainingComponent ):
212
291
"""
213
- Base class for training
292
+ Base class for training.
293
+
214
294
Args:
215
- weighted_loss (int, default=0): In case for classification, whether to weight
216
- the loss function according to the distribution of classes in the target
217
- use_stochastic_weight_averaging (bool, default=True): whether to use stochastic
218
- weight averaging. Stochastic weight averaging is a simple average of
219
- multiple points(model parameters) along the trajectory of SGD. SWA
220
- has been proposed in
295
+ weighted_loss (int, default=0):
296
+ In case for classification, whether to weight the loss function according to the distribution of classes
297
+ in the target
298
+ use_stochastic_weight_averaging (bool, default=True):
299
+ whether to use stochastic weight averaging. Stochastic weight averaging is a simple average of
300
+ multiple points(model parameters) along the trajectory of SGD. SWA has been proposed in
221
301
[Averaging Weights Leads to Wider Optima and Better Generalization](https://arxiv.org/abs/1803.05407)
222
- use_snapshot_ensemble (bool, default=True): whether to use snapshot
223
- ensemble
224
- se_lastk (int, default=3): Number of snapshots of the network to maintain
225
- use_lookahead_optimizer (bool, default=True): whether to use lookahead
226
- optimizer
227
- random_state:
228
- **lookahead_config:
302
+ use_snapshot_ensemble (bool, default=True):
303
+ whether to use snapshot ensemble
304
+ se_lastk (int, default=3):
305
+ Number of snapshots of the network to maintain
306
+ use_lookahead_optimizer (bool, default=True):
307
+ whether to use lookahead optimizer
308
+ random_state (Optional[np.random.RandomState]):
309
+ Object that contains a seed and allows for reproducible results
310
+ swa_model (Optional[torch.nn.Module], default=None):
311
+ Averaged model used for Stochastic Weight Averaging
312
+ model_snapshots (Optional[List[torch.nn.Module]], default=None):
313
+ List of model snapshots in case snapshot ensemble is used
314
+ **lookahead_config (Any):
315
+ keyword arguments for the lookahead optimizer including:
316
+ la_steps (int):
317
+ number of lookahead steps
318
+ la_alpha (float):
319
+ linear interpolation factor. 1.0 recovers the inner optimizer.
229
320
"""
230
321
def __init__ (self , weighted_loss : int = 0 ,
231
322
use_stochastic_weight_averaging : bool = True ,
@@ -336,15 +427,21 @@ def prepare(
336
427
337
428
def on_epoch_start (self , X : Dict [str , Any ], epoch : int ) -> None :
338
429
"""
339
- Optional place holder for AutoPytorch Extensions.
430
+ Optional placeholder for AutoPytorch Extensions.
431
+ A user can define what happens on every epoch start or every epoch end.
340
432
341
- An user can define what happens on every epoch start or every epoch end.
433
+ Args:
434
+ X (Dict[str, Any]):
435
+ Dictionary with fitted parameters. It is a message passing mechanism, in which during a transform,
436
+ a components adds relevant information so that further stages can be properly fitted
437
+ epoch (int):
438
+ the current epoch
342
439
"""
343
440
pass
344
441
345
442
def _swa_update (self ) -> None :
346
443
"""
347
- perform swa model update
444
+ Perform Stochastic Weight Averaging model update
348
445
"""
349
446
if self .swa_model is None :
350
447
raise ValueError ("SWA model cannot be none when stochastic weight averaging is enabled" )
@@ -354,6 +451,7 @@ def _swa_update(self) -> None:
354
451
def _se_update (self , epoch : int ) -> None :
355
452
"""
356
453
Add latest model or swa_model to model snapshot ensemble
454
+
357
455
Args:
358
456
epoch (int):
359
457
current epoch
@@ -373,9 +471,16 @@ def _se_update(self, epoch: int) -> None:
373
471
374
472
def on_epoch_end (self , X : Dict [str , Any ], epoch : int ) -> bool :
375
473
"""
376
- Optional place holder for AutoPytorch Extensions.
377
- An user can define what happens on every epoch start or every epoch end.
378
- If returns True, the training is stopped
474
+ Optional placeholder for AutoPytorch Extensions.
475
+ A user can define what happens on every epoch start or every epoch end.
476
+ If returns True, the training is stopped.
477
+
478
+ Args:
479
+ X (Dict[str, Any]):
480
+ Dictionary with fitted parameters. It is a message passing mechanism, in which during a transform,
481
+ a components adds relevant information so that further stages can be properly fitted
482
+ epoch (int):
483
+ the current epoch
379
484
380
485
"""
381
486
if X ['is_cyclic_scheduler' ]:
@@ -421,12 +526,18 @@ def train_epoch(self, train_loader: torch.utils.data.DataLoader, epoch: int,
421
526
Train the model for a single epoch.
422
527
423
528
Args:
424
- train_loader (torch.utils.data.DataLoader): generator of features/label
425
- epoch (int): The current epoch used solely for tracking purposes
529
+ train_loader (torch.utils.data.DataLoader):
530
+ generator of features/label
531
+ epoch (int):
532
+ The current epoch used solely for tracking purposes
533
+ writer (Optional[SummaryWriter]):
534
+ Object to keep track of the training loss in an event file
426
535
427
536
Returns:
428
- float: training loss
429
- Dict[str, float]: scores for each desired metric
537
+ float:
538
+ training loss
539
+ Dict[str, float]:
540
+ scores for each desired metric
430
541
"""
431
542
432
543
loss_sum = 0.0
@@ -482,12 +593,16 @@ def train_step(self, data: torch.Tensor, targets: torch.Tensor) -> Tuple[float,
482
593
Allows to train 1 step of gradient descent, given a batch of train/labels
483
594
484
595
Args:
485
- data (torch.Tensor): input features to the network
486
- targets (torch.Tensor): ground truth to calculate loss
596
+ data (torch.Tensor):
597
+ input features to the network
598
+ targets (torch.Tensor):
599
+ ground truth to calculate loss
487
600
488
601
Returns:
489
- torch.Tensor: The predictions of the network
490
- float: the loss incurred in the prediction
602
+ torch.Tensor:
603
+ The predictions of the network
604
+ float:
605
+ the loss incurred in the prediction
491
606
"""
492
607
# prepare
493
608
data = data .float ().to (self .device )
@@ -513,12 +628,18 @@ def evaluate(self, test_loader: torch.utils.data.DataLoader, epoch: int,
513
628
Evaluate the model in both metrics and criterion
514
629
515
630
Args:
516
- test_loader (torch.utils.data.DataLoader): generator of features/label
517
- epoch (int): the current epoch for tracking purposes
631
+ test_loader (torch.utils.data.DataLoader):
632
+ generator of features/label
633
+ epoch (int):
634
+ the current epoch for tracking purposes
635
+ writer (Optional[SummaryWriter]):
636
+ Object to keep track of the test loss in an event file
518
637
519
638
Returns:
520
- float: test loss
521
- Dict[str, float]: scores for each desired metric
639
+ float:
640
+ test loss
641
+ Dict[str, float]:
642
+ scores for each desired metric
522
643
"""
523
644
self .model .eval ()
524
645
@@ -576,14 +697,15 @@ def get_class_weights(self, criterion: Type[torch.nn.Module], labels: Union[np.n
576
697
def data_preparation (self , X : torch .Tensor , y : torch .Tensor ,
577
698
) -> Tuple [torch .Tensor , Dict [str , np .ndarray ]]:
578
699
"""
579
- Depending on the trainer choice, data fed to the network might be pre-processed
580
- on a different way. That is, in standard training we provide the data to the
581
- network as we receive it to the loader. Some regularization techniques, like mixup
582
- alter the data.
700
+ Depending on the trainer choice, data fed to the network might be pre-processed on a different way. That is,
701
+ in standard training we provide the data to the network as we receive it to the loader. Some regularization
702
+ techniques, like mixup alter the data.
583
703
584
704
Args:
585
- X (torch.Tensor): The batch training features
586
- y (torch.Tensor): The batch training labels
705
+ X (torch.Tensor):
706
+ The batch training features
707
+ y (torch.Tensor):
708
+ The batch training labels
587
709
588
710
Returns:
589
711
torch.Tensor: that processes data
@@ -595,16 +717,21 @@ def data_preparation(self, X: torch.Tensor, y: torch.Tensor,
595
717
def criterion_preparation (self , y_a : torch .Tensor , y_b : torch .Tensor = None , lam : float = 1.0
596
718
) -> Callable : # type: ignore
597
719
"""
598
- Depending on the trainer choice, the criterion is not directly applied to the
599
- traditional y_pred/y_ground_truth pairs, but rather it might have a slight transformation.
720
+ Depending on the trainer choice, the criterion is not directly applied to the traditional
721
+ y_pred/y_ground_truth pairs, but rather it might have a slight transformation.
600
722
For example, in the case of mixup training, we need to account for the lambda mixup
601
723
602
724
Args:
603
- kwargs (Dict): an expanded dictionary with modifiers to the
604
- criterion calculation
725
+ y_a (torch.Tensor):
726
+ the batch label of the first training example used in trainer
727
+ y_b (torch.Tensor, default=None):
728
+ if applicable, the batch label of the second training example used in trainer
729
+ lam (float):
730
+ trainer coefficient
605
731
606
732
Returns:
607
- Callable: a lambda function that contains the new criterion calculation recipe
733
+ Callable:
734
+ a lambda function that contains the new criterion calculation recipe
608
735
"""
609
736
raise NotImplementedError ()
610
737
0 commit comments