Skip to content

Commit ddcf4c0

Browse files
committed
organize graph to have easy access via named ops
1 parent d11dd75 commit ddcf4c0

File tree

2 files changed

+26
-11
lines changed

2 files changed

+26
-11
lines changed

model.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ def __init__(self, args, infer=False):
3131

3232
self.cell = cell
3333

34-
self.input_data = tf.placeholder(dtype=tf.float32, shape=[None, args.seq_length, 3])
35-
self.target_data = tf.placeholder(dtype=tf.float32, shape=[None, args.seq_length, 3])
36-
self.initial_state = cell.zero_state(batch_size=args.batch_size, dtype=tf.float32)
34+
self.input_data = tf.placeholder(dtype=tf.float32, shape=[None, args.seq_length, 3], name='data_in')
35+
self.target_data = tf.placeholder(dtype=tf.float32, shape=[None, args.seq_length, 3], name='targets')
36+
zero_state = cell.zero_state(batch_size=args.batch_size, dtype=tf.float32)
37+
self.state_in = tf.identity(zero_state, name='state_in')
3738

3839
self.num_mixture = args.num_mixture
3940
NOUT = 1 + self.num_mixture * 6 # end_of_stroke + prob + 2*(mu + sig) + corr
@@ -45,10 +46,10 @@ def __init__(self, args, infer=False):
4546
inputs = tf.split(axis=1, num_or_size_splits=args.seq_length, value=self.input_data)
4647
inputs = [tf.squeeze(input_, [1]) for input_ in inputs]
4748

48-
outputs, last_state = tf.contrib.legacy_seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=None, scope='rnnlm')
49+
outputs, state_out = tf.contrib.legacy_seq2seq.rnn_decoder(inputs, self.state_in, cell, loop_function=None, scope='rnnlm')
4950
output = tf.reshape(tf.concat(axis=1, values=outputs), [-1, args.rnn_size])
5051
output = tf.nn.xw_plus_b(output, output_w, output_b)
51-
self.final_state = last_state
52+
self.state_out = tf.identity(state_out, name='state_out')
5253

5354
# reshape target data so that it is compatible with prediction shape
5455
flat_target_data = tf.reshape(self.target_data,[-1, 3])
@@ -113,6 +114,20 @@ def get_mixture_coef(output):
113114
return [z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_eos]
114115

115116
[o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_eos] = get_mixture_coef(output)
117+
118+
# I could put all of these in a single tensor for reading out, but this is more human readable
119+
data_out_pi = tf.identity(o_pi, "data_out_pi");
120+
data_out_mu1 = tf.identity(o_pi, "data_out_mu1");
121+
data_out_mu2 = tf.identity(o_pi, "data_out_mu2");
122+
data_out_sigma1 = tf.identity(o_pi, "data_out_sigma1");
123+
data_out_sigma2 = tf.identity(o_pi, "data_out_sigma2");
124+
data_out_corr = tf.identity(o_pi, "data_out_corr");
125+
data_out_eos = tf.identity(o_pi, "data_out_eos");
126+
127+
# sticking them all in one op anyway, makes it easier for freezing the graph later
128+
# IMPORTANT, this needs to stack the named ops above (data_out_XXX), not the prev ops (o_XXX)
129+
# otherwise when I freeze the graph up to this point, the named versions will be cut
130+
data_out = tf.stack([data_out_pi, data_out_mu1, data_out_mu2, data_out_sigma1, data_out_sigma2, data_out_corr, data_out_eos], name="data_out")
116131

117132
self.pi = o_pi
118133
self.mu1 = o_mu1
@@ -161,9 +176,9 @@ def sample_gaussian_2d(mu1, mu2, s1, s2, rho):
161176

162177
for i in range(num):
163178

164-
feed = {self.input_data: prev_x, self.initial_state:prev_state}
179+
feed = {self.input_data: prev_x, self.state_in:prev_state}
165180

166-
[o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_eos, next_state] = sess.run([self.pi, self.mu1, self.mu2, self.sigma1, self.sigma2, self.corr, self.eos, self.final_state],feed)
181+
[o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_eos, next_state] = sess.run([self.pi, self.mu1, self.mu2, self.sigma1, self.sigma2, self.corr, self.eos, self.state_out],feed)
167182

168183
idx = get_pi_idx(random.random(), o_pi[0])
169184

train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ def train(args):
6363
sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
6464
data_loader.reset_batch_pointer()
6565
v_x, v_y = data_loader.validation_data()
66-
valid_feed = {model.input_data: v_x, model.target_data: v_y, model.initial_state: model.initial_state.eval()}
67-
state = model.initial_state.eval()
66+
valid_feed = {model.input_data: v_x, model.target_data: v_y, model.state_in: model.state_in.eval()}
67+
state = model.state_in.eval()
6868
for b in range(data_loader.num_batches):
6969
start = time.time()
7070
x, y = data_loader.next_batch()
71-
feed = {model.input_data: x, model.target_data: y, model.initial_state: state}
72-
summary, train_loss, state, _ = sess.run([merged_summaries, model.cost, model.final_state, model.train_op], feed)
71+
feed = {model.input_data: x, model.target_data: y, model.state_in: state}
72+
summary, train_loss, state, _ = sess.run([merged_summaries, model.cost, model.state_out, model.train_op], feed)
7373
summary_writer.add_summary(summary, e * data_loader.num_batches + b)
7474
valid_loss, = sess.run([model.cost], valid_feed)
7575
end = time.time()

0 commit comments

Comments
 (0)