1212 import mlflow
1313
1414import torch
15- from ignite .engine import Engine , Events
15+ from ignite .engine import Engine , EventEnum , Events
1616from ignite .handlers import Checkpoint , DiskSaver , global_step_from_engine
1717from ignite .handlers .tqdm_logger import ProgressBar
1818from tensorboardX import SummaryWriter
@@ -103,6 +103,9 @@ def dist_data_loader(
103103 For multiple splits, we return a dictionary where the keys are the names of the splits
104104 and the value is either a Dataloader as described above or the value None if the split
105105 was not configured.
106+
107+ If an iterable dataset is passed, we cannot create multiple splits with a pyTorch sampler object
108+ so we return the same thing for all splits, which is a dataloader representing the entire iterable
106109 """
107110 # Handle case where no split is needed.
108111 if isinstance (split , bool ):
@@ -118,18 +121,25 @@ def dist_data_loader(
118121 if seed is not None :
119122 torch_rng .manual_seed (seed )
120123
121- # Create the indexes for all splits based on config.
122- indexes = create_splits (data_set , config )
123-
124- # Create samplers and dataloaders for each split we are interested in
125- samplers = {
126- s : SubsetRandomSampler (indexes [s ], generator = torch_rng ) if indexes .get (s ) else None for s in split
127- }
128-
129- dataloaders = {
130- split : idist .auto_dataloader (data_set , sampler = sampler , ** config ["data_loader" ]) if sampler else None
131- for split , sampler in samplers .items ()
132- }
124+ if data_set .is_iterable ():
125+ dataloaders = {
126+ s : idist .auto_dataloader (data_set , pin_memory = True , ** config ["data_loader" ]) for s in split
127+ }
128+ else :
129+ # Create the indexes for all splits based on config.
130+ indexes = create_splits (data_set , config )
131+
132+ # Create samplers and dataloaders for each split we are interested in
133+ samplers = {
134+ s : SubsetRandomSampler (indexes [s ], generator = torch_rng ) if indexes .get (s ) else None for s in split
135+ }
136+
137+ dataloaders = {
138+ split : idist .auto_dataloader (data_set , sampler = sampler , ** config ["data_loader" ])
139+ if sampler
140+ else None
141+ for split , sampler in samplers .items ()
142+ }
133143
134144 # Return only one if we were only passed one split in, return the dictionary otherwise.
135145 return dataloaders [split [0 ]] if len (split ) == 1 else dataloaders
@@ -363,6 +373,7 @@ def create_validator(
363373 model = idist .auto_model (model )
364374
365375 validator = create_engine ("train_step" , device , model )
376+ fixup_engine (validator )
366377
367378 @validator .on (Events .STARTED )
368379 def set_model_to_eval_mode ():
@@ -372,12 +383,12 @@ def set_model_to_eval_mode():
372383 def set_model_to_train_mode ():
373384 model .train ()
374385
375- @validator .on (Events . EPOCH_COMPLETED )
386+ @validator .on (HyraxEvents . HYRAX_EPOCH_COMPLETED )
376387 def log_training_loss ():
377388 logger .debug (f"Validation run time: { validator .state .times ['EPOCH_COMPLETED' ]:.2f} [s]" )
378389 logger .debug (f"Validation metrics: { validator .state .output } " )
379390
380- @trainer .on (Events . EPOCH_COMPLETED )
391+ @trainer .on (HyraxEvents . HYRAX_EPOCH_COMPLETED )
381392 def run_validation ():
382393 validator .run (validation_data_loader )
383394
@@ -386,7 +397,7 @@ def log_validation_loss(validator, trainer):
386397 tensorboardx_logger .add_scalar ("training/validation/loss" , validator .state .output ["loss" ], step )
387398 mlflow .log_metrics ({"validation/loss" : validator .state .output ["loss" ]}, step = step )
388399
389- validator .add_event_handler (Events . EPOCH_COMPLETED , log_validation_loss , trainer )
400+ validator .add_event_handler (HyraxEvents . HYRAX_EPOCH_COMPLETED , log_validation_loss , trainer )
390401
391402 return validator
392403
@@ -419,6 +430,7 @@ def create_trainer(
419430 model .train ()
420431 model = idist .auto_model (model )
421432 trainer = create_engine ("train_step" , device , model )
433+ fixup_engine (trainer )
422434
423435 optimizer = extract_model_method (model , "optimizer" )
424436
@@ -435,18 +447,19 @@ def create_trainer(
435447 to_save ,
436448 DiskSaver (results_directory , require_empty = False ),
437449 n_saved = 1 ,
438- global_step_transform = global_step_from_engine (trainer ),
450+ global_step_transform = global_step_from_engine (trainer , Events . EPOCH_COMPLETED ),
439451 filename_pattern = "{name}_epoch_{global_step}.{ext}" ,
440452 )
441453
442454 def neg_loss_score (engine ):
455+ print (engine .state )
443456 return - engine .state .output ["loss" ]
444457
445458 best_checkpoint = Checkpoint (
446459 to_save ,
447460 DiskSaver (results_directory , require_empty = False ),
448461 n_saved = 1 ,
449- global_step_transform = global_step_from_engine (trainer ),
462+ global_step_transform = global_step_from_engine (trainer , Events . EPOCH_COMPLETED ),
450463 score_name = "loss" ,
451464 score_function = neg_loss_score ,
452465 greater_or_equal = True ,
@@ -473,13 +486,13 @@ def log_training_loss_tensorboard(trainer):
473486 tensorboardx_logger .add_scalar ("training/training/loss" , trainer .state .output ["loss" ], step )
474487 mlflow .log_metrics ({"training/loss" : trainer .state .output ["loss" ]}, step = step )
475488
476- @trainer .on (Events . EPOCH_COMPLETED )
489+ @trainer .on (HyraxEvents . HYRAX_EPOCH_COMPLETED )
477490 def log_training_loss (trainer ):
478491 logger .debug (f"Epoch { trainer .state .epoch } run time: { trainer .state .times ['EPOCH_COMPLETED' ]:.2f} [s]" )
479492 logger .debug (f"Epoch { trainer .state .epoch } metrics: { trainer .state .output } " )
480493
481- trainer .add_event_handler (Events . EPOCH_COMPLETED , latest_checkpoint )
482- trainer .add_event_handler (Events . EPOCH_COMPLETED , best_checkpoint )
494+ trainer .add_event_handler (HyraxEvents . HYRAX_EPOCH_COMPLETED , latest_checkpoint )
495+ trainer .add_event_handler (HyraxEvents . HYRAX_EPOCH_COMPLETED , best_checkpoint )
483496
484497 @trainer .on (Events .COMPLETED )
485498 def log_total_time (trainer ):
@@ -498,3 +511,38 @@ def log_best_checkpoint_location(_, best_checkpoint):
498511 pbar .attach (trainer )
499512
500513 return trainer
514+
515+
516+ class HyraxEvents (EventEnum ):
517+ """
518+ Workaround event for a pytorch ignite bug. See fixup_engine for details
519+ """
520+
521+ HYRAX_EPOCH_COMPLETED = "HyraxEpochCompleted"
522+
523+
524+ def fixup_engine (engine : Engine ) -> Engine :
525+ """
526+ Workaround for this pytorch ignite bug (https://github.com/pytorch/ignite/issues/3372) where
527+ engine.state.output is not available at EPOCH_COMPLETED or later times (COMPLETED, etc)
528+
529+ We create a new event HYRAX_EPOCH_COMPLETED which triggers at ITERATION_COMPLETED, but only on the final
530+ iteration. This is just before the erronious state reset.
531+
532+ This hack relies on pytorch ignite internal state, but can be removed as soon as our fix is mainlined
533+ (https://github.com/pytorch/ignite/pull/3373) in version 0.6.0 estimated August 2025
534+ """
535+ from more_itertools import peekable
536+
537+ engine .register_events (* HyraxEvents )
538+
539+ @engine .on (Events .ITERATION_COMPLETED )
540+ def maintain_event_handler (engine ):
541+ # Ensure we have a peekable iterator in the engine.
542+ if not hasattr (engine ._dataloader_iter , "peek" ):
543+ # Replace with a pass-through peekable iterator
544+ engine ._dataloader_iter = peekable (engine ._dataloader_iter )
545+
546+ # On the last iteration the peekable iterator evaluates as true
547+ if not engine ._dataloader_iter :
548+ engine .fire_event (HyraxEvents .HYRAX_EPOCH_COMPLETED )
0 commit comments