Skip to content

Commit 34c704d

Browse files
authored
[add] documentation update in base trainer (#468)
1 parent afddca5 commit 34c704d

File tree

1 file changed

+181
-54
lines changed

1 file changed

+181
-54
lines changed

autoPyTorch/pipeline/components/training/trainer/base_trainer.py

Lines changed: 181 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -45,25 +45,53 @@ def __init__(self,
4545
An object for tracking when to stop the network training.
4646
It handles epoch based criteria as well as training based criteria.
4747
48-
It also allows to define a 'epoch_or_time' budget type, which means,
49-
the first of them both which is exhausted, is honored
48+
It also allows to define a 'epoch_or_time' budget type, which means, the first of them both which is
49+
exhausted, is honored
50+
51+
Args:
52+
budget_type (str):
53+
Type of budget to be used when fitting the pipeline.
54+
Possible values are 'epochs', 'runtime', or 'epoch_or_time'
55+
max_epochs (Optional[int], default=None):
56+
Maximum number of epochs to train the pipeline for
57+
max_runtime (Optional[int], default=None):
58+
Maximum number of seconds to train the pipeline for
5059
"""
5160
self.start_time = time.time()
5261
self.budget_type = budget_type
5362
self.max_epochs = max_epochs
5463
self.max_runtime = max_runtime
5564

5665
def is_max_epoch_reached(self, epoch: int) -> bool:
66+
"""
67+
For budget type 'epoch' or 'epoch_or_time' return True if the maximum number of epochs is reached.
68+
69+
Args:
70+
epoch (int):
71+
the current epoch
5772
58-
# Make None a method to run without this constrain
73+
Returns:
74+
bool:
75+
True if the current epoch is larger than the maximum epochs, False otherwise.
76+
Additionally, returns False if the run is without this constraint.
77+
"""
78+
# Make None a method to run without this constraint
5979
if self.max_epochs is None:
6080
return False
6181
if self.budget_type in ['epochs', 'epoch_or_time'] and epoch > self.max_epochs:
6282
return True
6383
return False
6484

6585
def is_max_time_reached(self) -> bool:
66-
# Make None a method to run without this constrain
86+
"""
87+
For budget type 'runtime' or 'epoch_or_time' return True if the maximum runtime is reached.
88+
89+
Returns:
90+
bool:
91+
True if the maximum runtime is reached, False otherwise.
92+
Additionally, returns False if the run is without this constraint.
93+
"""
94+
# Make None a method to run without this constraint
6795
if self.max_runtime is None:
6896
return False
6997
elapsed_time = time.time() - self.start_time
@@ -78,14 +106,22 @@ def __init__(
78106
total_parameter_count: float,
79107
trainable_parameter_count: float,
80108
optimize_metric: Optional[str] = None,
81-
):
109+
) -> None:
82110
"""
83111
A useful object to track performance per epoch.
84112
85-
It allows to track train, validation and test information not only for
86-
debug, but for research purposes (Like understanding overfit).
113+
It allows to track train, validation and test information not only for debug, but for research purposes
114+
(Like understanding overfit).
87115
88116
It does so by tracking a metric/loss at the end of each epoch.
117+
118+
Args:
119+
total_parameter_count (float):
120+
the total number of parameters of the model
121+
trainable_parameter_count (float):
122+
only the parameters being optimized
123+
optimize_metric (Optional[str], default=None):
124+
name of the metric that is used to evaluate a pipeline.
89125
"""
90126
self.performance_tracker: Dict[str, Dict] = {
91127
'start_time': {},
@@ -121,8 +157,30 @@ def add_performance(self,
121157
test_loss: Optional[float] = None,
122158
) -> None:
123159
"""
124-
Tracks performance information about the run, useful for
125-
plotting individual runs
160+
Tracks performance information about the run, useful for plotting individual runs.
161+
162+
Args:
163+
epoch (int):
164+
the current epoch
165+
start_time (float):
166+
timestamp at the beginning of current epoch
167+
end_time (float):
168+
timestamp when gathering the information after the current epoch
169+
train_loss (float):
170+
the training loss
171+
train_metrics (Dict[str, float]):
172+
training scores for each desired metric
173+
val_metrics (Dict[str, float]):
174+
validation scores for each desired metric
175+
test_metrics (Dict[str, float]):
176+
test scores for each desired metric
177+
val_loss (Optional[float], default=None):
178+
the validation loss
179+
test_loss (Optional[float], default=None):
180+
the test loss
181+
182+
Returns:
183+
None
126184
"""
127185
self.performance_tracker['train_loss'][epoch] = train_loss
128186
self.performance_tracker['val_loss'][epoch] = val_loss
@@ -134,6 +192,18 @@ def add_performance(self,
134192
self.performance_tracker['test_metrics'][epoch] = test_metrics
135193

136194
def get_best_epoch(self, split_type: str = 'val') -> int:
195+
"""
196+
Get the epoch with the best metric.
197+
198+
Args:
199+
split_type (str, default=val):
200+
Which split's metric to consider.
201+
Possible values are 'train' or 'val
202+
203+
Returns:
204+
int:
205+
the epoch with the best metric
206+
"""
137207
# If we compute for optimization, prefer the performance
138208
# metric to the loss
139209
if self.optimize_metric is not None:
@@ -159,6 +229,13 @@ def get_best_epoch(self, split_type: str = 'val') -> int:
159229
)) + 1 # Epochs start at 1
160230

161231
def get_last_epoch(self) -> int:
232+
"""
233+
Get the last epoch.
234+
235+
Returns:
236+
int:
237+
the last epoch
238+
"""
162239
if 'train_loss' not in self.performance_tracker:
163240
return 0
164241
else:
@@ -170,7 +247,8 @@ def repr_last_epoch(self) -> str:
170247
performance
171248
172249
Returns:
173-
str: A nice representation of the last epoch
250+
str:
251+
A nice representation of the last epoch
174252
"""
175253
last_epoch = len(self.performance_tracker['train_loss'])
176254
string = "\n"
@@ -202,30 +280,43 @@ def is_empty(self) -> bool:
202280
Checks if the object is empty or not
203281
204282
Returns:
205-
bool
283+
bool:
284+
True if the object is empty, False otherwise
206285
"""
207286
# if train_loss is empty, we can be sure that RunSummary is empty.
208287
return not bool(self.performance_tracker['train_loss'])
209288

210289

211290
class BaseTrainerComponent(autoPyTorchTrainingComponent):
212291
"""
213-
Base class for training
292+
Base class for training.
293+
214294
Args:
215-
weighted_loss (int, default=0): In case for classification, whether to weight
216-
the loss function according to the distribution of classes in the target
217-
use_stochastic_weight_averaging (bool, default=True): whether to use stochastic
218-
weight averaging. Stochastic weight averaging is a simple average of
219-
multiple points(model parameters) along the trajectory of SGD. SWA
220-
has been proposed in
295+
weighted_loss (int, default=0):
296+
In case for classification, whether to weight the loss function according to the distribution of classes
297+
in the target
298+
use_stochastic_weight_averaging (bool, default=True):
299+
whether to use stochastic weight averaging. Stochastic weight averaging is a simple average of
300+
multiple points(model parameters) along the trajectory of SGD. SWA has been proposed in
221301
[Averaging Weights Leads to Wider Optima and Better Generalization](https://arxiv.org/abs/1803.05407)
222-
use_snapshot_ensemble (bool, default=True): whether to use snapshot
223-
ensemble
224-
se_lastk (int, default=3): Number of snapshots of the network to maintain
225-
use_lookahead_optimizer (bool, default=True): whether to use lookahead
226-
optimizer
227-
random_state:
228-
**lookahead_config:
302+
use_snapshot_ensemble (bool, default=True):
303+
whether to use snapshot ensemble
304+
se_lastk (int, default=3):
305+
Number of snapshots of the network to maintain
306+
use_lookahead_optimizer (bool, default=True):
307+
whether to use lookahead optimizer
308+
random_state (Optional[np.random.RandomState]):
309+
Object that contains a seed and allows for reproducible results
310+
swa_model (Optional[torch.nn.Module], default=None):
311+
Averaged model used for Stochastic Weight Averaging
312+
model_snapshots (Optional[List[torch.nn.Module]], default=None):
313+
List of model snapshots in case snapshot ensemble is used
314+
**lookahead_config (Any):
315+
keyword arguments for the lookahead optimizer including:
316+
la_steps (int):
317+
number of lookahead steps
318+
la_alpha (float):
319+
linear interpolation factor. 1.0 recovers the inner optimizer.
229320
"""
230321
def __init__(self, weighted_loss: int = 0,
231322
use_stochastic_weight_averaging: bool = True,
@@ -336,15 +427,21 @@ def prepare(
336427

337428
def on_epoch_start(self, X: Dict[str, Any], epoch: int) -> None:
338429
"""
339-
Optional place holder for AutoPytorch Extensions.
430+
Optional placeholder for AutoPytorch Extensions.
431+
A user can define what happens on every epoch start or every epoch end.
340432
341-
An user can define what happens on every epoch start or every epoch end.
433+
Args:
434+
X (Dict[str, Any]):
435+
Dictionary with fitted parameters. It is a message passing mechanism, in which during a transform,
436+
a components adds relevant information so that further stages can be properly fitted
437+
epoch (int):
438+
the current epoch
342439
"""
343440
pass
344441

345442
def _swa_update(self) -> None:
346443
"""
347-
perform swa model update
444+
Perform Stochastic Weight Averaging model update
348445
"""
349446
if self.swa_model is None:
350447
raise ValueError("SWA model cannot be none when stochastic weight averaging is enabled")
@@ -354,6 +451,7 @@ def _swa_update(self) -> None:
354451
def _se_update(self, epoch: int) -> None:
355452
"""
356453
Add latest model or swa_model to model snapshot ensemble
454+
357455
Args:
358456
epoch (int):
359457
current epoch
@@ -373,9 +471,16 @@ def _se_update(self, epoch: int) -> None:
373471

374472
def on_epoch_end(self, X: Dict[str, Any], epoch: int) -> bool:
375473
"""
376-
Optional place holder for AutoPytorch Extensions.
377-
An user can define what happens on every epoch start or every epoch end.
378-
If returns True, the training is stopped
474+
Optional placeholder for AutoPytorch Extensions.
475+
A user can define what happens on every epoch start or every epoch end.
476+
If returns True, the training is stopped.
477+
478+
Args:
479+
X (Dict[str, Any]):
480+
Dictionary with fitted parameters. It is a message passing mechanism, in which during a transform,
481+
a components adds relevant information so that further stages can be properly fitted
482+
epoch (int):
483+
the current epoch
379484
380485
"""
381486
if X['is_cyclic_scheduler']:
@@ -421,12 +526,18 @@ def train_epoch(self, train_loader: torch.utils.data.DataLoader, epoch: int,
421526
Train the model for a single epoch.
422527
423528
Args:
424-
train_loader (torch.utils.data.DataLoader): generator of features/label
425-
epoch (int): The current epoch used solely for tracking purposes
529+
train_loader (torch.utils.data.DataLoader):
530+
generator of features/label
531+
epoch (int):
532+
The current epoch used solely for tracking purposes
533+
writer (Optional[SummaryWriter]):
534+
Object to keep track of the training loss in an event file
426535
427536
Returns:
428-
float: training loss
429-
Dict[str, float]: scores for each desired metric
537+
float:
538+
training loss
539+
Dict[str, float]:
540+
scores for each desired metric
430541
"""
431542

432543
loss_sum = 0.0
@@ -482,12 +593,16 @@ def train_step(self, data: torch.Tensor, targets: torch.Tensor) -> Tuple[float,
482593
Allows to train 1 step of gradient descent, given a batch of train/labels
483594
484595
Args:
485-
data (torch.Tensor): input features to the network
486-
targets (torch.Tensor): ground truth to calculate loss
596+
data (torch.Tensor):
597+
input features to the network
598+
targets (torch.Tensor):
599+
ground truth to calculate loss
487600
488601
Returns:
489-
torch.Tensor: The predictions of the network
490-
float: the loss incurred in the prediction
602+
torch.Tensor:
603+
The predictions of the network
604+
float:
605+
the loss incurred in the prediction
491606
"""
492607
# prepare
493608
data = data.float().to(self.device)
@@ -513,12 +628,18 @@ def evaluate(self, test_loader: torch.utils.data.DataLoader, epoch: int,
513628
Evaluate the model in both metrics and criterion
514629
515630
Args:
516-
test_loader (torch.utils.data.DataLoader): generator of features/label
517-
epoch (int): the current epoch for tracking purposes
631+
test_loader (torch.utils.data.DataLoader):
632+
generator of features/label
633+
epoch (int):
634+
the current epoch for tracking purposes
635+
writer (Optional[SummaryWriter]):
636+
Object to keep track of the test loss in an event file
518637
519638
Returns:
520-
float: test loss
521-
Dict[str, float]: scores for each desired metric
639+
float:
640+
test loss
641+
Dict[str, float]:
642+
scores for each desired metric
522643
"""
523644
self.model.eval()
524645

@@ -576,14 +697,15 @@ def get_class_weights(self, criterion: Type[torch.nn.Module], labels: Union[np.n
576697
def data_preparation(self, X: torch.Tensor, y: torch.Tensor,
577698
) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]:
578699
"""
579-
Depending on the trainer choice, data fed to the network might be pre-processed
580-
on a different way. That is, in standard training we provide the data to the
581-
network as we receive it to the loader. Some regularization techniques, like mixup
582-
alter the data.
700+
Depending on the trainer choice, data fed to the network might be pre-processed on a different way. That is,
701+
in standard training we provide the data to the network as we receive it to the loader. Some regularization
702+
techniques, like mixup alter the data.
583703
584704
Args:
585-
X (torch.Tensor): The batch training features
586-
y (torch.Tensor): The batch training labels
705+
X (torch.Tensor):
706+
The batch training features
707+
y (torch.Tensor):
708+
The batch training labels
587709
588710
Returns:
589711
torch.Tensor: that processes data
@@ -595,16 +717,21 @@ def data_preparation(self, X: torch.Tensor, y: torch.Tensor,
595717
def criterion_preparation(self, y_a: torch.Tensor, y_b: torch.Tensor = None, lam: float = 1.0
596718
) -> Callable: # type: ignore
597719
"""
598-
Depending on the trainer choice, the criterion is not directly applied to the
599-
traditional y_pred/y_ground_truth pairs, but rather it might have a slight transformation.
720+
Depending on the trainer choice, the criterion is not directly applied to the traditional
721+
y_pred/y_ground_truth pairs, but rather it might have a slight transformation.
600722
For example, in the case of mixup training, we need to account for the lambda mixup
601723
602724
Args:
603-
kwargs (Dict): an expanded dictionary with modifiers to the
604-
criterion calculation
725+
y_a (torch.Tensor):
726+
the batch label of the first training example used in trainer
727+
y_b (torch.Tensor, default=None):
728+
if applicable, the batch label of the second training example used in trainer
729+
lam (float):
730+
trainer coefficient
605731
606732
Returns:
607-
Callable: a lambda function that contains the new criterion calculation recipe
733+
Callable:
734+
a lambda function that contains the new criterion calculation recipe
608735
"""
609736
raise NotImplementedError()
610737

0 commit comments

Comments
 (0)