@@ -35,8 +35,9 @@ def parse_args():
3535 parser .add_argument (
3636 '--number_iteration' ,
3737 type = int ,
38- default = 10 ,
38+ default = 100 ,
3939 help = 'total batch num for per_gpu_batch_size' )
40+ parser .add_argument ('--display_step' , type = int , default = 1 , help = '' )
4041
4142 args = parser .parse_args ()
4243 return args
@@ -167,7 +168,7 @@ def SE_ResNeXt(input, class_dim, infer=False, layers=152):
167168
168169
169170def net_conf (image , label , class_dim ):
170- out = SE_ResNeXt152 (input = image , class_dim = class_dim )
171+ out = SE_ResNeXt (input = image , class_dim = class_dim )
171172 cost = fluid .layers .cross_entropy (input = out , label = label )
172173 avg_cost = fluid .layers .mean (x = cost )
173174 #accuracy = fluid.evaluator.Accuracy(input=out, label=label)
@@ -195,16 +196,18 @@ def train():
195196 startup = fluid .Program ()
196197
197198 with fluid .program_guard (main , startup ):
198- data_file = fluid .layers .open_recordio_file (
199- filename = './resnet_152.recordio_batch_size_12_3_224_224' , # ./resnet_152.recordio_batch_size_2
199+ reader = fluid .layers .open_recordio_file (
200+ filename = './flowers.recordio' ,
200201 shapes = [[- 1 , 3 , 224 , 224 ], [- 1 , 1 ]],
201202 lod_levels = [0 , 0 ],
202203 dtypes = ['float32' , 'int64' ])
203- image , label = fluid .layers .read_file (data_file )
204+ # currently, double buffer only supports one device.
205+ #data_file = fluid.layers.create_double_buffer_reader(reader=data_file, place='CUDA:0')
206+ image , label = fluid .layers .read_file (reader )
204207
205208 prediction , avg_cost , accuracy , accuracy5 = net_conf (image , label ,
206209 class_dim )
207-
210+ #optimizer = fluid.optimizer.SGD(learning_rate=0.002)
208211 optimizer = fluid .optimizer .Momentum (
209212 learning_rate = fluid .layers .piecewise_decay (
210213 boundaries = [100 ], values = [0.1 , 0.2 ]),
@@ -217,20 +220,26 @@ def train():
217220
218221 exe = fluid .ParallelExecutor (loss_name = avg_cost .name , use_cuda = True )
219222
220- batch_id = 0
223+ batch_id = - 1
221224 time_record = []
222- # with profiler.profiler('All', 'total', '/tmp/profile') as prof:
225+
223226 for i in xrange (args .number_iteration ):
227+ batch_id += 1
228+ if batch_id >= 5 and batch_id < 7 :
229+ with profiler .profiler ('All' , 'total' , '/tmp/profile' ) as prof :
230+ exe .run ([])
231+ continue
232+
224233 t1 = time .time ()
225- exe .run ([avg_cost .name ] if batch_id % 10 == 0 else [])
234+ cost_val = exe .run ([avg_cost .name ]
235+ if batch_id % args .display_step == 0 else [])
226236 t2 = time .time ()
227237 period = t2 - t1
228238 time_record .append (period )
229239
230- if batch_id % 10 == 0 :
231- print ("trainbatch {0}, time{1}" .format (batch_id ,
232- "%2.2f sec" % period ))
233- batch_id += 1
240+ if batch_id % args .display_step == 0 :
241+ print ("iter=%d, elapse=%f, cost=%s" %
242+ (batch_id , period , np .array (cost_val [0 ])))
234243
235244 del time_record [0 ]
236245 for ele in time_record :
0 commit comments