Skip to content

Commit b9dd776

Browse files
committed
add mem_opt argument
1 parent daea299 commit b9dd776

File tree

1 file changed

+27
-22
lines changed

1 file changed

+27
-22
lines changed

fluid/SE-ResNeXt-152/train_parallel_do.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@
2626
def parse_args():
2727
parser = argparse.ArgumentParser('SE-ResNeXt-152 parallel profile.')
2828
parser.add_argument('--per_gpu_batch_size', type=int, default=12, help='')
29+
parser.add_argument(
30+
'--use_mem_opt',
31+
type=distutils.util.strtobool,
32+
default=True,
33+
help='use memory optimize')
2934
parser.add_argument(
3035
'--skip_first_steps',
3136
type=int,
@@ -212,7 +217,8 @@ def train():
212217
regularization=fluid.regularizer.L2Decay(1e-4))
213218
opts = optimizer.minimize(avg_cost)
214219

215-
fluid.memory_optimize(fluid.default_main_program())
220+
if args.use_mem_opt:
221+
fluid.memory_optimize(fluid.default_main_program())
216222

217223
place = fluid.CUDAPlace(0)
218224
# place = fluid.CPUPlace()
@@ -227,27 +233,26 @@ def train():
227233
feed_dict = feeder.feed(data)
228234

229235
for pass_id in range(1):
230-
with profiler.profiler('All', 'total', '/tmp/profile') as prof:
231-
train_time = 0.0
232-
233-
for step_id in range(step_num):
234-
train_start = time.time()
235-
exe.run(fluid.default_main_program(),
236-
feed=feeder.feed(train_reader_iter.next())
237-
if args.use_python_reader else feed_dict,
238-
fetch_list=[],
239-
use_program_cache=True)
240-
train_stop = time.time()
241-
step_time = train_stop - train_start
242-
if step_id >= args.skip_first_steps:
243-
train_time += step_time
244-
print("step_id=" + str(step_id) + " step_time=" + str(
245-
step_time))
246-
print("\n\n\n")
247-
calc_step_num = step_num - args.skip_first_steps
248-
print("calc_step_num=" + str(calc_step_num) + " total_train_time=" +
249-
str(train_time) + " ave_step_time=" + str(
250-
float(train_time) / calc_step_num))
236+
#with profiler.profiler('All', 'total', '/tmp/profile') as prof:
237+
train_time = 0.0
238+
239+
for step_id in range(step_num):
240+
train_start = time.time()
241+
exe.run(fluid.default_main_program(),
242+
feed=feeder.feed(train_reader_iter.next())
243+
if args.use_python_reader else feed_dict,
244+
fetch_list=[],
245+
use_program_cache=True)
246+
train_stop = time.time()
247+
step_time = train_stop - train_start
248+
if step_id >= args.skip_first_steps:
249+
train_time += step_time
250+
print("step_id=" + str(step_id) + " step_time=" + str(step_time))
251+
print("\n\n\n")
252+
calc_step_num = step_num - args.skip_first_steps
253+
print("calc_step_num=" + str(calc_step_num) + " total_train_time=" +
254+
str(train_time) + " ave_step_time=" + str(
255+
float(train_time) / calc_step_num))
251256

252257

253258
if __name__ == '__main__':

0 commit comments

Comments
 (0)