1
- {% extends "base/main.py.jinja" %}
2
- {% block datasets_and_dataloaders %}
3
- train_dataset, eval_dataset = get_datasets(root=config.data_path)
1
+ {% block imports %}
2
+ from argparse import ArgumentParser
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import logging
8
+
9
+ import ignite.distributed as idist
10
+ from ignite.engine import create_supervised_evaluator, create_supervised_trainer
11
+ from ignite.engine.events import Events
12
+ from ignite.utils import setup_logger, manual_seed
13
+ from ignite.metrics import Accuracy, Loss
14
+
15
+ from datasets import get_datasets, get_data_loaders
16
+ from utils import log_metrics, get_default_parser, initialize, setup_common_handlers, setup_exp_logging
17
+ {% endblock %}
18
+
19
+
20
+ {% block run %}
21
+ def run(local_rank: int, config: Any, *args: Any, **kwags: Any):
22
+
23
+ # -----------------------------
24
+ # datasets and dataloaders
25
+ # -----------------------------
26
+ {% block datasets_and_dataloaders %}
27
+ train_dataset, eval_dataset = get_datasets(config.data_path)
4
28
train_dataloader, eval_dataloader = get_data_loaders(
5
29
train_dataset=train_dataset,
6
30
eval_dataset=eval_dataset,
7
31
train_batch_size=config.train_batch_size,
8
32
eval_batch_size=config.eval_batch_size,
9
33
num_workers=config.num_workers,
10
34
)
11
- {% endblock %}
35
+ {% endblock %}
12
36
13
- {% block model_optimizer_loss %}
14
- model = idist.auto_model(get_model(config.model_name))
15
- optimizer = idist.auto_optim(optim.Adam(model.parameters(), lr=config.lr))
16
- loss_fn = nn.CrossEntropyLoss()
17
- {% endblock %}
37
+ # ------------------------------------------
38
+ # model, optimizer, loss function, device
39
+ # ------------------------------------------
40
+ {% block model_optimizer_loss %}
41
+ device, model, optimizer, loss_fn = initialize(config)
42
+ {% endblock %}
43
+
44
+ # ----------------------
45
+ # train / eval engine
46
+ # ----------------------
47
+ {% block engines %}
48
+ train_engine = create_supervised_trainer(
49
+ model=model,
50
+ optimizer=optimizer,
51
+ loss_fn=loss_fn,
52
+ device=device,
53
+ output_transform=lambda x, y, y_pred, loss: {'train_loss': loss.item()},
54
+ )
55
+ metrics = {
56
+ 'eval_accuracy': Accuracy(device=device),
57
+ 'eval_loss': Loss(loss_fn=loss_fn, device=device)
58
+ }
59
+ eval_engine = create_supervised_evaluator(
60
+ model=model,
61
+ metrics=metrics,
62
+ device=device,
63
+ )
64
+ {% endblock %}
18
65
19
- {% block metrics %}
20
- Accuracy(device=config.device).attach(eval_engine, "eval_accuracy")
66
+ # ---------------
67
+ # setup logging
68
+ # ---------------
69
+ {% block loggers %}
70
+ name = f"bs{config.train_batch_size}-lr{config.lr}-{optimizer.__class__.__name__}"
71
+ now = datetime.now().strftime("%Y%m%d-%X")
72
+ train_engine.logger = setup_logger("trainer", level=config.verbose, filepath=config.filepath / f"{name}-{now}.log")
73
+ eval_engine.logger = setup_logger("evaluator", level=config.verbose, filepath=config.filepath / f"{name}-{now}.log")
74
+ {% endblock %}
75
+
76
+ # -----------------------------------------
77
+ # checkpoint and common training handlers
78
+ # -----------------------------------------
79
+ {% block eval_ckpt_common_training %}
80
+ eval_ckpt_handler = setup_common_handlers(
81
+ config=config,
82
+ eval_engine=eval_engine,
83
+ train_engine=train_engine,
84
+ model=model,
85
+ optimizer=optimizer
86
+ )
87
+ {% endblock %}
88
+
89
+ # --------------------------------
90
+ # setup common experiment loggers
91
+ # --------------------------------
92
+ {% block exp_loggers %}
93
+ exp_logger = setup_exp_logging(
94
+ config=config,
95
+ eval_engine=eval_engine,
96
+ train_engine=train_engine,
97
+ optimizer=optimizer,
98
+ name=name
99
+ )
100
+ {% endblock %}
101
+
102
+ # ----------------------
103
+ # engines log and run
104
+ # ----------------------
105
+ {% block engines_run_and_log %}
106
+ {% block log_training_results %}
107
+ @train_engine.on(Events.ITERATION_COMPLETED(every=config.log_train))
108
+ def log_training_results(engine):
109
+ train_engine.state.metrics = train_engine.state.output
110
+ log_metrics(train_engine, "Train", device)
111
+ {% endblock %}
112
+
113
+ {% block run_eval_engine_and_log %}
114
+ @train_engine.on(Events.EPOCH_COMPLETED(every=config.log_eval))
115
+ def run_eval_engine_and_log(engine):
116
+ eval_engine.run(
117
+ eval_dataloader,
118
+ max_epochs=config.eval_max_epochs,
119
+ epoch_length=config.eval_epoch_length
120
+ )
121
+ log_metrics(eval_engine, "Eval", device)
122
+ {% endblock %}
123
+
124
+ train_engine.run(
125
+ train_dataloader,
126
+ max_epochs=config.train_max_epochs,
127
+ epoch_length=config.train_epoch_length
128
+ )
129
+ {% endblock %}
21
130
{% endblock %}
22
131
23
132
{% block main_fn %}
24
133
def main():
25
134
parser = ArgumentParser(parents=[get_default_parser()])
26
- parser.add_argument(
27
- "--model_name",
28
- default="{{ model_name }}",
29
- type=str,
30
- help="Image classification model name ({{ model_name}})"
31
- )
32
135
config = parser.parse_args()
33
136
manual_seed(config.seed)
34
137
config.verbose = logging.INFO if config.verbose else logging.WARNING
@@ -46,3 +149,9 @@ def main():
46
149
) as parallel:
47
150
parallel.run(run, config=config)
48
151
{% endblock %}
152
+
153
+
154
+ {% block entrypoint %}
155
+ if __name__ == "__main__":
156
+ main()
157
+ {% endblock %}
0 commit comments