6
6
7
7
import gzip
8
8
import json
9
+ import logging
9
10
import time
10
11
import warnings
11
12
from collections .abc import Iterable
17
18
from matplotlib import colors as mcolors
18
19
from pytorch3d .implicitron .tools .vis_utils import get_visdom_connection
19
20
21
+ logger = logging .getLogger (__name__ )
22
+
20
23
21
24
class AverageMeter (object ):
22
25
"""Computes and stores the average and current value"""
@@ -91,7 +94,9 @@ class Stats(object):
91
94
# stats.update() automatically parses the 'objective' and 'top1e' from
92
95
# the "output" dict and stores this into the db
93
96
stats.update(output)
94
- stats.print() # prints the averages over given epoch
97
+ # prints the metric averages over given epoch
98
+ std_out = stats.get_status_string()
99
+ logger.info(str_out)
95
100
# stores the training plots into '/tmp/epoch_stats.pdf'
96
101
# and plots into a visdom server running at localhost (if running)
97
102
stats.plot_stats(plot_file='/tmp/epoch_stats.pdf')
@@ -101,7 +106,6 @@ class Stats(object):
101
106
def __init__ (
102
107
self ,
103
108
log_vars ,
104
- verbose = False ,
105
109
epoch = - 1 ,
106
110
visdom_env = "main" ,
107
111
do_plot = True ,
@@ -110,7 +114,6 @@ def __init__(
110
114
visdom_port = 8097 ,
111
115
):
112
116
113
- self .verbose = verbose
114
117
self .log_vars = log_vars
115
118
self .visdom_env = visdom_env
116
119
self .visdom_server = visdom_server
@@ -156,32 +159,29 @@ def __exit__(self, type, value, traceback):
156
159
iserr = type is not None and issubclass (type , Exception )
157
160
iserr = iserr or (type is KeyboardInterrupt )
158
161
if iserr :
159
- print ("error inside 'with' block" )
162
+ logger . error ("error inside 'with' block" )
160
163
return
161
164
if self .do_plot :
162
165
self .plot_stats (self .visdom_env )
163
166
164
167
def reset (self ): # to be called after each epoch
165
168
stat_sets = list (self .stats .keys ())
166
- if self .verbose :
167
- print ("stats: epoch %d - reset" % self .epoch )
169
+ logger .debug (f"stats: epoch { self .epoch } - reset" )
168
170
self .it = {k : - 1 for k in stat_sets }
169
171
for stat_set in stat_sets :
170
172
for stat in self .stats [stat_set ]:
171
173
self .stats [stat_set ][stat ].reset ()
172
174
173
175
def hard_reset (self , epoch = - 1 ): # to be called during object __init__
174
176
self .epoch = epoch
175
- if self .verbose :
176
- print ("stats: epoch %d - hard reset" % self .epoch )
177
+ logger .debug (f"stats: epoch { self .epoch } - hard reset" )
177
178
self .stats = {}
178
179
179
180
# reset
180
181
self .reset ()
181
182
182
183
def new_epoch (self ):
183
- if self .verbose :
184
- print ("stats: new epoch %d" % (self .epoch + 1 ))
184
+ logger .debug (f"stats: new epoch { (self .epoch + 1 )} " )
185
185
self .epoch += 1
186
186
self .reset () # zero the stats + increase epoch counter
187
187
@@ -193,18 +193,17 @@ def gather_value(self, val):
193
193
val = float (val .sum ())
194
194
return val
195
195
196
- def add_log_vars (self , added_log_vars , verbose = True ):
196
+ def add_log_vars (self , added_log_vars ):
197
197
for add_log_var in added_log_vars :
198
198
if add_log_var not in self .stats :
199
- if verbose :
200
- print (f"Adding { add_log_var } " )
199
+ logger .debug (f"Adding { add_log_var } " )
201
200
self .log_vars .append (add_log_var )
202
201
203
202
def update (self , preds , time_start = None , freeze_iter = False , stat_set = "train" ):
204
203
205
204
if self .epoch == - 1 : # uninitialized
206
- print (
207
- "warning: epoch==-1 means uninitialized stats structure -> new_epoch() called"
205
+ logger . warning (
206
+ "epoch==-1 means uninitialized stats structure -> new_epoch() called"
208
207
)
209
208
self .new_epoch ()
210
209
@@ -284,6 +283,12 @@ def print(
284
283
skip_nan = False ,
285
284
stat_format = lambda s : s .replace ("loss_" , "" ).replace ("prev_stage_" , "ps_" ),
286
285
):
286
+ """
287
+ stats.print() is deprecated. Please use get_status_string() instead.
288
+ example:
289
+ std_out = stats.get_status_string()
290
+ logger.info(str_out)
291
+ """
287
292
288
293
epoch = self .epoch
289
294
stats = self .stats
@@ -311,8 +316,30 @@ def print(
311
316
if get_str :
312
317
return str_out
313
318
else :
319
+ warnings .warn (
320
+ "get_str=False is deprecated."
321
+ "Please enable this flag to get receive the output string." ,
322
+ DeprecationWarning ,
323
+ )
314
324
print (str_out )
315
325
326
+ def get_status_string (
327
+ self ,
328
+ max_it = None ,
329
+ stat_set = "train" ,
330
+ vars_print = None ,
331
+ skip_nan = False ,
332
+ stat_format = lambda s : s .replace ("loss_" , "" ).replace ("prev_stage_" , "ps_" ),
333
+ ):
334
+ return self .print (
335
+ max_it = max_it ,
336
+ stat_set = stat_set ,
337
+ vars_print = vars_print ,
338
+ get_str = True ,
339
+ skip_nan = skip_nan ,
340
+ stat_format = stat_format ,
341
+ )
342
+
316
343
def plot_stats (
317
344
self , visdom_env = None , plot_file = None , visdom_server = None , visdom_port = None
318
345
):
@@ -329,16 +356,15 @@ def plot_stats(
329
356
330
357
stat_sets = list (self .stats .keys ())
331
358
332
- print (
333
- "printing charts to visdom env '%s' (%s:%d)"
334
- % (visdom_env , visdom_server , visdom_port )
359
+ logger .debug (
360
+ f"printing charts to visdom env '{ visdom_env } ' ({ visdom_server } :{ visdom_port } )"
335
361
)
336
362
337
363
novisdom = False
338
364
339
365
viz = get_visdom_connection (server = visdom_server , port = visdom_port )
340
366
if viz is None or not viz .check_connection ():
341
- print ("no visdom server! -> skipping visdom plots" )
367
+ logger . info ("no visdom server! -> skipping visdom plots" )
342
368
novisdom = True
343
369
344
370
lines = []
@@ -385,7 +411,7 @@ def plot_stats(
385
411
)
386
412
387
413
if plot_file :
388
- print ( "exporting stats to %s" % plot_file )
414
+ logger . info ( f"plotting stats to { plot_file } " )
389
415
ncol = 3
390
416
nrow = int (np .ceil (float (len (lines )) / ncol ))
391
417
matplotlib .rcParams .update ({"font.size" : 5 })
@@ -423,15 +449,15 @@ def plot_stats(
423
449
except PermissionError :
424
450
warnings .warn ("Cant dump stats due to insufficient permissions!" )
425
451
426
- def synchronize_logged_vars (self , log_vars , default_val = float ("NaN" ), verbose = True ):
452
+ def synchronize_logged_vars (self , log_vars , default_val = float ("NaN" )):
427
453
428
454
stat_sets = list (self .stats .keys ())
429
455
430
456
# remove the additional log_vars
431
457
for stat_set in stat_sets :
432
458
for stat in self .stats [stat_set ].keys ():
433
459
if stat not in log_vars :
434
- print ( "additional stat %s:%s -> removing" % ( stat_set , stat ) )
460
+ logger . warning ( f "additional stat { stat_set } : { stat } -> removing" )
435
461
436
462
self .stats [stat_set ] = {
437
463
stat : v for stat , v in self .stats [stat_set ].items () if stat in log_vars
@@ -442,21 +468,19 @@ def synchronize_logged_vars(self, log_vars, default_val=float("NaN"), verbose=Tr
442
468
for stat_set in stat_sets :
443
469
for stat in log_vars :
444
470
if stat not in self .stats [stat_set ]:
445
- if verbose :
446
- print (
447
- "missing stat %s:%s -> filling with default values (%1.2f)"
448
- % (stat_set , stat , default_val )
449
- )
471
+ logger .info (
472
+ "missing stat %s:%s -> filling with default values (%1.2f)"
473
+ % (stat_set , stat , default_val )
474
+ )
450
475
elif len (self .stats [stat_set ][stat ].history ) != self .epoch + 1 :
451
476
h = self .stats [stat_set ][stat ].history
452
477
if len (h ) == 0 : # just never updated stat ... skip
453
478
continue
454
479
else :
455
- if verbose :
456
- print (
457
- "incomplete stat %s:%s -> reseting with default values (%1.2f)"
458
- % (stat_set , stat , default_val )
459
- )
480
+ logger .info (
481
+ "incomplete stat %s:%s -> reseting with default values (%1.2f)"
482
+ % (stat_set , stat , default_val )
483
+ )
460
484
else :
461
485
continue
462
486
0 commit comments