@@ -691,76 +691,84 @@ Below is a simple usage example for an image classification use case.
691
691
692
692
``` python
693
693
with strategy.scope():
694
- model = resnet.ResNetV1(resnet. BLOCKS_50 )
695
- optimizer = tf.train.MomentumOptimizer (learning_rate, 0.9 )
694
+ model = tf.keras.applications.ResNet50( weights = None )
695
+ optimizer = tf.keras.optimizers.SGD (learning_rate, momentum = 0.9 )
696
696
697
697
def input_fn (ctx ):
698
698
return imagenet.ImageNet(ctx.get_per_replica_batch_size(effective_batch_size))
699
699
700
- def step_fn (inputs ):
701
- image, label = inputs
700
+ input_iterator = strategy.make_input_iterator(input_fn)
702
701
703
- logits = model(images)
704
- cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
705
- logits = logits, labels = label)
706
- loss = tf.reduce_mean(cross_entropy)
707
- train_op = optimizer.minimize(loss)
708
- with tf.control_dependencies([train_op]):
709
- return tf.identity(loss)
702
+ @tf.function
703
+ def train_step ():
704
+ def step_fn (inputs ):
705
+ image, label = inputs
710
706
711
- input_iterator = strategy.make_input_iterator(input_fn)
712
- per_replica_losses = strategy.run(step_fn, input_iterator)
713
- mean_loss = strategy.reduce(per_replica_losses)
714
-
715
- with tf.Session(config = session_config) as session:
716
- session.run(strategy.initialize())
717
- session.run(input_iterator.initialize())
718
- for _ in range (num_train_steps):
719
- loss = session.run(mean_loss)
720
- session.run(strategy.finalize())
707
+ with tf.GradientTape() as tape:
708
+ logits = model(images)
709
+ cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
710
+ logits = logits, labels = label)
711
+ loss = tf.reduce_mean(cross_entropy)
712
+
713
+ grads = tape.gradient(loss, model.trainable_variables)
714
+ optimizer.apply_gradients(list (zip (grads, model.trainable_variables)))
715
+ return loss
716
+
717
+ per_replica_losses = strategy.run(step_fn, input_iterator)
718
+ mean_loss = strategy.reduce(AggregationType.MEAN , per_replica_losses)
719
+ return mean_loss
720
+
721
+ strategy.initialize()
722
+ input_iterator.initialize()
723
+ for _ in range (num_train_steps):
724
+ loss = train_step()
725
+ strategy.finalize()
721
726
```
722
727
723
728
#### Evaluation
724
729
725
730
``` python
726
731
with strategy.scope():
727
- model = resnet.ResNetV1(resnet. BLOCKS_50 )
732
+ model = tf.keras.applications.ResNet50( weights = None )
728
733
729
734
def eval_input_fn (ctx ):
730
735
del ctx # Unused.
731
736
return imagenet.ImageNet(
732
737
eval_batch_size, subset = " valid" , shuffle = False , num_epochs = 1 )
733
738
734
- def eval_top1_accuracy (inputs ):
735
- image, label = inputs
736
- logits = model(images)
737
- predicted_label = tf.argmax(logits, axis = 1 )
738
- top_1_acc = tf.reduce_mean(
739
- tf.cast(tf.equal(predicted_label, label), tf.float32))
740
- return top1_acc
741
-
742
739
eval_input_iterator = strategy.make_input_iterator(
743
740
eval_input_fn, input_replication_mode = InputReplicationMode.SINGLE )
744
- per_replica_top1_accs = strategy.run(eval_top1_accuracy, eval_input_iterator)
745
- mean_top1_acc = strategy.reduce(per_replica_top1_accs)
746
741
747
- with tf.Session(config = session_config) as session:
748
- session.run(strategy.initialize())
742
+ @tf.function
743
+ def eval ():
744
+ def eval_top1_accuracy (inputs ):
745
+ image, label = inputs
746
+ logits = model(images)
747
+ predicted_label = tf.argmax(logits, axis = 1 )
748
+ top_1_acc = tf.reduce_mean(
749
+ tf.cast(tf.equal(predicted_label, label), tf.float32))
750
+ return top1_acc
751
+
752
+ per_replica_top1_accs = strategy.run(eval_top1_accuracy, eval_input_iterator)
753
+ mean_top1_acc = strategy.reduce(AggregationType.MEAN , per_replica_top1_accs)
754
+ return mean_top1_acc
755
+
756
+ strategy.initialize()
757
+ while True :
758
+ while not has_new_checkpoint():
759
+ sleep(60 )
760
+
761
+ load_checkpoint()
762
+
763
+ # Do a sweep over the entire validation set.
764
+ eval_input_iterator.initialize()
749
765
while True :
750
- while not has_new_checkpoint():
751
- sleep(60 )
752
-
753
- load_checkpoint()
754
-
755
- # Do a sweep over the entire validation set.
756
- session.run(eval_input_iterator.initialize())
757
- while True :
758
- try :
759
- top1_acc = session.run(mean_top1_acc)
760
- ...
761
- except tf.errors.OutOfRangeError:
762
- break
763
- session.run(strategy.finalize())
766
+ try :
767
+ top1_acc = eval ()
768
+ ...
769
+ except tf.errors.OutOfRangeError:
770
+ break
771
+ strategy.finalize()
764
772
```
765
773
766
774
#### Sharded Input Pipeline
@@ -801,42 +809,43 @@ with strategy.scope():
801
809
discriminator = GoodfellowDiscriminator(DefaultDiscriminator2D())
802
810
generator = DefaultGenerator2D()
803
811
gan = GAN(discriminator, generator)
804
- disc_optimizer = tf.train.AdamOptimizer (disc_learning_rate, beta1 = 0.5 , beta2 = 0.9 )
805
- gen_optimizer = tf.train.AdamOptimizer (gen_learning_rate, beta1 = 0.5 , beta2 = 0.9 )
812
+ disc_optimizer = tf.keras.optimizers.Adam (disc_learning_rate)
813
+ gen_optimizer = tf.keras.optimizers.Adam (gen_learning_rate)
806
814
807
815
def discriminator_step (inputs ):
808
816
image, noise = inputs
809
- gan_output = gan.connect(image, noise)
810
- disc_loss, disc_vars = gan_output.discriminator_loss_and_vars()
811
- disc_train_op = disc_optimizer.minimize(disc_loss, var_list = disc_vars)
812
-
813
- with tf.control_dependencies([disc_train_op]):
814
- return tf.identity(disc_loss)
817
+
818
+ with tf.GradientTape() as tape:
819
+ gan_output = gan.connect(image, noise)
820
+ disc_loss, disc_vars = gan_output.discriminator_loss_and_vars()
821
+
822
+ grads = tape.gradients(disc_loss, disc_vars)
823
+ disc_optimizer.apply_gradients(list (zip (grads, disc_vars)))
824
+ return disc_loss
815
825
816
826
def generator_step (inputs ):
817
827
image, noise = inputs
818
- gan_output = gan.connect(image, noise)
819
- gen_loss, gen_vars = gan_output.generator_loss_and_vars()
820
- gen_train_op = gen_optimizer.minimize(gen_loss, var_list = gen_vars)
821
-
822
- with tf.control_dependencies([gen_train_op]):
823
- return tf.identity(gen_loss)
828
+
829
+ with tf.GradientTape() as tape:
830
+ gan_output = gan.connect(image, noise)
831
+ gen_loss, gen_vars = gan_output.generator_loss_and_vars()
832
+
833
+ grads = tape.gradient(gen_loss, gen_vars)
834
+ gen_optimizer.apply_gradients(list (zip (grads, gen_vars)))
835
+ return gen_loss
824
836
825
837
input_iterator = strategy.make_input_iterator(input_fn)
826
- per_replica_disc_losses = strategy.run(discriminator_step, input_iterator)
827
- per_replica_gen_losses = strategy.run(generator_step, input_iterator)
828
- mean_disc_loss = strategy.reduce(per_replica_disc_losses)
829
- mean_gen_loss = strategy.reduce(per_replica_gen_losses)
830
-
831
- with tf.Session() as session:
832
- session.run(strategy.initialize())
833
- session.run(input_iterator.initialize())
834
- for _ in range (num_train_steps):
835
- for _ in range (num_disc_steps):
836
- disc_loss = session.run(mean_disc_loss)
837
- for _ in range (num_gen_steps):
838
- gen_loss = session.run(mean_gen_loss)
839
- session.run(strategy.finalize())
838
+
839
+ strategy.initialize()
840
+ input_iterator.initialize()
841
+ for _ in range (num_train_steps):
842
+ for _ in range (num_disc_steps):
843
+ per_replica_disc_losses = strategy.run(discriminator_step, input_iterator)
844
+ mean_disc_loss = strategy.reduce(AggregationType.MEAN , per_replica_disc_losses)
845
+ for _ in range (num_gen_steps):
846
+ per_replica_gen_losses = strategy.run(generator_step, input_iterator)
847
+ mean_gen_loss = strategy.reduce(AggregationType.MEAN , per_replica_gen_losses)
848
+ strategy.finalize()
840
849
```
841
850
842
851
### Reinforcement Learning
@@ -846,11 +855,9 @@ This is an example of
846
855
Reinforcement Learning system, converted to eager style.
847
856
848
857
``` python
849
- tf.enable_eager_execution()
850
-
851
858
with strategy.scope():
852
859
agent = Agent(num_actions, hidden_size, entropy_cost, baseline_cost)
853
- optimizer = tf.train.RMSPropOptimizer (learning_rate)
860
+ optimizer = tf.keras.optimizers.RMSprop (learning_rate)
854
861
855
862
# Queues of trajectories from actors.
856
863
queues = []
@@ -867,9 +874,12 @@ def learner_input(ctx):
867
874
return dequeue_batch
868
875
869
876
def learner_step (trajectories ):
870
- loss = tf.reduce_sum(agent.compute_loss(trajectories))
877
+ with tf.GradientTape() as tape:
878
+ loss = tf.reduce_sum(agent.compute_loss(trajectories))
879
+
871
880
agent_vars = agent.get_all_variables()
872
- optimizer.minimize(loss, var_list = agent_vars)
881
+ grads = tape.gradient(loss, agent_vars)
882
+ optimizer.apply_gradients(list (zip (grads, agent_vars)))
873
883
return loss, agent_vars
874
884
875
885
# Create learner inputs.
@@ -893,7 +903,7 @@ strategy.initialize()
893
903
for _ in range (num_train_steps):
894
904
per_replica_outputs = strategy.run(learner_step, learner_inputs)
895
905
per_replica_losses, updated_agent_var_copies = zip (* per_replica_outputs)
896
- mean_loss = strategy.reduce(per_replica_losses)
906
+ mean_loss = strategy.reduce(AggregationType. MEAN , per_replica_losses)
897
907
898
908
strategy.finalize()
899
909
```
0 commit comments