Skip to content

Commit 4bd13fe

Browse files
authored
498 Add logger_handler to LrScheduleHandler (#3570)
* [DLMED] add log handler Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix CI tests Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix CI test Signed-off-by: Nic Ma <[email protected]> * [DLMED] test CI Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix logging Signed-off-by: Nic Ma <[email protected]> * [DLMED] temp test Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix wrong unit test Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix wrong test cases Signed-off-by: Nic Ma <[email protected]>
1 parent bb4ad5f commit 4bd13fe

File tree

6 files changed

+78
-39
lines changed

6 files changed

+78
-39
lines changed

monai/handlers/lr_schedule_handler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
name: Optional[str] = None,
3737
epoch_level: bool = True,
3838
step_transform: Callable[[Engine], Any] = lambda engine: (),
39+
logger_handler: Optional[logging.Handler] = None,
3940
) -> None:
4041
"""
4142
Args:
@@ -47,6 +48,9 @@ def __init__(
4748
`True` is epoch level, `False` is iteration level.
4849
step_transform: a callable that is used to transform the information from `engine`
4950
to expected input data of lr_scheduler.step() function if necessary.
51+
logger_handler: if `print_lr` is True, add additional handler to log the learning rate: save to file, etc.
52+
all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html.
53+
the handler should have a logging level of at least `INFO`.
5054
5155
Raises:
5256
TypeError: When ``step_transform`` is not ``callable``.
@@ -59,6 +63,8 @@ def __init__(
5963
if not callable(step_transform):
6064
raise TypeError(f"step_transform must be callable but is {type(step_transform).__name__}.")
6165
self.step_transform = step_transform
66+
if logger_handler is not None:
67+
self.logger.addHandler(logger_handler)
6268

6369
self._name = name
6470

monai/handlers/stats_handler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def __init__(
8282
tag_name: scalar_value to logger. Defaults to ``'Loss'``.
8383
key_var_format: a formatting string to control the output string format of key: value.
8484
logger_handler: add additional handler to handle the stats data: save to file, etc.
85-
Add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html
85+
all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html.
86+
the handler should have a logging level of at least `INFO`.
8687
"""
8788

8889
self.epoch_print_logger = epoch_print_logger

monai/transforms/utility/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ def __init__(
570570
a typical example is to print some properties of Nifti image: affine, pixdim, etc.
571571
additional_info: user can define callable function to extract additional info from input data.
572572
logger_handler: add additional handler to output data: save to file, etc.
573-
add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html
573+
all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html.
574574
the handler should have a logging level of at least `INFO`.
575575
576576
Raises:

monai/transforms/utility/dictionary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ def __init__(
795795
additional info from input data. it also can be a sequence of string, each element
796796
corresponds to a key in ``keys``.
797797
logger_handler: add additional handler to output data: save to file, etc.
798-
add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html
798+
all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html.
799799
the handler should have a logging level of at least `INFO`.
800800
allow_missing_keys: don't raise exception if key is missing.
801801

tests/test_handler_lr_scheduler.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
# limitations under the License.
1111

1212
import logging
13+
import os
14+
import re
1315
import sys
16+
import tempfile
1417
import unittest
1518

1619
import numpy as np
@@ -24,6 +27,8 @@ class TestHandlerLrSchedule(unittest.TestCase):
2427
def test_content(self):
2528
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
2629
data = [0] * 8
30+
test_lr = 0.1
31+
gamma = 0.1
2732

2833
# set up engine
2934
def _train_func(engine, batch):
@@ -41,24 +46,45 @@ def run_validation(engine):
4146
net = torch.nn.PReLU()
4247

4348
def _reduce_lr_on_plateau():
44-
optimizer = torch.optim.SGD(net.parameters(), 0.1)
49+
optimizer = torch.optim.SGD(net.parameters(), test_lr)
4550
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=1)
4651
handler = LrScheduleHandler(lr_scheduler, step_transform=lambda x: val_engine.state.metrics["val_loss"])
4752
handler.attach(train_engine)
48-
return lr_scheduler
53+
return handler
4954

50-
def _reduce_on_step():
51-
optimizer = torch.optim.SGD(net.parameters(), 0.1)
52-
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
53-
handler = LrScheduleHandler(lr_scheduler)
54-
handler.attach(train_engine)
55-
return lr_scheduler
55+
with tempfile.TemporaryDirectory() as tempdir:
56+
key_to_handler = "test_log_lr"
57+
key_to_print = "Current learning rate"
58+
filename = os.path.join(tempdir, "test_lr.log")
59+
# test with additional logging handler
60+
file_saver = logging.FileHandler(filename, mode="w")
61+
file_saver.setLevel(logging.INFO)
62+
63+
def _reduce_on_step():
64+
optimizer = torch.optim.SGD(net.parameters(), test_lr)
65+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=gamma)
66+
handler = LrScheduleHandler(lr_scheduler, name=key_to_handler, logger_handler=file_saver)
67+
handler.attach(train_engine)
68+
handler.logger.setLevel(logging.INFO)
69+
return handler
70+
71+
schedulers = _reduce_lr_on_plateau(), _reduce_on_step()
72+
73+
train_engine.run(data, max_epochs=5)
74+
file_saver.close()
75+
schedulers[1].logger.removeHandler(file_saver)
5676

57-
schedulers = _reduce_lr_on_plateau(), _reduce_on_step()
77+
with open(filename) as f:
78+
output_str = f.read()
79+
has_key_word = re.compile(f".*{key_to_print}.*")
80+
content_count = 0
81+
for line in output_str.split("\n"):
82+
if has_key_word.match(line):
83+
content_count += 1
84+
self.assertTrue(content_count > 0)
5885

59-
train_engine.run(data, max_epochs=5)
6086
for scheduler in schedulers:
61-
np.testing.assert_allclose(scheduler._last_lr[0], 0.001)
87+
np.testing.assert_allclose(scheduler.lr_scheduler._last_lr[0], 0.001)
6288

6389

6490
if __name__ == "__main__":

tests/test_handler_stats.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,19 @@ def _update_metric(engine):
4545
# set up testing handler
4646
stats_handler = StatsHandler(name=key_to_handler, logger_handler=log_handler)
4747
stats_handler.attach(engine)
48+
stats_handler.logger.setLevel(logging.INFO)
4849

4950
engine.run(range(3), max_epochs=2)
5051

5152
# check logging output
5253
output_str = log_stream.getvalue()
5354
log_handler.close()
54-
grep = re.compile(f".*{key_to_handler}.*")
5555
has_key_word = re.compile(f".*{key_to_print}.*")
56-
for idx, line in enumerate(output_str.split("\n")):
57-
if grep.match(line):
58-
if idx in [5, 10]:
59-
self.assertTrue(has_key_word.match(line))
56+
content_count = 0
57+
for line in output_str.split("\n"):
58+
if has_key_word.match(line):
59+
content_count += 1
60+
self.assertTrue(content_count > 0)
6061

6162
def test_loss_print(self):
6263
log_stream = StringIO()
@@ -74,18 +75,19 @@ def _train_func(engine, batch):
7475
# set up testing handler
7576
stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print, logger_handler=log_handler)
7677
stats_handler.attach(engine)
78+
stats_handler.logger.setLevel(logging.INFO)
7779

7880
engine.run(range(3), max_epochs=2)
7981

8082
# check logging output
8183
output_str = log_stream.getvalue()
8284
log_handler.close()
83-
grep = re.compile(f".*{key_to_handler}.*")
8485
has_key_word = re.compile(f".*{key_to_print}.*")
85-
for idx, line in enumerate(output_str.split("\n")):
86-
if grep.match(line):
87-
if idx in [1, 2, 3, 6, 7, 8]:
88-
self.assertTrue(has_key_word.match(line))
86+
content_count = 0
87+
for line in output_str.split("\n"):
88+
if has_key_word.match(line):
89+
content_count += 1
90+
self.assertTrue(content_count > 0)
8991

9092
def test_loss_dict(self):
9193
log_stream = StringIO()
@@ -102,21 +104,22 @@ def _train_func(engine, batch):
102104

103105
# set up testing handler
104106
stats_handler = StatsHandler(
105-
name=key_to_handler, output_transform=lambda x: {key_to_print: x}, logger_handler=log_handler
107+
name=key_to_handler, output_transform=lambda x: {key_to_print: x[0]}, logger_handler=log_handler
106108
)
107109
stats_handler.attach(engine)
110+
stats_handler.logger.setLevel(logging.INFO)
108111

109112
engine.run(range(3), max_epochs=2)
110113

111114
# check logging output
112115
output_str = log_stream.getvalue()
113116
log_handler.close()
114-
grep = re.compile(f".*{key_to_handler}.*")
115117
has_key_word = re.compile(f".*{key_to_print}.*")
116-
for idx, line in enumerate(output_str.split("\n")):
117-
if grep.match(line):
118-
if idx in [1, 2, 3, 6, 7, 8]:
119-
self.assertTrue(has_key_word.match(line))
118+
content_count = 0
119+
for line in output_str.split("\n"):
120+
if has_key_word.match(line):
121+
content_count += 1
122+
self.assertTrue(content_count > 0)
120123

121124
def test_loss_file(self):
122125
key_to_handler = "test_logging"
@@ -136,18 +139,19 @@ def _train_func(engine, batch):
136139
# set up testing handler
137140
stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print, logger_handler=handler)
138141
stats_handler.attach(engine)
142+
stats_handler.logger.setLevel(logging.INFO)
139143

140144
engine.run(range(3), max_epochs=2)
141145
handler.close()
142146
stats_handler.logger.removeHandler(handler)
143147
with open(filename) as f:
144148
output_str = f.read()
145-
grep = re.compile(f".*{key_to_handler}.*")
146149
has_key_word = re.compile(f".*{key_to_print}.*")
147-
for idx, line in enumerate(output_str.split("\n")):
148-
if grep.match(line):
149-
if idx in [1, 2, 3, 6, 7, 8]:
150-
self.assertTrue(has_key_word.match(line))
150+
content_count = 0
151+
for line in output_str.split("\n"):
152+
if has_key_word.match(line):
153+
content_count += 1
154+
self.assertTrue(content_count > 0)
151155

152156
def test_exception(self):
153157
# set up engine
@@ -190,17 +194,19 @@ def _update_metric(engine):
190194
name=key_to_handler, state_attributes=["test1", "test2", "test3"], logger_handler=log_handler
191195
)
192196
stats_handler.attach(engine)
197+
stats_handler.logger.setLevel(logging.INFO)
193198

194199
engine.run(range(3), max_epochs=2)
195200

196201
# check logging output
197202
output_str = log_stream.getvalue()
198203
log_handler.close()
199-
grep = re.compile(f".*{key_to_handler}.*")
200204
has_key_word = re.compile(".*State values.*")
201-
for idx, line in enumerate(output_str.split("\n")):
202-
if grep.match(line) and idx in [5, 10]:
203-
self.assertTrue(has_key_word.match(line))
205+
content_count = 0
206+
for line in output_str.split("\n"):
207+
if has_key_word.match(line):
208+
content_count += 1
209+
self.assertTrue(content_count > 0)
204210

205211

206212
if __name__ == "__main__":

0 commit comments

Comments
 (0)