1
1
import unittest
2
+ from argparse import Namespace
2
3
from numbers import Number
3
4
from unittest .mock import MagicMock
4
5
5
6
import ignite .distributed as idist
6
7
import torch
7
8
from ignite .engine .engine import Engine
8
- from single_cg .engines import create_engines , evaluate_fn , train_fn
9
+ from single_cg .engines import create_engines , evaluate_function , train_function
9
10
from single_cg .events import TrainEvents , train_events_to_attr
10
11
from torch import nn , optim
11
12
@@ -27,7 +28,8 @@ def test_train_fn(self):
27
28
optim = MagicMock ()
28
29
engine .add_event_handler (TrainEvents .BACKWARD_COMPLETED , backward )
29
30
engine .add_event_handler (TrainEvents .OPTIM_STEP_COMPLETED , optim )
30
- output = train_fn (None , engine , self .batch , self .model , self .loss_fn , self .optimizer , self .device )
31
+ config = Namespace (use_amp = False )
32
+ output = train_function (config , engine , self .batch , self .model , self .loss_fn , self .optimizer , self .device )
31
33
self .assertIsInstance (output , Number )
32
34
self .assertTrue (hasattr (engine .state , "backward_completed" ))
33
35
self .assertTrue (hasattr (engine .state , "optim_step_completed" ))
@@ -39,7 +41,10 @@ def test_train_fn(self):
39
41
self .assertTrue (optim .called )
40
42
41
43
def test_train_fn_event_filter (self ):
42
- engine = Engine (lambda e , b : train_fn (None , e , b , self .model , self .loss_fn , self .optimizer , self .device ))
44
+ config = Namespace (use_amp = False )
45
+ engine = Engine (
46
+ lambda e , b : train_function (config , e , b , self .model , self .loss_fn , self .optimizer , self .device )
47
+ )
43
48
engine .register_events (* TrainEvents , event_to_attr = train_events_to_attr )
44
49
backward = MagicMock ()
45
50
optim = MagicMock ()
@@ -60,7 +65,10 @@ def test_train_fn_event_filter(self):
60
65
self .assertTrue (optim .called )
61
66
62
67
def test_train_fn_every (self ):
63
- engine = Engine (lambda e , b : train_fn (None , e , b , self .model , self .loss_fn , self .optimizer , self .device ))
68
+ config = Namespace (use_amp = False )
69
+ engine = Engine (
70
+ lambda e , b : train_function (config , e , b , self .model , self .loss_fn , self .optimizer , self .device )
71
+ )
64
72
engine .register_events (* TrainEvents , event_to_attr = train_events_to_attr )
65
73
backward = MagicMock ()
66
74
optim = MagicMock ()
@@ -77,7 +85,10 @@ def test_train_fn_every(self):
77
85
self .assertTrue (optim .called )
78
86
79
87
def test_train_fn_once (self ):
80
- engine = Engine (lambda e , b : train_fn (None , e , b , self .model , self .loss_fn , self .optimizer , self .device ))
88
+ config = Namespace (use_amp = False )
89
+ engine = Engine (
90
+ lambda e , b : train_function (config , e , b , self .model , self .loss_fn , self .optimizer , self .device )
91
+ )
81
92
engine .register_events (* TrainEvents , event_to_attr = train_events_to_attr )
82
93
backward = MagicMock ()
83
94
optim = MagicMock ()
@@ -95,12 +106,13 @@ def test_train_fn_once(self):
95
106
96
107
def test_evaluate_fn (self ):
97
108
engine = Engine (lambda e , b : 1 )
98
- output = evaluate_fn (None , engine , self .batch , self .model , self .loss_fn , self .device )
109
+ config = Namespace (use_amp = False )
110
+ output = evaluate_function (config , engine , self .batch , self .model , self .loss_fn , self .device )
99
111
self .assertIsInstance (output , Number )
100
112
101
113
def test_create_engines (self ):
102
114
train_engine , eval_engine = create_engines (
103
- config = None ,
115
+ config = Namespace ( use_amp = True ) ,
104
116
model = self .model ,
105
117
loss_fn = self .loss_fn ,
106
118
optimizer = self .optimizer ,
0 commit comments