Skip to content

Commit b8b3d9c

Browse files
committed
modified optimizer and learning rate to match the paper of inception V3
1 parent 0e915f6 commit b8b3d9c

File tree

1 file changed

+46
-21
lines changed

1 file changed

+46
-21
lines changed

example/tutorial_imagenet_inceptionV3_distributed.py

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
#! /usr/bin/python
22
# -*- coding: utf8 -*-
33

4+
# Example of training an Inception V3 model with ImageNet. The parameters are set as in the
5+
# best results of the paper: https://arxiv.org/abs/1512.00567
6+
# The dataset can be downloaded from http://www.image-net.org/ or from the Kaggle competition:
7+
# https://www.kaggle.com/c/imagenet-object-localization-challenge/data
8+
49
import os
510
import time
611
import multiprocessing
@@ -206,7 +211,7 @@ def end(self, session):
206211
########## METRICS ##########
207212

208213
def calculate_metrics(predicted_batch, real_batch, threshold=0.5, is_training=False, ema_decay=0.9):
209-
with tf.variable_scope('metrics'):
214+
with tf.variable_scope('metric'):
210215
threshold_graph = tf.constant(threshold, name='threshold')
211216
zero_point_five = tf.constant(0.5)
212217
predicted_bool = tf.greater_equal(predicted_batch, threshold_graph)
@@ -265,12 +270,22 @@ def calculate_metrics(predicted_batch, real_batch, threshold=0.5, is_training=Fa
265270
tf.summary.scalar('recall', recall)
266271
tf.summary.scalar('fall-out', fall_out)
267272
tf.summary.scalar('f1-score', f1_score)
268-
269-
metrics_ops = {#'accuracy' : accuracy,
270-
'precision': precision,
271-
'recall' : recall,
272-
'fall-out' : fall_out,
273-
'f1-score' : f1_score}
273+
# tf.summary.scalar('true_positive', tp)
274+
# tf.summary.scalar('true_negative', tn)
275+
# tf.summary.scalar('false_positive', fp)
276+
# tf.summary.scalar('false_negative', fn)
277+
278+
metrics_ops = {
279+
# 'accuracy' : accuracy,
280+
'precision' : precision,
281+
'recall' : recall,
282+
'fall-out' : fall_out,
283+
'f1-score' : f1_score,
284+
# 'true positive' : tp,
285+
# 'true negative' : tn,
286+
# 'false positive': fp,
287+
# 'false negative': fn,
288+
}
274289
return init_op, average_ops, metrics_ops
275290

276291

@@ -338,20 +353,28 @@ def run_worker(task_spec, checkpoints_path, batch_size=32, epochs=10):
338353
target=one_hot_classes,
339354
name='loss')
340355
steps_per_epoch = dataset_size / batch_size
341-
learning_rate = tf.train.exponential_decay(learning_rate=0.1,
356+
learning_rate = tf.train.exponential_decay(learning_rate=0.045,
342357
global_step=global_step,
343-
decay_steps=steps_per_epoch, # 1 epochs
344-
decay_rate=0.5,
358+
decay_steps=steps_per_epoch * 2, # 2 epochs
359+
decay_rate=0.94,
345360
staircase=True,
346361
name='learning_rate')
347362
optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate,
348363
decay=0.9,
349364
epsilon=1.0)
350-
train_op = optimizer.minimize(loss=loss,
351-
var_list=network.all_params,
352-
global_step=global_step)
365+
# clip and apply gradients
366+
gvs = optimizer.compute_gradients(loss=loss,
367+
var_list=network.all_params)
368+
capped_gvs = []
369+
for grad, var in gvs:
370+
if grad is not None:
371+
grad = tf.clip_by_value(grad, -2., 2.)
372+
capped_gvs.append((grad,var))
373+
train_op = optimizer.apply_gradients(grads_and_vars=capped_gvs,
374+
global_step=global_step)
353375
# metrics
354-
tf.summary.scalar('loss', loss)
376+
tf.summary.scalar('learning_rate/value', learning_rate)
377+
tf.summary.scalar('loss/logits', loss)
355378
_, metrics_average_ops, metrics_ops = calculate_metrics(predicted_batch=predictions,
356379
real_batch=one_hot_classes,
357380
is_training=True)
@@ -364,7 +387,8 @@ def run_worker(task_spec, checkpoints_path, batch_size=32, epochs=10):
364387
hooks=hooks,
365388
checkpoint_dir=checkpoints_path,
366389
save_summaries_secs=None,
367-
save_summaries_steps=100) as sess:
390+
save_summaries_steps=300,
391+
save_checkpoint_secs=60 * 60) as sess:
368392
# print network information
369393
if task_spec is None or task_spec.is_master():
370394
network.print_params(False, session=sess)
@@ -375,17 +399,18 @@ def run_worker(task_spec, checkpoints_path, batch_size=32, epochs=10):
375399
last_log_time = time.time()
376400
next_log_time = last_log_time + 60
377401
while not sess.should_stop():
378-
step, loss_val, _, metrics = \
379-
sess.run([global_step, loss, train_op, metrics_ops])
402+
step, loss_val, learning_rate_val, _, metrics = \
403+
sess.run([global_step, loss, learning_rate, train_op, metrics_ops])
380404
if task_spec is None or task_spec.is_master():
381405
now = time.time()
382406
if now > next_log_time:
383407
last_log_time = now
384408
next_log_time = last_log_time + 60
385-
current_epoch = '{:.2}'.format(float(step) / steps_per_epoch)
409+
current_epoch = '{:.3f}'.format(float(step) / steps_per_epoch)
386410
max_steps = epochs * steps_per_epoch
387-
logging.info('Epoch: {}/{} Steps: {}/{} Loss: {} Metrics: {}'.format(
388-
current_epoch, epochs, step, max_steps, loss_val, metrics))
411+
m = 'Epoch: {}/{} Steps: {}/{} Loss: {} Learning rate: {} Metrics: {}'
412+
logging.info(m.format(current_epoch, epochs, step, max_steps,
413+
loss_val, learning_rate_val, metrics))
389414
except OutOfRangeError:
390415
pass
391416

@@ -401,7 +426,7 @@ def run_worker(task_spec, checkpoints_path, batch_size=32, epochs=10):
401426
parser = argparse.ArgumentParser()
402427
parser.add_argument('--with_evaluator', dest='with_evaluator', action='store_true')
403428
parser.add_argument('--batch_size', dest='batch_size', type=int, default=32)
404-
parser.add_argument('--epochs', dest='epochs', type=int, default=10)
429+
parser.add_argument('--epochs', dest='epochs', type=int, default=100)
405430
parser.set_defaults(with_evaluator=False)
406431
args = parser.parse_args()
407432
logging.info('Batch size: {}'.format(args.batch_size))

0 commit comments

Comments
 (0)