Skip to content

Commit d08fe6d

Browse files
Softly deprecate the get_str=False flag.
Summary: We don't want to use print directly in stats.print() method. Instead this method will return the output string to the caller. Reviewed By: shapovalov Differential Revision: D45356240 fbshipit-source-id: 2cabe3cdfb9206bf09aa7b3cdd2263148a5ba145
1 parent 297020a commit d08fe6d

File tree

2 files changed

+58
-34
lines changed

2 files changed

+58
-34
lines changed

projects/implicitron_trainer/impl/training_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,6 @@ def load_stats(
256256
list(log_vars),
257257
plot_file=os.path.join(exp_dir, "train_stats.pdf"),
258258
visdom_env=visdom_env_charts,
259-
verbose=False,
260259
visdom_server=self.visdom_server,
261260
visdom_port=self.visdom_port,
262261
)
@@ -382,7 +381,8 @@ def _training_or_validation_epoch(
382381

383382
# print textual status update
384383
if it % self.metric_print_interval == 0 or last_iter:
385-
stats.print(stat_set=trainmode, max_it=n_batches)
384+
std_out = stats.get_status_string(stat_set=trainmode, max_it=n_batches)
385+
logger.info(std_out)
386386

387387
# visualize results
388388
if (

pytorch3d/implicitron/tools/stats.py

Lines changed: 56 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import gzip
88
import json
9+
import logging
910
import time
1011
import warnings
1112
from collections.abc import Iterable
@@ -17,6 +18,8 @@
1718
from matplotlib import colors as mcolors
1819
from pytorch3d.implicitron.tools.vis_utils import get_visdom_connection
1920

21+
logger = logging.getLogger(__name__)
22+
2023

2124
class AverageMeter(object):
2225
"""Computes and stores the average and current value"""
@@ -91,7 +94,9 @@ class Stats(object):
9194
# stats.update() automatically parses the 'objective' and 'top1e' from
9295
# the "output" dict and stores this into the db
9396
stats.update(output)
94-
stats.print() # prints the averages over given epoch
97+
# prints the metric averages over given epoch
98+
std_out = stats.get_status_string()
99+
logger.info(str_out)
95100
# stores the training plots into '/tmp/epoch_stats.pdf'
96101
# and plots into a visdom server running at localhost (if running)
97102
stats.plot_stats(plot_file='/tmp/epoch_stats.pdf')
@@ -101,7 +106,6 @@ class Stats(object):
101106
def __init__(
102107
self,
103108
log_vars,
104-
verbose=False,
105109
epoch=-1,
106110
visdom_env="main",
107111
do_plot=True,
@@ -110,7 +114,6 @@ def __init__(
110114
visdom_port=8097,
111115
):
112116

113-
self.verbose = verbose
114117
self.log_vars = log_vars
115118
self.visdom_env = visdom_env
116119
self.visdom_server = visdom_server
@@ -156,32 +159,29 @@ def __exit__(self, type, value, traceback):
156159
iserr = type is not None and issubclass(type, Exception)
157160
iserr = iserr or (type is KeyboardInterrupt)
158161
if iserr:
159-
print("error inside 'with' block")
162+
logger.error("error inside 'with' block")
160163
return
161164
if self.do_plot:
162165
self.plot_stats(self.visdom_env)
163166

164167
def reset(self): # to be called after each epoch
165168
stat_sets = list(self.stats.keys())
166-
if self.verbose:
167-
print("stats: epoch %d - reset" % self.epoch)
169+
logger.debug(f"stats: epoch {self.epoch} - reset")
168170
self.it = {k: -1 for k in stat_sets}
169171
for stat_set in stat_sets:
170172
for stat in self.stats[stat_set]:
171173
self.stats[stat_set][stat].reset()
172174

173175
def hard_reset(self, epoch=-1): # to be called during object __init__
174176
self.epoch = epoch
175-
if self.verbose:
176-
print("stats: epoch %d - hard reset" % self.epoch)
177+
logger.debug(f"stats: epoch {self.epoch} - hard reset")
177178
self.stats = {}
178179

179180
# reset
180181
self.reset()
181182

182183
def new_epoch(self):
183-
if self.verbose:
184-
print("stats: new epoch %d" % (self.epoch + 1))
184+
logger.debug(f"stats: new epoch {(self.epoch + 1)}")
185185
self.epoch += 1
186186
self.reset() # zero the stats + increase epoch counter
187187

@@ -193,18 +193,17 @@ def gather_value(self, val):
193193
val = float(val.sum())
194194
return val
195195

196-
def add_log_vars(self, added_log_vars, verbose=True):
196+
def add_log_vars(self, added_log_vars):
197197
for add_log_var in added_log_vars:
198198
if add_log_var not in self.stats:
199-
if verbose:
200-
print(f"Adding {add_log_var}")
199+
logger.debug(f"Adding {add_log_var}")
201200
self.log_vars.append(add_log_var)
202201

203202
def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"):
204203

205204
if self.epoch == -1: # uninitialized
206-
print(
207-
"warning: epoch==-1 means uninitialized stats structure -> new_epoch() called"
205+
logger.warning(
206+
"epoch==-1 means uninitialized stats structure -> new_epoch() called"
208207
)
209208
self.new_epoch()
210209

@@ -284,6 +283,12 @@ def print(
284283
skip_nan=False,
285284
stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"),
286285
):
286+
"""
287+
stats.print() is deprecated. Please use get_status_string() instead.
288+
example:
289+
std_out = stats.get_status_string()
290+
logger.info(str_out)
291+
"""
287292

288293
epoch = self.epoch
289294
stats = self.stats
@@ -311,8 +316,30 @@ def print(
311316
if get_str:
312317
return str_out
313318
else:
319+
warnings.warn(
320+
"get_str=False is deprecated."
321+
"Please enable this flag to get receive the output string.",
322+
DeprecationWarning,
323+
)
314324
print(str_out)
315325

326+
def get_status_string(
327+
self,
328+
max_it=None,
329+
stat_set="train",
330+
vars_print=None,
331+
skip_nan=False,
332+
stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"),
333+
):
334+
return self.print(
335+
max_it=max_it,
336+
stat_set=stat_set,
337+
vars_print=vars_print,
338+
get_str=True,
339+
skip_nan=skip_nan,
340+
stat_format=stat_format,
341+
)
342+
316343
def plot_stats(
317344
self, visdom_env=None, plot_file=None, visdom_server=None, visdom_port=None
318345
):
@@ -329,16 +356,15 @@ def plot_stats(
329356

330357
stat_sets = list(self.stats.keys())
331358

332-
print(
333-
"printing charts to visdom env '%s' (%s:%d)"
334-
% (visdom_env, visdom_server, visdom_port)
359+
logger.debug(
360+
f"printing charts to visdom env '{visdom_env}' ({visdom_server}:{visdom_port})"
335361
)
336362

337363
novisdom = False
338364

339365
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
340366
if viz is None or not viz.check_connection():
341-
print("no visdom server! -> skipping visdom plots")
367+
logger.info("no visdom server! -> skipping visdom plots")
342368
novisdom = True
343369

344370
lines = []
@@ -385,7 +411,7 @@ def plot_stats(
385411
)
386412

387413
if plot_file:
388-
print("exporting stats to %s" % plot_file)
414+
logger.info(f"plotting stats to {plot_file}")
389415
ncol = 3
390416
nrow = int(np.ceil(float(len(lines)) / ncol))
391417
matplotlib.rcParams.update({"font.size": 5})
@@ -423,15 +449,15 @@ def plot_stats(
423449
except PermissionError:
424450
warnings.warn("Cant dump stats due to insufficient permissions!")
425451

426-
def synchronize_logged_vars(self, log_vars, default_val=float("NaN"), verbose=True):
452+
def synchronize_logged_vars(self, log_vars, default_val=float("NaN")):
427453

428454
stat_sets = list(self.stats.keys())
429455

430456
# remove the additional log_vars
431457
for stat_set in stat_sets:
432458
for stat in self.stats[stat_set].keys():
433459
if stat not in log_vars:
434-
print("additional stat %s:%s -> removing" % (stat_set, stat))
460+
logger.warning(f"additional stat {stat_set}:{stat} -> removing")
435461

436462
self.stats[stat_set] = {
437463
stat: v for stat, v in self.stats[stat_set].items() if stat in log_vars
@@ -442,21 +468,19 @@ def synchronize_logged_vars(self, log_vars, default_val=float("NaN"), verbose=Tr
442468
for stat_set in stat_sets:
443469
for stat in log_vars:
444470
if stat not in self.stats[stat_set]:
445-
if verbose:
446-
print(
447-
"missing stat %s:%s -> filling with default values (%1.2f)"
448-
% (stat_set, stat, default_val)
449-
)
471+
logger.info(
472+
"missing stat %s:%s -> filling with default values (%1.2f)"
473+
% (stat_set, stat, default_val)
474+
)
450475
elif len(self.stats[stat_set][stat].history) != self.epoch + 1:
451476
h = self.stats[stat_set][stat].history
452477
if len(h) == 0: # just never updated stat ... skip
453478
continue
454479
else:
455-
if verbose:
456-
print(
457-
"incomplete stat %s:%s -> reseting with default values (%1.2f)"
458-
% (stat_set, stat, default_val)
459-
)
480+
logger.info(
481+
"incomplete stat %s:%s -> reseting with default values (%1.2f)"
482+
% (stat_set, stat, default_val)
483+
)
460484
else:
461485
continue
462486

0 commit comments

Comments
 (0)