Skip to content

Commit 5122bff

Browse files
author
Jeff Yang
authored
refactor: explicit template for image classification (#17)
1 parent f679dc3 commit 5122bff

15 files changed

+495
-157
lines changed

.github/workflows/ci.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ jobs:
3838
restore-keys: |
3939
${{ steps.get-date.outputs.date }}-${{ runner.os }}-${{ matrix.python-version }}-
4040
41-
- run: pip install --pre -r requirements-dev.txt -f https://download.pytorch.org/whl/cpu/torch_stable.html --progress-bar off
41+
- run: pip install -r requirements.txt --progress-bar off
42+
- run: pip install -r requirements-dev.txt -f https://download.pytorch.org/whl/cpu/torch_stable.html --progress-bar off
4243
- run: python -m torch.utils.collect_env
4344
- run: bash .github/run_test.sh generate
4445
- run: bash .github/run_test.sh unittest

app/streamlit_app.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import shutil
2-
from pathlib import Path
32
from datetime import datetime
3+
from pathlib import Path
44

55
import streamlit as st
66
from codegen import CodeGenerator
@@ -51,7 +51,7 @@ def render_code(self, fname="", code="", fold=False):
5151

5252
def add_sidebar(self):
5353
def config(template_name):
54-
return import_from_file("template_config", f"./templates/{template_name}/{template_name}_config.py")
54+
return import_from_file("template_config", f"./templates/{template_name}/sidebar.py")
5555

5656
self.sidebar(self.codegen.template_list, config)
5757

templates/base/base_config.py renamed to templates/base/sidebar.py

+2-13
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,6 @@
33
import streamlit as st
44

55
params = {
6-
"amp_mode": {
7-
"app": ["None", "amp", "apex"],
8-
"test": ["None", "amp", "apex"],
9-
},
10-
"device": {
11-
"app": ["cpu", "cuda", "xla"],
12-
"test": ["cpu", "cuda"],
13-
},
146
"data_path": {
157
"app": {"value": "./"},
168
"test": {"prefix": "tmp", "suffix": ""},
@@ -20,11 +12,11 @@
2012
"test": {"prefix": "tmp", "suffix": ""},
2113
},
2214
"train_batch_size": {
23-
"app": {"min_value": 1, "value": 1},
15+
"app": {"min_value": 1, "value": 4},
2416
"test": {"min_value": 1, "max_value": 2},
2517
},
2618
"eval_batch_size": {
27-
"app": {"min_value": 1, "value": 1},
19+
"app": {"min_value": 1, "value": 4},
2820
"test": {"min_value": 1, "max_value": 2},
2921
},
3022
"num_workers": {
@@ -96,9 +88,6 @@ def get_configs() -> dict:
9688
st.info("Common base training configurations. Those in the parenthesis are used in the code.")
9789

9890
# group by streamlit function type
99-
config["amp_mode"] = st.selectbox("AMP mode (amp_mode)", params.amp_mode.app)
100-
config["device"] = st.selectbox("Device to use (device)", params.device.app)
101-
10291
config["data_path"] = st.text_input("Dataset path (data_path)", **params.data_path.app)
10392
config["filepath"] = st.text_input("Logging file path (filepath)", **params.filepath.app)
10493

templates/base/utils.py.jinja

-10
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,6 @@ import torch
88

99
{% block get_default_parser %}
1010
DEFAULTS = {
11-
"amp_mode": {
12-
"default": "{{ amp_mode }}",
13-
"type": str,
14-
"help": "automatic mixed precision mode to use: `amp` or `apex` ({{ amp_mode }})",
15-
},
1611
"train_batch_size": {
1712
"default": {{ train_batch_size }},
1813
"type": int,
@@ -23,11 +18,6 @@ DEFAULTS = {
2318
"type": str,
2419
"help": "datasets path ({{ data_path }})",
2520
},
26-
"device": {
27-
"default": "{{ device }}",
28-
"type": torch.device,
29-
"help": "device to use for training / evaluation / testing ({{ device }})",
30-
},
3121
"filepath": {
3222
"default": "{{ filepath }}",
3323
"type": str,

templates/image_classification/fn.py.jinja

-1
This file was deleted.

templates/image_classification/generate_metadata.py

-30
This file was deleted.

templates/image_classification/image_classification_config.py

-19
This file was deleted.
+126-17
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,137 @@
1-
{% extends "base/main.py.jinja" %}
2-
{% block datasets_and_dataloaders %}
3-
train_dataset, eval_dataset = get_datasets(root=config.data_path)
1+
{% block imports %}
2+
from argparse import ArgumentParser
3+
from datetime import datetime
4+
from pathlib import Path
5+
from typing import Any
6+
7+
import logging
8+
9+
import ignite.distributed as idist
10+
from ignite.engine import create_supervised_evaluator, create_supervised_trainer
11+
from ignite.engine.events import Events
12+
from ignite.utils import setup_logger, manual_seed
13+
from ignite.metrics import Accuracy, Loss
14+
15+
from datasets import get_datasets, get_data_loaders
16+
from utils import log_metrics, get_default_parser, initialize, setup_common_handlers, setup_exp_logging
17+
{% endblock %}
18+
19+
20+
{% block run %}
21+
def run(local_rank: int, config: Any, *args: Any, **kwags: Any):
22+
23+
# -----------------------------
24+
# datasets and dataloaders
25+
# -----------------------------
26+
{% block datasets_and_dataloaders %}
27+
train_dataset, eval_dataset = get_datasets(config.data_path)
428
train_dataloader, eval_dataloader = get_data_loaders(
529
train_dataset=train_dataset,
630
eval_dataset=eval_dataset,
731
train_batch_size=config.train_batch_size,
832
eval_batch_size=config.eval_batch_size,
933
num_workers=config.num_workers,
1034
)
11-
{% endblock %}
35+
{% endblock %}
1236

13-
{% block model_optimizer_loss %}
14-
model = idist.auto_model(get_model(config.model_name))
15-
optimizer = idist.auto_optim(optim.Adam(model.parameters(), lr=config.lr))
16-
loss_fn = nn.CrossEntropyLoss()
17-
{% endblock %}
37+
# ------------------------------------------
38+
# model, optimizer, loss function, device
39+
# ------------------------------------------
40+
{% block model_optimizer_loss %}
41+
device, model, optimizer, loss_fn = initialize(config)
42+
{% endblock %}
43+
44+
# ----------------------
45+
# train / eval engine
46+
# ----------------------
47+
{% block engines %}
48+
train_engine = create_supervised_trainer(
49+
model=model,
50+
optimizer=optimizer,
51+
loss_fn=loss_fn,
52+
device=device,
53+
output_transform=lambda x, y, y_pred, loss: {'train_loss': loss.item()},
54+
)
55+
metrics = {
56+
'eval_accuracy': Accuracy(device=device),
57+
'eval_loss': Loss(loss_fn=loss_fn, device=device)
58+
}
59+
eval_engine = create_supervised_evaluator(
60+
model=model,
61+
metrics=metrics,
62+
device=device,
63+
)
64+
{% endblock %}
1865

19-
{% block metrics %}
20-
Accuracy(device=config.device).attach(eval_engine, "eval_accuracy")
66+
# ---------------
67+
# setup logging
68+
# ---------------
69+
{% block loggers %}
70+
name = f"bs{config.train_batch_size}-lr{config.lr}-{optimizer.__class__.__name__}"
71+
now = datetime.now().strftime("%Y%m%d-%X")
72+
train_engine.logger = setup_logger("trainer", level=config.verbose, filepath=config.filepath / f"{name}-{now}.log")
73+
eval_engine.logger = setup_logger("evaluator", level=config.verbose, filepath=config.filepath / f"{name}-{now}.log")
74+
{% endblock %}
75+
76+
# -----------------------------------------
77+
# checkpoint and common training handlers
78+
# -----------------------------------------
79+
{% block eval_ckpt_common_training %}
80+
eval_ckpt_handler = setup_common_handlers(
81+
config=config,
82+
eval_engine=eval_engine,
83+
train_engine=train_engine,
84+
model=model,
85+
optimizer=optimizer
86+
)
87+
{% endblock %}
88+
89+
# --------------------------------
90+
# setup common experiment loggers
91+
# --------------------------------
92+
{% block exp_loggers %}
93+
exp_logger = setup_exp_logging(
94+
config=config,
95+
eval_engine=eval_engine,
96+
train_engine=train_engine,
97+
optimizer=optimizer,
98+
name=name
99+
)
100+
{% endblock %}
101+
102+
# ----------------------
103+
# engines log and run
104+
# ----------------------
105+
{% block engines_run_and_log %}
106+
{% block log_training_results %}
107+
@train_engine.on(Events.ITERATION_COMPLETED(every=config.log_train))
108+
def log_training_results(engine):
109+
train_engine.state.metrics = train_engine.state.output
110+
log_metrics(train_engine, "Train", device)
111+
{% endblock %}
112+
113+
{% block run_eval_engine_and_log %}
114+
@train_engine.on(Events.EPOCH_COMPLETED(every=config.log_eval))
115+
def run_eval_engine_and_log(engine):
116+
eval_engine.run(
117+
eval_dataloader,
118+
max_epochs=config.eval_max_epochs,
119+
epoch_length=config.eval_epoch_length
120+
)
121+
log_metrics(eval_engine, "Eval", device)
122+
{% endblock %}
123+
124+
train_engine.run(
125+
train_dataloader,
126+
max_epochs=config.train_max_epochs,
127+
epoch_length=config.train_epoch_length
128+
)
129+
{% endblock %}
21130
{% endblock %}
22131

23132
{% block main_fn %}
24133
def main():
25134
parser = ArgumentParser(parents=[get_default_parser()])
26-
parser.add_argument(
27-
"--model_name",
28-
default="{{ model_name }}",
29-
type=str,
30-
help="Image classification model name ({{ model_name}})"
31-
)
32135
config = parser.parse_args()
33136
manual_seed(config.seed)
34137
config.verbose = logging.INFO if config.verbose else logging.WARNING
@@ -46,3 +149,9 @@ def main():
46149
) as parallel:
47150
parallel.run(run, config=config)
48151
{% endblock %}
152+
153+
154+
{% block entrypoint %}
155+
if __name__ == "__main__":
156+
main()
157+
{% endblock %}

templates/image_classification/metadata.json

-51
This file was deleted.

0 commit comments

Comments
 (0)