@@ -1278,6 +1278,45 @@ def check_iter_epoch(first_epoch_iter):
12781278 engine .run (data , max_epochs = 10 , epoch_length = epoch_length )
12791279 assert engine .state .epoch == 10
12801280 assert engine .state .iteration == 10 * real_epoch_length
1281+
1282+ def test_iterator_state_output (self ):
1283+ torch .manual_seed (12 )
1284+
1285+ def finite_iterator (length , batch_size ):
1286+ for _ in range (length ):
1287+ batch = torch .rand (batch_size , 3 , 32 , 32 )
1288+ yield batch
1289+
1290+ def train_step (trainer , batch ):
1291+ s = trainer .state
1292+ print (f"{ s .epoch } /{ s .max_epochs } : { s .iteration } - { batch .norm ():.3f} " )
1293+ return "flag_value"
1294+
1295+ trainer = Engine (train_step )
1296+ trainer .run (finite_iterator (4 ,4 ), max_epochs = 2 )
1297+
1298+ assert trainer .state .output == "flag_value"
1299+ assert trainer .state .epoch == 2
1300+ #assert trainer.state.iteration == 2*4
1301+
1302+ def test_map_state_output (self ):
1303+ torch .manual_seed (12 )
1304+
1305+ batch_size = 4
1306+ finite_map = [torch .rand (batch_size , 3 , 32 ,32 ) for _ in range (4 )]
1307+
1308+ def train_step (trainer , batch ):
1309+ s = trainer .state
1310+ print (f"{ s .epoch } /{ s .max_epochs } : { s .iteration } - { batch .norm ():.3f} " )
1311+ return "flag_value"
1312+
1313+ trainer = Engine (train_step )
1314+ trainer .run (finite_map , max_epochs = 2 )
1315+
1316+ assert trainer .state .output == "flag_value"
1317+ assert trainer .state .epoch == 2
1318+ assert trainer .state .iteration == 2 * 4
1319+
12811320
12821321
12831322@pytest .mark .parametrize (
0 commit comments