@@ -45,18 +45,19 @@ def _update_metric(engine):
45
45
# set up testing handler
46
46
stats_handler = StatsHandler (name = key_to_handler , logger_handler = log_handler )
47
47
stats_handler .attach (engine )
48
+ stats_handler .logger .setLevel (logging .INFO )
48
49
49
50
engine .run (range (3 ), max_epochs = 2 )
50
51
51
52
# check logging output
52
53
output_str = log_stream .getvalue ()
53
54
log_handler .close ()
54
- grep = re .compile (f".*{ key_to_handler } .*" )
55
55
has_key_word = re .compile (f".*{ key_to_print } .*" )
56
- for idx , line in enumerate (output_str .split ("\n " )):
57
- if grep .match (line ):
58
- if idx in [5 , 10 ]:
59
- self .assertTrue (has_key_word .match (line ))
56
+ content_count = 0
57
+ for line in output_str .split ("\n " ):
58
+ if has_key_word .match (line ):
59
+ content_count += 1
60
+ self .assertTrue (content_count > 0 )
60
61
61
62
def test_loss_print (self ):
62
63
log_stream = StringIO ()
@@ -74,18 +75,19 @@ def _train_func(engine, batch):
74
75
# set up testing handler
75
76
stats_handler = StatsHandler (name = key_to_handler , tag_name = key_to_print , logger_handler = log_handler )
76
77
stats_handler .attach (engine )
78
+ stats_handler .logger .setLevel (logging .INFO )
77
79
78
80
engine .run (range (3 ), max_epochs = 2 )
79
81
80
82
# check logging output
81
83
output_str = log_stream .getvalue ()
82
84
log_handler .close ()
83
- grep = re .compile (f".*{ key_to_handler } .*" )
84
85
has_key_word = re .compile (f".*{ key_to_print } .*" )
85
- for idx , line in enumerate (output_str .split ("\n " )):
86
- if grep .match (line ):
87
- if idx in [1 , 2 , 3 , 6 , 7 , 8 ]:
88
- self .assertTrue (has_key_word .match (line ))
86
+ content_count = 0
87
+ for line in output_str .split ("\n " ):
88
+ if has_key_word .match (line ):
89
+ content_count += 1
90
+ self .assertTrue (content_count > 0 )
89
91
90
92
def test_loss_dict (self ):
91
93
log_stream = StringIO ()
@@ -102,21 +104,22 @@ def _train_func(engine, batch):
102
104
103
105
# set up testing handler
104
106
stats_handler = StatsHandler (
105
- name = key_to_handler , output_transform = lambda x : {key_to_print : x }, logger_handler = log_handler
107
+ name = key_to_handler , output_transform = lambda x : {key_to_print : x [ 0 ] }, logger_handler = log_handler
106
108
)
107
109
stats_handler .attach (engine )
110
+ stats_handler .logger .setLevel (logging .INFO )
108
111
109
112
engine .run (range (3 ), max_epochs = 2 )
110
113
111
114
# check logging output
112
115
output_str = log_stream .getvalue ()
113
116
log_handler .close ()
114
- grep = re .compile (f".*{ key_to_handler } .*" )
115
117
has_key_word = re .compile (f".*{ key_to_print } .*" )
116
- for idx , line in enumerate (output_str .split ("\n " )):
117
- if grep .match (line ):
118
- if idx in [1 , 2 , 3 , 6 , 7 , 8 ]:
119
- self .assertTrue (has_key_word .match (line ))
118
+ content_count = 0
119
+ for line in output_str .split ("\n " ):
120
+ if has_key_word .match (line ):
121
+ content_count += 1
122
+ self .assertTrue (content_count > 0 )
120
123
121
124
def test_loss_file (self ):
122
125
key_to_handler = "test_logging"
@@ -136,18 +139,19 @@ def _train_func(engine, batch):
136
139
# set up testing handler
137
140
stats_handler = StatsHandler (name = key_to_handler , tag_name = key_to_print , logger_handler = handler )
138
141
stats_handler .attach (engine )
142
+ stats_handler .logger .setLevel (logging .INFO )
139
143
140
144
engine .run (range (3 ), max_epochs = 2 )
141
145
handler .close ()
142
146
stats_handler .logger .removeHandler (handler )
143
147
with open (filename ) as f :
144
148
output_str = f .read ()
145
- grep = re .compile (f".*{ key_to_handler } .*" )
146
149
has_key_word = re .compile (f".*{ key_to_print } .*" )
147
- for idx , line in enumerate (output_str .split ("\n " )):
148
- if grep .match (line ):
149
- if idx in [1 , 2 , 3 , 6 , 7 , 8 ]:
150
- self .assertTrue (has_key_word .match (line ))
150
+ content_count = 0
151
+ for line in output_str .split ("\n " ):
152
+ if has_key_word .match (line ):
153
+ content_count += 1
154
+ self .assertTrue (content_count > 0 )
151
155
152
156
def test_exception (self ):
153
157
# set up engine
@@ -190,17 +194,19 @@ def _update_metric(engine):
190
194
name = key_to_handler , state_attributes = ["test1" , "test2" , "test3" ], logger_handler = log_handler
191
195
)
192
196
stats_handler .attach (engine )
197
+ stats_handler .logger .setLevel (logging .INFO )
193
198
194
199
engine .run (range (3 ), max_epochs = 2 )
195
200
196
201
# check logging output
197
202
output_str = log_stream .getvalue ()
198
203
log_handler .close ()
199
- grep = re .compile (f".*{ key_to_handler } .*" )
200
204
has_key_word = re .compile (".*State values.*" )
201
- for idx , line in enumerate (output_str .split ("\n " )):
202
- if grep .match (line ) and idx in [5 , 10 ]:
203
- self .assertTrue (has_key_word .match (line ))
205
+ content_count = 0
206
+ for line in output_str .split ("\n " ):
207
+ if has_key_word .match (line ):
208
+ content_count += 1
209
+ self .assertTrue (content_count > 0 )
204
210
205
211
206
212
if __name__ == "__main__" :
0 commit comments