diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index f3f95c9a2e27..c814f770e77a 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -1069,7 +1069,7 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]: ) while True: - self.state.batch = self.state.output = None + self.state.batch = None try: # Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted @@ -1081,6 +1081,9 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]: yield from self._maybe_terminate_or_interrupt() self.state.batch = next(self._dataloader_iter) + # We on purpose reset state.output here as for iterable dataloaders + # we accidentally can remove it when one epoch is completed. + self.state.output = None # We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events # if no data was provided to engine.run(data=None, ...) @@ -1254,7 +1257,7 @@ def _run_once_on_dataset_legacy(self) -> float: ) while True: - self.state.batch = self.state.output = None + self.state.batch = None try: # Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted if self.last_event_name != Events.DATALOADER_STOP_ITERATION: @@ -1265,6 +1268,10 @@ def _run_once_on_dataset_legacy(self) -> float: self._maybe_terminate_legacy() self.state.batch = next(self._dataloader_iter) + # We on purpose reset state.output here as for iterable dataloaders + # we accidentally can remove it when one epoch is completed. + self.state.output = None + # We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events # if no data was provided to engine.run(data=None, ...) if self.state.dataloader is not None: diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index 76e1ad837605..8748e03f574c 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -1279,6 +1279,44 @@ def check_iter_epoch(first_epoch_iter): assert engine.state.epoch == 10 assert engine.state.iteration == 10 * real_epoch_length + def test_iterator_state_output(self): + torch.manual_seed(12) + + def finite_iterator(length, batch_size): + for _ in range(length): + batch = torch.rand(batch_size, 3, 32, 32) + yield batch + + def train_step(trainer, batch): + s = trainer.state + print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}") + return "flag_value" + + trainer = Engine(train_step) + trainer.run(finite_iterator(4, 4), max_epochs=2) + + assert trainer.state.output == "flag_value" + assert trainer.state.epoch == 2 + # assert trainer.state.iteration == 2*4 + + def test_map_state_output(self): + torch.manual_seed(12) + + batch_size = 4 + finite_map = [torch.rand(batch_size, 3, 32, 32) for _ in range(4)] + + def train_step(trainer, batch): + s = trainer.state + print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}") + return "flag_value" + + trainer = Engine(train_step) + trainer.run(finite_map, max_epochs=2) + + assert trainer.state.output == "flag_value" + assert trainer.state.epoch == 2 + assert trainer.state.iteration == 2 * 4 + @pytest.mark.parametrize( "interrupt_event, e, i",