Skip to content

Commit 4813107

Browse files
committed
code refine
1 parent b9dd776 commit 4813107

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

fluid/SE-ResNeXt-152/train_parallel_executor.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ def parse_args():
3535
parser.add_argument(
3636
'--number_iteration',
3737
type=int,
38-
default=10,
38+
default=100,
3939
help='total batch num for per_gpu_batch_size')
40+
parser.add_argument('--display_step', type=int, default=1, help='')
4041

4142
args = parser.parse_args()
4243
return args
@@ -167,7 +168,7 @@ def SE_ResNeXt(input, class_dim, infer=False, layers=152):
167168

168169

169170
def net_conf(image, label, class_dim):
170-
out = SE_ResNeXt152(input=image, class_dim=class_dim)
171+
out = SE_ResNeXt(input=image, class_dim=class_dim)
171172
cost = fluid.layers.cross_entropy(input=out, label=label)
172173
avg_cost = fluid.layers.mean(x=cost)
173174
#accuracy = fluid.evaluator.Accuracy(input=out, label=label)
@@ -195,16 +196,16 @@ def train():
195196
startup = fluid.Program()
196197

197198
with fluid.program_guard(main, startup):
198-
data_file = fluid.layers.open_recordio_file(
199-
filename='./resnet_152.recordio_batch_size_12_3_224_224', # ./resnet_152.recordio_batch_size_2
199+
reader = fluid.layers.open_recordio_file(
200+
filename='./flowers.recordio',
200201
shapes=[[-1, 3, 224, 224], [-1, 1]],
201202
lod_levels=[0, 0],
202203
dtypes=['float32', 'int64'])
203-
image, label = fluid.layers.read_file(data_file)
204+
image, label = fluid.layers.read_file(reader)
204205

205206
prediction, avg_cost, accuracy, accuracy5 = net_conf(image, label,
206207
class_dim)
207-
208+
#optimizer = fluid.optimizer.SGD(learning_rate=0.002)
208209
optimizer = fluid.optimizer.Momentum(
209210
learning_rate=fluid.layers.piecewise_decay(
210211
boundaries=[100], values=[0.1, 0.2]),
@@ -217,20 +218,26 @@ def train():
217218

218219
exe = fluid.ParallelExecutor(loss_name=avg_cost.name, use_cuda=True)
219220

220-
batch_id = 0
221+
batch_id = -1
221222
time_record = []
222-
# with profiler.profiler('All', 'total', '/tmp/profile') as prof:
223+
223224
for i in xrange(args.number_iteration):
225+
batch_id += 1
226+
if batch_id >= 5 and batch_id < 7:
227+
with profiler.profiler('All', 'total', '/tmp/profile') as prof:
228+
exe.run([])
229+
continue
230+
224231
t1 = time.time()
225-
exe.run([avg_cost.name] if batch_id % 10 == 0 else [])
232+
cost_val = exe.run([avg_cost.name]
233+
if batch_id % args.display_step == 0 else [])
226234
t2 = time.time()
227235
period = t2 - t1
228236
time_record.append(period)
229237

230-
if batch_id % 10 == 0:
231-
print("trainbatch {0}, time{1}".format(batch_id,
232-
"%2.2f sec" % period))
233-
batch_id += 1
238+
if batch_id % args.display_step == 0:
239+
print("iter=%d, elapse=%f, cost=%s" %
240+
(batch_id, period, np.array(cost_val[0])))
234241

235242
del time_record[0]
236243
for ele in time_record:

0 commit comments

Comments
 (0)