@@ -60,10 +60,15 @@ def model_filename(self) -> str:
6060 """ str: The full model filename """
6161 return os .path .join (self ._model_dir , self ._model_name )
6262
63+ @property
64+ def have_session_data (self ) -> bool :
65+ """ bool : ``True`` if session data is available otherwise ``False`` """
66+ return bool (self ._state and self ._state ["sessions" ])
67+
6368 @property
6469 def batch_sizes (self ) -> dict [int , int ]:
6570 """ dict: The batch sizes for each session_id for the model. """
66- if not self ._state :
71+ if not self .have_session_data :
6772 return {}
6873 return {int (sess_id ): sess ["batchsize" ]
6974 for sess_id , sess in self ._state .get ("sessions" , {}).items ()}
@@ -76,9 +81,9 @@ def full_summary(self) -> list[dict]:
7681
7782 @property
7883 def logging_disabled (self ) -> bool :
79- """ bool: ``True`` if logging is enabled for the currently training session otherwise
84+ """ bool: ``True`` if logging is disabled for the currently training session otherwise
8085 ``False``. """
81- if not self ._state :
86+ if not self .have_session_data :
8287 return True
8388 max_id = str (max (int (idx ) for idx in self ._state ["sessions" ]))
8489 return self ._state ["sessions" ][max_id ]["no_logs" ]
@@ -311,6 +316,10 @@ def get_summary_stats(self) -> list[dict]:
311316 within the loaded data as well as the totals.
312317 """
313318 logger .debug ("Compiling sessions summary data" )
319+ if not self ._session .have_session_data :
320+ logger .debug ("Session data doesn't exist. Most likely task has been "
321+ "terminated during compilation, or is from LR finder" )
322+ return []
314323 self ._get_time_stats ()
315324 self ._get_per_session_stats ()
316325 if not self ._per_session_stats :
@@ -365,9 +374,9 @@ def _get_per_session_stats(self) -> None:
365374 compiled = []
366375 for session_id in self ._time_stats :
367376 logger .debug ("Compiling session ID: %s" , session_id )
368- if not self ._state :
369- logger .debug ("Session state dict doesn't exist. Most likely task has been "
370- "terminated during compilation" )
377+ if not self ._session . have_session_data :
378+ logger .debug ("Session data doesn't exist. Most likely task has been "
379+ "terminated during compilation, or is from LR finder " )
371380 return
372381 compiled .append (self ._collate_stats (session_id ))
373382
@@ -435,6 +444,8 @@ def _total_stats(self) -> dict[str, str | int | float]:
435444 iterations for all session ids within the loaded data.
436445 """
437446 logger .debug ("Compiling Totals" )
447+ starttime = 0.0
448+ endtime = 0.0
438449 elapsed = 0
439450 examples = 0
440451 iterations = 0
@@ -450,13 +461,14 @@ def _total_stats(self) -> dict[str, str | int | float]:
450461 batchset .add (summary ["batch" ])
451462 iterations += summary ["iterations" ]
452463 batch = "," .join (str (bs ) for bs in batchset )
453- totals = {"session" : "Total" ,
454- "start" : starttime ,
455- "end" : endtime ,
456- "elapsed" : elapsed ,
457- "rate" : examples / elapsed if elapsed != 0 else 0 ,
458- "batch" : batch ,
459- "iterations" : iterations }
464+ totals : dict [str , str | int | float ] = {
465+ "session" : "Total" ,
466+ "start" : starttime ,
467+ "end" : endtime ,
468+ "elapsed" : elapsed ,
469+ "rate" : examples / elapsed if elapsed != 0 else 0 ,
470+ "batch" : batch ,
471+ "iterations" : iterations }
460472 logger .debug (totals )
461473 return totals
462474
@@ -533,7 +545,7 @@ class Calculations():
533545 ``True`` if values significantly away from the average should be excluded, otherwise
534546 ``False``. Default: ``False``
535547 """
536- def __init__ (self , session_id ,
548+ def __init__ (self , session_id , # pylint:disable=too-many-positional-arguments
537549 display : str = "loss" ,
538550 loss_keys : list [str ] | str = "loss" ,
539551 selections : list [str ] | str = "raw" ,
0 commit comments