Skip to content

Commit 04787b1

Browse files
author
ydcjeff
committed
fix: all in 1 deps install, amp option fix #32
1 parent 11f5212 commit 04787b1

File tree

7 files changed

+74
-49
lines changed

7 files changed

+74
-49
lines changed

requirements-dev.txt

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
-r requirements.txt
2+
13
# dev
24
pytorch-ignite
35
torch

templates/_base/_argparse.pyi

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ from argparse import ArgumentParser
44

55
{% block defaults %}
66
DEFAULTS = {
7+
"use_amp": {
8+
"action": "store_true",
9+
"help": "use torch.cuda.amp for automatic mixed precision"
10+
},
711
"seed": {
812
"default": 666,
913
"type": int,

templates/_base/_handlers.pyi

+3-1
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def get_logger(
157157
Config object for setting up loggers
158158
159159
`config` has to contain
160+
- `filepath`: logging path to output file
160161
- `logger_log_every_iters`: logging iteration interval for loggers
161162
162163
train_engine
@@ -232,5 +233,6 @@ def get_logger(
232233
**kwargs,
233234
)
234235
{% else %}
235-
return None
236+
logger_handler = None
236237
{% endif %}
238+
return logger_handler

templates/single/_sidebar.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def get_configs():
1919
default_none_options(config)
2020

2121
with st.beta_expander("Single Model, Single Optimizer Template Configurations", expanded=True):
22-
st.info("Those in the parenthesis are used in the generated code.")
22+
st.info("Names in the parenthesis are variable names used in the generated code.")
2323

2424
# group by configurations type
2525
distributed_options(config)

templates/single/single_cg/engines.pyi

+20-16
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,24 @@
22
`train_engine` and `eval_engine` like trainer and evaluator
33
"""
44
from typing import Any, Tuple
5+
56
import torch
67
from ignite.engine import Engine
8+
from torch.cuda.amp import autocast
79
from torch.optim.optimizer import Optimizer
810

911
from single_cg.events import TrainEvents, train_events_to_attr
1012

1113

1214
# Edit below functions the way how the model will be training
1315

14-
# train_fn is how the model will be learning with given batch
15-
# below in the train_fn, common parameters are provided
16+
# train_function is how the model will be learning with given batch
17+
# below in the train_function, common parameters are provided
1618
# you can add any additional parameters depending on the training
1719
# NOTE : engine and batch parameters are needed to work with
1820
# Ignite's Engine.
1921
# TODO: Extend with your custom training.
20-
def train_fn(
22+
def train_function(
2123
config: Any,
2224
engine: Engine,
2325
batch: Any,
@@ -26,7 +28,7 @@ def train_fn(
2628
optimizer: Optimizer,
2729
device: torch.device,
2830
):
29-
"""Training.
31+
"""Model training step.
3032
3133
Parameters
3234
----------
@@ -48,8 +50,9 @@ def train_fn(
4850
samples = batch[0].to(device, non_blocking=True)
4951
targets = batch[1].to(device, non_blocking=True)
5052

51-
outputs = model(samples)
52-
loss = loss_fn(outputs, targets)
53+
with autocast(enabled=config.use_amp):
54+
outputs = model(samples)
55+
loss = loss_fn(outputs, targets)
5356

5457
loss.backward()
5558
engine.state.backward_completed += 1
@@ -66,22 +69,22 @@ def train_fn(
6669
return loss_value
6770

6871

69-
# evaluate_fn is how the model will be learning with given batch
70-
# below in the evaluate_fn, common parameters are provided
72+
# evaluate_function is how the model will be learning with given batch
73+
# below in the evaluate_function, common parameters are provided
7174
# you can add any additional parameters depending on the training
7275
# NOTE : engine and batch parameters are needed to work with
7376
# Ignite's Engine.
7477
# TODO: Extend with your custom evaluation.
7578
@torch.no_grad()
76-
def evaluate_fn(
79+
def evaluate_function(
7780
config: Any,
7881
engine: Engine,
7982
batch: Any,
8083
model: torch.nn.Module,
8184
loss_fn: torch.nn.Module,
8285
device: torch.device,
8386
):
84-
"""Evaluating.
87+
"""Model evaluating step.
8588
8689
Parameters
8790
----------
@@ -102,10 +105,11 @@ def evaluate_fn(
102105
samples = batch[0].to(device, non_blocking=True)
103106
targets = batch[1].to(device, non_blocking=True)
104107

105-
outputs = model(samples)
106-
loss = loss_fn(outputs, targets)
107-
loss_value = loss.item()
108+
with autocast(enabled=config.use_amp):
109+
outputs = model(samples)
110+
loss = loss_fn(outputs, targets)
108111

112+
loss_value = loss.item()
109113
engine.state.metrics = {"eval_loss": loss_value}
110114
return loss_value
111115

@@ -117,21 +121,21 @@ def create_engines(**kwargs) -> Tuple[Engine, Engine]:
117121
118122
Parameters
119123
----------
120-
kwargs: keyword arguments passed to both train_fn and evaluate_fn
124+
kwargs: keyword arguments passed to both train_function and evaluate_function
121125
122126
Returns
123127
-------
124128
train_engine, eval_engine
125129
"""
126130
train_engine = Engine(
127-
lambda e, b: train_fn(
131+
lambda e, b: train_function(
128132
engine=e,
129133
batch=b,
130134
**kwargs,
131135
)
132136
)
133137
eval_engine = Engine(
134-
lambda e, b: evaluate_fn(
138+
lambda e, b: evaluate_function(
135139
engine=e,
136140
batch=b,
137141
**kwargs,

templates/single/tests/test_engines.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import unittest
2+
from argparse import Namespace
23
from numbers import Number
34
from unittest.mock import MagicMock
45

56
import ignite.distributed as idist
67
import torch
78
from ignite.engine.engine import Engine
8-
from single_cg.engines import create_engines, evaluate_fn, train_fn
9+
from single_cg.engines import create_engines, evaluate_function, train_function
910
from single_cg.events import TrainEvents, train_events_to_attr
1011
from torch import nn, optim
1112

@@ -27,7 +28,8 @@ def test_train_fn(self):
2728
optim = MagicMock()
2829
engine.add_event_handler(TrainEvents.BACKWARD_COMPLETED, backward)
2930
engine.add_event_handler(TrainEvents.OPTIM_STEP_COMPLETED, optim)
30-
output = train_fn(None, engine, self.batch, self.model, self.loss_fn, self.optimizer, self.device)
31+
config = Namespace(use_amp=False)
32+
output = train_function(config, engine, self.batch, self.model, self.loss_fn, self.optimizer, self.device)
3133
self.assertIsInstance(output, Number)
3234
self.assertTrue(hasattr(engine.state, "backward_completed"))
3335
self.assertTrue(hasattr(engine.state, "optim_step_completed"))
@@ -39,7 +41,10 @@ def test_train_fn(self):
3941
self.assertTrue(optim.called)
4042

4143
def test_train_fn_event_filter(self):
42-
engine = Engine(lambda e, b: train_fn(None, e, b, self.model, self.loss_fn, self.optimizer, self.device))
44+
config = Namespace(use_amp=False)
45+
engine = Engine(
46+
lambda e, b: train_function(config, e, b, self.model, self.loss_fn, self.optimizer, self.device)
47+
)
4348
engine.register_events(*TrainEvents, event_to_attr=train_events_to_attr)
4449
backward = MagicMock()
4550
optim = MagicMock()
@@ -60,7 +65,10 @@ def test_train_fn_event_filter(self):
6065
self.assertTrue(optim.called)
6166

6267
def test_train_fn_every(self):
63-
engine = Engine(lambda e, b: train_fn(None, e, b, self.model, self.loss_fn, self.optimizer, self.device))
68+
config = Namespace(use_amp=False)
69+
engine = Engine(
70+
lambda e, b: train_function(config, e, b, self.model, self.loss_fn, self.optimizer, self.device)
71+
)
6472
engine.register_events(*TrainEvents, event_to_attr=train_events_to_attr)
6573
backward = MagicMock()
6674
optim = MagicMock()
@@ -77,7 +85,10 @@ def test_train_fn_every(self):
7785
self.assertTrue(optim.called)
7886

7987
def test_train_fn_once(self):
80-
engine = Engine(lambda e, b: train_fn(None, e, b, self.model, self.loss_fn, self.optimizer, self.device))
88+
config = Namespace(use_amp=False)
89+
engine = Engine(
90+
lambda e, b: train_function(config, e, b, self.model, self.loss_fn, self.optimizer, self.device)
91+
)
8192
engine.register_events(*TrainEvents, event_to_attr=train_events_to_attr)
8293
backward = MagicMock()
8394
optim = MagicMock()
@@ -95,12 +106,13 @@ def test_train_fn_once(self):
95106

96107
def test_evaluate_fn(self):
97108
engine = Engine(lambda e, b: 1)
98-
output = evaluate_fn(None, engine, self.batch, self.model, self.loss_fn, self.device)
109+
config = Namespace(use_amp=False)
110+
output = evaluate_function(config, engine, self.batch, self.model, self.loss_fn, self.device)
99111
self.assertIsInstance(output, Number)
100112

101113
def test_create_engines(self):
102114
train_engine, eval_engine = create_engines(
103-
config=None,
115+
config=Namespace(use_amp=True),
104116
model=self.model,
105117
loss_fn=self.loss_fn,
106118
optimizer=self.optimizer,

templates/single/tests/test_handlers.py

+25-24
Original file line numberDiff line numberDiff line change
@@ -52,30 +52,31 @@ def test_get_handlers(self):
5252
self.assertIsInstance(timer_handler, (type(None), Timer), "Shoulde be Timer or None")
5353

5454
def test_get_logger(self):
55-
config = Namespace(logger_log_every_iters=1)
56-
train_engine = Engine(lambda e, b: b)
57-
optimizer = optim.Adam(nn.Linear(1, 1).parameters())
58-
logger_handler = get_logger(
59-
config=config,
60-
train_engine=train_engine,
61-
eval_engine=train_engine,
62-
optimizers=optimizer,
63-
)
64-
self.assertIsInstance(
65-
logger_handler,
66-
(
67-
BaseLogger,
68-
ClearMLLogger,
69-
MLflowLogger,
70-
NeptuneLogger,
71-
PolyaxonLogger,
72-
TensorboardLogger,
73-
VisdomLogger,
74-
WandBLogger,
75-
type(None),
76-
),
77-
"Should be Ignite provided loggers or None",
78-
)
55+
with TemporaryDirectory() as tmp:
56+
config = Namespace(filepath=tmp, logger_log_every_iters=1)
57+
train_engine = Engine(lambda e, b: b)
58+
optimizer = optim.Adam(nn.Linear(1, 1).parameters())
59+
logger_handler = get_logger(
60+
config=config,
61+
train_engine=train_engine,
62+
eval_engine=train_engine,
63+
optimizers=optimizer,
64+
)
65+
self.assertIsInstance(
66+
logger_handler,
67+
(
68+
BaseLogger,
69+
ClearMLLogger,
70+
MLflowLogger,
71+
NeptuneLogger,
72+
PolyaxonLogger,
73+
TensorboardLogger,
74+
VisdomLogger,
75+
WandBLogger,
76+
type(None),
77+
),
78+
"Should be Ignite provided loggers or None",
79+
)
7980

8081

8182
if __name__ == "__main__":

0 commit comments

Comments
 (0)