Skip to content

Commit 05e08e1

Browse files
chr1sj0nesewilderj
authored andcommitted
Switch DistStrat revised API examples to TensorFlow 2 style. (#63)
1 parent acf19fa commit 05e08e1

File tree

1 file changed

+92
-82
lines changed

1 file changed

+92
-82
lines changed

rfcs/20181016-replicator.md

Lines changed: 92 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -691,76 +691,84 @@ Below is a simple usage example for an image classification use case.
691691

692692
```python
693693
with strategy.scope():
694-
model = resnet.ResNetV1(resnet.BLOCKS_50)
695-
optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9)
694+
model = tf.keras.applications.ResNet50(weights=None)
695+
optimizer = tf.keras.optimizers.SGD(learning_rate, momentum=0.9)
696696

697697
def input_fn(ctx):
698698
return imagenet.ImageNet(ctx.get_per_replica_batch_size(effective_batch_size))
699699

700-
def step_fn(inputs):
701-
image, label = inputs
700+
input_iterator = strategy.make_input_iterator(input_fn)
702701

703-
logits = model(images)
704-
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
705-
logits=logits, labels=label)
706-
loss = tf.reduce_mean(cross_entropy)
707-
train_op = optimizer.minimize(loss)
708-
with tf.control_dependencies([train_op]):
709-
return tf.identity(loss)
702+
@tf.function
703+
def train_step():
704+
def step_fn(inputs):
705+
image, label = inputs
710706

711-
input_iterator = strategy.make_input_iterator(input_fn)
712-
per_replica_losses = strategy.run(step_fn, input_iterator)
713-
mean_loss = strategy.reduce(per_replica_losses)
714-
715-
with tf.Session(config=session_config) as session:
716-
session.run(strategy.initialize())
717-
session.run(input_iterator.initialize())
718-
for _ in range(num_train_steps):
719-
loss = session.run(mean_loss)
720-
session.run(strategy.finalize())
707+
with tf.GradientTape() as tape:
708+
logits = model(images)
709+
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
710+
logits=logits, labels=label)
711+
loss = tf.reduce_mean(cross_entropy)
712+
713+
grads = tape.gradient(loss, model.trainable_variables)
714+
optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
715+
return loss
716+
717+
per_replica_losses = strategy.run(step_fn, input_iterator)
718+
mean_loss = strategy.reduce(AggregationType.MEAN, per_replica_losses)
719+
return mean_loss
720+
721+
strategy.initialize()
722+
input_iterator.initialize()
723+
for _ in range(num_train_steps):
724+
loss = train_step()
725+
strategy.finalize()
721726
```
722727

723728
#### Evaluation
724729

725730
```python
726731
with strategy.scope():
727-
model = resnet.ResNetV1(resnet.BLOCKS_50)
732+
model = tf.keras.applications.ResNet50(weights=None)
728733

729734
def eval_input_fn(ctx):
730735
del ctx # Unused.
731736
return imagenet.ImageNet(
732737
eval_batch_size, subset="valid", shuffle=False, num_epochs=1)
733738

734-
def eval_top1_accuracy(inputs):
735-
image, label = inputs
736-
logits = model(images)
737-
predicted_label = tf.argmax(logits, axis=1)
738-
top_1_acc = tf.reduce_mean(
739-
tf.cast(tf.equal(predicted_label, label), tf.float32))
740-
return top1_acc
741-
742739
eval_input_iterator = strategy.make_input_iterator(
743740
eval_input_fn, input_replication_mode=InputReplicationMode.SINGLE)
744-
per_replica_top1_accs = strategy.run(eval_top1_accuracy, eval_input_iterator)
745-
mean_top1_acc = strategy.reduce(per_replica_top1_accs)
746741

747-
with tf.Session(config=session_config) as session:
748-
session.run(strategy.initialize())
742+
@tf.function
743+
def eval():
744+
def eval_top1_accuracy(inputs):
745+
image, label = inputs
746+
logits = model(images)
747+
predicted_label = tf.argmax(logits, axis=1)
748+
top_1_acc = tf.reduce_mean(
749+
tf.cast(tf.equal(predicted_label, label), tf.float32))
750+
return top1_acc
751+
752+
per_replica_top1_accs = strategy.run(eval_top1_accuracy, eval_input_iterator)
753+
mean_top1_acc = strategy.reduce(AggregationType.MEAN, per_replica_top1_accs)
754+
return mean_top1_acc
755+
756+
strategy.initialize()
757+
while True:
758+
while not has_new_checkpoint():
759+
sleep(60)
760+
761+
load_checkpoint()
762+
763+
# Do a sweep over the entire validation set.
764+
eval_input_iterator.initialize()
749765
while True:
750-
while not has_new_checkpoint():
751-
sleep(60)
752-
753-
load_checkpoint()
754-
755-
# Do a sweep over the entire validation set.
756-
session.run(eval_input_iterator.initialize())
757-
while True:
758-
try:
759-
top1_acc = session.run(mean_top1_acc)
760-
...
761-
except tf.errors.OutOfRangeError:
762-
break
763-
session.run(strategy.finalize())
766+
try:
767+
top1_acc = eval()
768+
...
769+
except tf.errors.OutOfRangeError:
770+
break
771+
strategy.finalize()
764772
```
765773

766774
#### Sharded Input Pipeline
@@ -801,42 +809,43 @@ with strategy.scope():
801809
discriminator = GoodfellowDiscriminator(DefaultDiscriminator2D())
802810
generator = DefaultGenerator2D()
803811
gan = GAN(discriminator, generator)
804-
disc_optimizer = tf.train.AdamOptimizer(disc_learning_rate, beta1=0.5, beta2=0.9)
805-
gen_optimizer = tf.train.AdamOptimizer(gen_learning_rate, beta1=0.5, beta2=0.9)
812+
disc_optimizer = tf.keras.optimizers.Adam(disc_learning_rate)
813+
gen_optimizer = tf.keras.optimizers.Adam(gen_learning_rate)
806814

807815
def discriminator_step(inputs):
808816
image, noise = inputs
809-
gan_output = gan.connect(image, noise)
810-
disc_loss, disc_vars = gan_output.discriminator_loss_and_vars()
811-
disc_train_op = disc_optimizer.minimize(disc_loss, var_list=disc_vars)
812-
813-
with tf.control_dependencies([disc_train_op]):
814-
return tf.identity(disc_loss)
817+
818+
with tf.GradientTape() as tape:
819+
gan_output = gan.connect(image, noise)
820+
disc_loss, disc_vars = gan_output.discriminator_loss_and_vars()
821+
822+
grads = tape.gradients(disc_loss, disc_vars)
823+
disc_optimizer.apply_gradients(list(zip(grads, disc_vars)))
824+
return disc_loss
815825

816826
def generator_step(inputs):
817827
image, noise = inputs
818-
gan_output = gan.connect(image, noise)
819-
gen_loss, gen_vars = gan_output.generator_loss_and_vars()
820-
gen_train_op = gen_optimizer.minimize(gen_loss, var_list=gen_vars)
821-
822-
with tf.control_dependencies([gen_train_op]):
823-
return tf.identity(gen_loss)
828+
829+
with tf.GradientTape() as tape:
830+
gan_output = gan.connect(image, noise)
831+
gen_loss, gen_vars = gan_output.generator_loss_and_vars()
832+
833+
grads = tape.gradient(gen_loss, gen_vars)
834+
gen_optimizer.apply_gradients(list(zip(grads, gen_vars)))
835+
return gen_loss
824836

825837
input_iterator = strategy.make_input_iterator(input_fn)
826-
per_replica_disc_losses = strategy.run(discriminator_step, input_iterator)
827-
per_replica_gen_losses = strategy.run(generator_step, input_iterator)
828-
mean_disc_loss = strategy.reduce(per_replica_disc_losses)
829-
mean_gen_loss = strategy.reduce(per_replica_gen_losses)
830-
831-
with tf.Session() as session:
832-
session.run(strategy.initialize())
833-
session.run(input_iterator.initialize())
834-
for _ in range(num_train_steps):
835-
for _ in range(num_disc_steps):
836-
disc_loss = session.run(mean_disc_loss)
837-
for _ in range(num_gen_steps):
838-
gen_loss = session.run(mean_gen_loss)
839-
session.run(strategy.finalize())
838+
839+
strategy.initialize()
840+
input_iterator.initialize()
841+
for _ in range(num_train_steps):
842+
for _ in range(num_disc_steps):
843+
per_replica_disc_losses = strategy.run(discriminator_step, input_iterator)
844+
mean_disc_loss = strategy.reduce(AggregationType.MEAN, per_replica_disc_losses)
845+
for _ in range(num_gen_steps):
846+
per_replica_gen_losses = strategy.run(generator_step, input_iterator)
847+
mean_gen_loss = strategy.reduce(AggregationType.MEAN, per_replica_gen_losses)
848+
strategy.finalize()
840849
```
841850

842851
### Reinforcement Learning
@@ -846,11 +855,9 @@ This is an example of
846855
Reinforcement Learning system, converted to eager style.
847856

848857
```python
849-
tf.enable_eager_execution()
850-
851858
with strategy.scope():
852859
agent = Agent(num_actions, hidden_size, entropy_cost, baseline_cost)
853-
optimizer = tf.train.RMSPropOptimizer(learning_rate)
860+
optimizer = tf.keras.optimizers.RMSprop(learning_rate)
854861

855862
# Queues of trajectories from actors.
856863
queues = []
@@ -867,9 +874,12 @@ def learner_input(ctx):
867874
return dequeue_batch
868875

869876
def learner_step(trajectories):
870-
loss = tf.reduce_sum(agent.compute_loss(trajectories))
877+
with tf.GradientTape() as tape:
878+
loss = tf.reduce_sum(agent.compute_loss(trajectories))
879+
871880
agent_vars = agent.get_all_variables()
872-
optimizer.minimize(loss, var_list=agent_vars)
881+
grads = tape.gradient(loss, agent_vars)
882+
optimizer.apply_gradients(list(zip(grads, agent_vars)))
873883
return loss, agent_vars
874884

875885
# Create learner inputs.
@@ -893,7 +903,7 @@ strategy.initialize()
893903
for _ in range(num_train_steps):
894904
per_replica_outputs = strategy.run(learner_step, learner_inputs)
895905
per_replica_losses, updated_agent_var_copies = zip(*per_replica_outputs)
896-
mean_loss = strategy.reduce(per_replica_losses)
906+
mean_loss = strategy.reduce(AggregationType.MEAN, per_replica_losses)
897907

898908
strategy.finalize()
899909
```

0 commit comments

Comments
 (0)