2626def parse_args ():
2727 parser = argparse .ArgumentParser ('SE-ResNeXt-152 parallel profile.' )
2828 parser .add_argument ('--per_gpu_batch_size' , type = int , default = 12 , help = '' )
29+ parser .add_argument (
30+ '--use_mem_opt' ,
31+ type = distutils .util .strtobool ,
32+ default = True ,
33+ help = 'use memory optimize' )
2934 parser .add_argument (
3035 '--skip_first_steps' ,
3136 type = int ,
@@ -212,7 +217,8 @@ def train():
212217 regularization = fluid .regularizer .L2Decay (1e-4 ))
213218 opts = optimizer .minimize (avg_cost )
214219
215- fluid .memory_optimize (fluid .default_main_program ())
220+ if args .use_mem_opt :
221+ fluid .memory_optimize (fluid .default_main_program ())
216222
217223 place = fluid .CUDAPlace (0 )
218224 # place = fluid.CPUPlace()
@@ -227,27 +233,26 @@ def train():
227233 feed_dict = feeder .feed (data )
228234
229235 for pass_id in range (1 ):
230- with profiler .profiler ('All' , 'total' , '/tmp/profile' ) as prof :
231- train_time = 0.0
232-
233- for step_id in range (step_num ):
234- train_start = time .time ()
235- exe .run (fluid .default_main_program (),
236- feed = feeder .feed (train_reader_iter .next ())
237- if args .use_python_reader else feed_dict ,
238- fetch_list = [],
239- use_program_cache = True )
240- train_stop = time .time ()
241- step_time = train_stop - train_start
242- if step_id >= args .skip_first_steps :
243- train_time += step_time
244- print ("step_id=" + str (step_id ) + " step_time=" + str (
245- step_time ))
246- print ("\n \n \n " )
247- calc_step_num = step_num - args .skip_first_steps
248- print ("calc_step_num=" + str (calc_step_num ) + " total_train_time=" +
249- str (train_time ) + " ave_step_time=" + str (
250- float (train_time ) / calc_step_num ))
236+ #with profiler.profiler('All', 'total', '/tmp/profile') as prof:
237+ train_time = 0.0
238+
239+ for step_id in range (step_num ):
240+ train_start = time .time ()
241+ exe .run (fluid .default_main_program (),
242+ feed = feeder .feed (train_reader_iter .next ())
243+ if args .use_python_reader else feed_dict ,
244+ fetch_list = [],
245+ use_program_cache = True )
246+ train_stop = time .time ()
247+ step_time = train_stop - train_start
248+ if step_id >= args .skip_first_steps :
249+ train_time += step_time
250+ print ("step_id=" + str (step_id ) + " step_time=" + str (step_time ))
251+ print ("\n \n \n " )
252+ calc_step_num = step_num - args .skip_first_steps
253+ print ("calc_step_num=" + str (calc_step_num ) + " total_train_time=" +
254+ str (train_time ) + " ave_step_time=" + str (
255+ float (train_time ) / calc_step_num ))
251256
252257
253258if __name__ == '__main__' :
0 commit comments