@@ -35,8 +35,9 @@ def parse_args():
3535 parser .add_argument (
3636 '--number_iteration' ,
3737 type = int ,
38- default = 10 ,
38+ default = 100 ,
3939 help = 'total batch num for per_gpu_batch_size' )
40+ parser .add_argument ('--display_step' , type = int , default = 1 , help = '' )
4041
4142 args = parser .parse_args ()
4243 return args
@@ -167,7 +168,7 @@ def SE_ResNeXt(input, class_dim, infer=False, layers=152):
167168
168169
169170def net_conf (image , label , class_dim ):
170- out = SE_ResNeXt152 (input = image , class_dim = class_dim )
171+ out = SE_ResNeXt (input = image , class_dim = class_dim )
171172 cost = fluid .layers .cross_entropy (input = out , label = label )
172173 avg_cost = fluid .layers .mean (x = cost )
173174 #accuracy = fluid.evaluator.Accuracy(input=out, label=label)
@@ -195,16 +196,16 @@ def train():
195196 startup = fluid .Program ()
196197
197198 with fluid .program_guard (main , startup ):
198- data_file = fluid .layers .open_recordio_file (
199- filename = './resnet_152.recordio_batch_size_12_3_224_224' , # ./resnet_152.recordio_batch_size_2
199+ reader = fluid .layers .open_recordio_file (
200+ filename = './flowers.recordio' ,
200201 shapes = [[- 1 , 3 , 224 , 224 ], [- 1 , 1 ]],
201202 lod_levels = [0 , 0 ],
202203 dtypes = ['float32' , 'int64' ])
203- image , label = fluid .layers .read_file (data_file )
204+ image , label = fluid .layers .read_file (reader )
204205
205206 prediction , avg_cost , accuracy , accuracy5 = net_conf (image , label ,
206207 class_dim )
207-
208+ #optimizer = fluid.optimizer.SGD(learning_rate=0.002)
208209 optimizer = fluid .optimizer .Momentum (
209210 learning_rate = fluid .layers .piecewise_decay (
210211 boundaries = [100 ], values = [0.1 , 0.2 ]),
@@ -217,20 +218,26 @@ def train():
217218
218219 exe = fluid .ParallelExecutor (loss_name = avg_cost .name , use_cuda = True )
219220
220- batch_id = 0
221+ batch_id = - 1
221222 time_record = []
222- # with profiler.profiler('All', 'total', '/tmp/profile') as prof:
223+
223224 for i in xrange (args .number_iteration ):
225+ batch_id += 1
226+ if batch_id >= 5 and batch_id < 7 :
227+ with profiler .profiler ('All' , 'total' , '/tmp/profile' ) as prof :
228+ exe .run ([])
229+ continue
230+
224231 t1 = time .time ()
225- exe .run ([avg_cost .name ] if batch_id % 10 == 0 else [])
232+ cost_val = exe .run ([avg_cost .name ]
233+ if batch_id % args .display_step == 0 else [])
226234 t2 = time .time ()
227235 period = t2 - t1
228236 time_record .append (period )
229237
230- if batch_id % 10 == 0 :
231- print ("trainbatch {0}, time{1}" .format (batch_id ,
232- "%2.2f sec" % period ))
233- batch_id += 1
238+ if batch_id % args .display_step == 0 :
239+ print ("iter=%d, elapse=%f, cost=%s" %
240+ (batch_id , period , np .array (cost_val [0 ])))
234241
235242 del time_record [0 ]
236243 for ele in time_record :
0 commit comments