@@ -75,11 +75,15 @@ def __init__(self, program_func, optimizer, param_path=None, place=None):
7575 self .train_program = framework .Program ()
7676
7777 with framework .program_guard (self .train_program , self .startup_program ):
78- loss = program_func ()
78+ program_func_outs = program_func ()
79+ self .test_outputs = program_func_outs if isinstance (
80+ program_func_outs , list ) else [program_func_outs ]
81+ self .test_program = self .train_program .clone ()
7982 if not isinstance (optimizer , opt_module .Optimizer ):
8083 raise TypeError (
8184 "The optimizer should be an instance of Optimizer" )
82-
85+ # The fisrt element of program_func_outs is loss.
86+ loss = self .test_outputs [0 ]
8387 optimize_ops , params_grads = optimizer .minimize (loss )
8488
8589 self .place = Trainer ._check_and_get_place (place )
@@ -168,8 +172,17 @@ def train(self,
168172
169173 self ._train_by_executor (num_epochs , event_handler , reader , feed_order )
170174
171- def test (self , reader ):
172- pass
175+ def test (self , reader , feed_order = None ):
176+ """
177+ Test the model on given test data
178+
179+ Args:
180+ reader: The reader that yields test data.
181+ feed_order: Feeding order of reader. None will following the defining
182+ order in program
183+ """
184+
185+ return self ._test_by_executor (reader , feed_order , self .test_outputs )
173186
174187 def save_params (self , param_path ):
175188 # reference: save_persistables in io.py
@@ -225,26 +238,59 @@ def _train_by_executor(self, num_epochs, event_handler, reader, feed_order):
225238
226239 """
227240 with self ._prog_and_scope_guard ():
228- exe = executor .Executor (self .place )
229- if feed_order is None :
230- feed_var_list = [
231- var
232- for var in self .train_program .global_block (
233- ).vars .itervalues ()
234- if hasattr (var , 'is_data' ) and var .is_data
235- ]
236- else :
237- feed_var_list = [
238- self .train_program .global_block ().var (var_name )
239- for var_name in feed_order
240- ]
241-
241+ feed_var_list = build_feed_var_list (self .train_program , feed_order )
242242 feeder = data_feeder .DataFeeder (
243243 feed_list = feed_var_list , place = self .place )
244+ exe = executor .Executor (self .place )
244245 for epoch_id in range (num_epochs ):
245246 event_handler (BeginEpochEvent (epoch_id ))
246247 for step_id , data in enumerate (reader ()):
247248 event_handler (BeginStepEvent (epoch_id , step_id ))
248249 exe .run (feed = feeder .feed (data ), fetch_list = [])
249250 event_handler (EndStepEvent (epoch_id , step_id ))
250251 event_handler (EndEpochEvent (epoch_id ))
252+
253+ def _test_by_executor (self , reader , feed_order , fetch_list ):
254+ with executor .scope_guard (self .scope ):
255+ feed_var_list = build_feed_var_list (self .test_program , feed_order )
256+ feeder = data_feeder .DataFeeder (
257+ feed_list = feed_var_list , place = self .place )
258+ exe = executor .Executor (self .place )
259+ accumulated = len (fetch_list ) * [0 ]
260+ count = 0
261+ for data in reader ():
262+ outs = exe .run (program = self .test_program ,
263+ feed = feeder .feed (data ),
264+ fetch_list = fetch_list )
265+ accumulated = [x [0 ] + x [1 ][0 ] for x in zip (accumulated , outs )]
266+ count += 1
267+
268+ return [x / count for x in accumulated ]
269+
270+
271+ def build_feed_var_list (program , feed_order ):
272+ if not isinstance (program , framework .Program ):
273+ raise TypeError ("The 'program' should be an object of Program" )
274+
275+ if feed_order is None :
276+ feed_var_list = [
277+ var for var in program .global_block ().vars .itervalues ()
278+ if var .is_data
279+ ]
280+ elif isinstance (feed_order , list ):
281+ feed_var_list = [
282+ program .global_block ().var (var_name ) for var_name in feed_order
283+ ]
284+ else :
285+ if not isinstance (feed_order , dict ):
286+ raise TypeError (
287+ "The 'feed_order' should be either None, list or dict." )
288+ if not sorted (feed_order .values ()) == range (len (feed_order )):
289+ raise ValueError (
290+ "The values of 'feed_order' should be a permutation of [0, len(feed_order))"
291+ )
292+ sorted_pair_list = sorted (feed_order .items (), key = lambda item : item [1 ])
293+ feed_var_list = [
294+ program .global_block ().var (pair [0 ]) for pair in sorted_pair_list
295+ ]
296+ return feed_var_list
0 commit comments