Skip to content

Commit a84bc24

Browse files
committed
Tests found another bug in epoch counting for iterable datasets
1 parent 396af13 commit a84bc24

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

tests/ignite/engine/test_engine.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,6 +1278,45 @@ def check_iter_epoch(first_epoch_iter):
12781278
engine.run(data, max_epochs=10, epoch_length=epoch_length)
12791279
assert engine.state.epoch == 10
12801280
assert engine.state.iteration == 10 * real_epoch_length
1281+
1282+
def test_iterator_state_output(self):
1283+
torch.manual_seed(12)
1284+
1285+
def finite_iterator(length, batch_size):
1286+
for _ in range(length):
1287+
batch = torch.rand(batch_size, 3, 32, 32)
1288+
yield batch
1289+
1290+
def train_step(trainer, batch):
1291+
s = trainer.state
1292+
print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}")
1293+
return "flag_value"
1294+
1295+
trainer = Engine(train_step)
1296+
trainer.run(finite_iterator(4,4), max_epochs=2)
1297+
1298+
assert trainer.state.output == "flag_value"
1299+
assert trainer.state.epoch == 2
1300+
#assert trainer.state.iteration == 2*4
1301+
1302+
def test_map_state_output(self):
1303+
torch.manual_seed(12)
1304+
1305+
batch_size = 4
1306+
finite_map = [torch.rand(batch_size, 3, 32,32) for _ in range(4)]
1307+
1308+
def train_step(trainer, batch):
1309+
s = trainer.state
1310+
print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}")
1311+
return "flag_value"
1312+
1313+
trainer = Engine(train_step)
1314+
trainer.run(finite_map, max_epochs=2)
1315+
1316+
assert trainer.state.output == "flag_value"
1317+
assert trainer.state.epoch == 2
1318+
assert trainer.state.iteration == 2*4
1319+
12811320

12821321

12831322
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)