@@ -80,8 +80,9 @@ def train_network():
8080 label = fluid .layers .data (name = 'label' , shape = [1 ], dtype = 'int64' )
8181 cost = fluid .layers .cross_entropy (input = predict , label = label )
8282 avg_cost = fluid .layers .mean (cost )
83- accuracy = fluid .layers .accuracy (input = predict , label = label )
84- return avg_cost , accuracy
83+ # accuracy = fluid.layers.accuracy(input=predict, label=label)
84+ # return avg_cost, accuracy
85+ return avg_cost
8586
8687
8788def train (use_cuda , save_path ):
@@ -97,12 +98,12 @@ def train(use_cuda, save_path):
9798 paddle .dataset .cifar .test10 (), batch_size = BATCH_SIZE )
9899
99100 def event_handler (event ):
100- if isinstance (event , fluid .EndIteration ):
101- if (event .batch_id % 10 ) == 0 :
101+ if isinstance (event , fluid .EndEpochEvent ):
102+ if (event .epoch % 10 ) == 0 :
102103 avg_cost , accuracy = trainer .test (reader = test_reader )
103104
104105 print ('BatchID {1:04}, Loss {2:2.2}, Acc {3:2.2}' .format (
105- event .batch_id + 1 , avg_cost , accuracy ))
106+ event .epoch + 1 , avg_cost , accuracy ))
106107
107108 if accuracy > 0.01 : # Low threshold for speeding up CI
108109 trainer .params .save (save_path )
@@ -114,13 +115,18 @@ def event_handler(event):
114115 optimizer = fluid .optimizer .Adam (learning_rate = 0.001 ),
115116 place = place ,
116117 event_handler = event_handler )
117- trainer .train (train_reader , EPOCH_NUM , event_handler = event_handler )
118+
119+ trainer .train (
120+ num_epochs = EPOCH_NUM ,
121+ event_handler = event_handler ,
122+ reader = train_reader ,
123+ feed_order = ['pixel' , 'label' ])
118124
119125
120126def infer (use_cuda , save_path ):
121- params = fluid .Params (save_path )
122127 place = fluid .CUDAPlace (0 ) if use_cuda else fluid .CPUPlace ()
123- inferencer = fluid .Inferencer (inference_network , params , place = place )
128+ inferencer = fluid .Inferencer (
129+ inference_network , save_path = save_path , place = place )
124130
125131 # The input's dimension of conv should be 4-D or 5-D.
126132 # Use normilized image pixels as input data, which should be in the range
0 commit comments