Skip to content

Commit 5d1dc2f

Browse files
authored
Support training with single device. (#869)
1 parent e7684f0 commit 5d1dc2f

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

fluid/object_detection/train.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,9 @@ def if_exist(var):
239239

240240
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
241241

242-
train_exe = fluid.ParallelExecutor(
243-
use_cuda=args.use_gpu, loss_name=loss.name)
242+
if args.parallel:
243+
train_exe = fluid.ParallelExecutor(
244+
use_cuda=args.use_gpu, loss_name=loss.name)
244245

245246
train_reader = paddle.batch(
246247
reader.train(data_args, train_file_list), batch_size=batch_size)
@@ -279,8 +280,13 @@ def test(pass_id, best_map):
279280
prev_start_time = start_time
280281
start_time = time.time()
281282
if len(data) < devices_num: continue
282-
loss_v, = train_exe.run(fetch_list=[loss.name],
283-
feed_dict=feeder.feed(data))
283+
if args.parallel:
284+
loss_v, = train_exe.run(fetch_list=[loss.name],
285+
feed_dict=feeder.feed(data))
286+
else:
287+
loss_v, = exe.run(fluid.default_main_program(),
288+
feed=feeder.feed(data),
289+
fetch_list=[loss])
284290
end_time = time.time()
285291
loss_v = np.mean(np.array(loss_v))
286292
if batch_id % 20 == 0:

0 commit comments

Comments
 (0)