Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions fluid/object_detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,9 @@ def if_exist(var):

fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)

train_exe = fluid.ParallelExecutor(
use_cuda=args.use_gpu, loss_name=loss.name)
if args.parallel:
Copy link
Contributor

@wanghaoshuang wanghaoshuang Apr 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以在判断之前让 train_exe = exe, 然后可以简化下line283~line289
抱歉看错了。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ParalleExe和Executor的run接口不同,Executor多了program。这样还是不太行~

train_exe = fluid.ParallelExecutor(
use_cuda=args.use_gpu, loss_name=loss.name)

train_reader = paddle.batch(
reader.train(data_args, train_file_list), batch_size=batch_size)
Expand Down Expand Up @@ -279,8 +280,13 @@ def test(pass_id, best_map):
prev_start_time = start_time
start_time = time.time()
if len(data) < devices_num: continue
loss_v, = train_exe.run(fetch_list=[loss.name],
feed_dict=feeder.feed(data))
if args.parallel:
loss_v, = train_exe.run(fetch_list=[loss.name],
feed_dict=feeder.feed(data))
else:
loss_v, = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[loss])
end_time = time.time()
loss_v = np.mean(np.array(loss_v))
if batch_id % 20 == 0:
Expand Down