@@ -31,9 +31,10 @@ def __init__(self, args, infer=False):
31
31
32
32
self .cell = cell
33
33
34
- self .input_data = tf .placeholder (dtype = tf .float32 , shape = [None , args .seq_length , 3 ])
35
- self .target_data = tf .placeholder (dtype = tf .float32 , shape = [None , args .seq_length , 3 ])
36
- self .initial_state = cell .zero_state (batch_size = args .batch_size , dtype = tf .float32 )
34
+ self .input_data = tf .placeholder (dtype = tf .float32 , shape = [None , args .seq_length , 3 ], name = 'data_in' )
35
+ self .target_data = tf .placeholder (dtype = tf .float32 , shape = [None , args .seq_length , 3 ], name = 'targets' )
36
+ zero_state = cell .zero_state (batch_size = args .batch_size , dtype = tf .float32 )
37
+ self .state_in = tf .identity (zero_state , name = 'state_in' )
37
38
38
39
self .num_mixture = args .num_mixture
39
40
NOUT = 1 + self .num_mixture * 6 # end_of_stroke + prob + 2*(mu + sig) + corr
@@ -45,10 +46,10 @@ def __init__(self, args, infer=False):
45
46
inputs = tf .split (axis = 1 , num_or_size_splits = args .seq_length , value = self .input_data )
46
47
inputs = [tf .squeeze (input_ , [1 ]) for input_ in inputs ]
47
48
48
- outputs , last_state = tf .contrib .legacy_seq2seq .rnn_decoder (inputs , self .initial_state , cell , loop_function = None , scope = 'rnnlm' )
49
+ outputs , state_out = tf .contrib .legacy_seq2seq .rnn_decoder (inputs , self .state_in , cell , loop_function = None , scope = 'rnnlm' )
49
50
output = tf .reshape (tf .concat (axis = 1 , values = outputs ), [- 1 , args .rnn_size ])
50
51
output = tf .nn .xw_plus_b (output , output_w , output_b )
51
- self .final_state = last_state
52
+ self .state_out = tf . identity ( state_out , name = 'state_out' )
52
53
53
54
# reshape target data so that it is compatible with prediction shape
54
55
flat_target_data = tf .reshape (self .target_data ,[- 1 , 3 ])
@@ -113,6 +114,20 @@ def get_mixture_coef(output):
113
114
return [z_pi , z_mu1 , z_mu2 , z_sigma1 , z_sigma2 , z_corr , z_eos ]
114
115
115
116
[o_pi , o_mu1 , o_mu2 , o_sigma1 , o_sigma2 , o_corr , o_eos ] = get_mixture_coef (output )
117
+
118
+ # I could put all of these in a single tensor for reading out, but this is more human readable
119
+ data_out_pi = tf .identity (o_pi , "data_out_pi" );
120
+ data_out_mu1 = tf .identity (o_pi , "data_out_mu1" );
121
+ data_out_mu2 = tf .identity (o_pi , "data_out_mu2" );
122
+ data_out_sigma1 = tf .identity (o_pi , "data_out_sigma1" );
123
+ data_out_sigma2 = tf .identity (o_pi , "data_out_sigma2" );
124
+ data_out_corr = tf .identity (o_pi , "data_out_corr" );
125
+ data_out_eos = tf .identity (o_pi , "data_out_eos" );
126
+
127
+ # sticking them all in one op anyway, makes it easier for freezing the graph later
128
+ # IMPORTANT, this needs to stack the named ops above (data_out_XXX), not the prev ops (o_XXX)
129
+ # otherwise when I freeze the graph up to this point, the named versions will be cut
130
+ data_out = tf .stack ([data_out_pi , data_out_mu1 , data_out_mu2 , data_out_sigma1 , data_out_sigma2 , data_out_corr , data_out_eos ], name = "data_out" )
116
131
117
132
self .pi = o_pi
118
133
self .mu1 = o_mu1
@@ -161,9 +176,9 @@ def sample_gaussian_2d(mu1, mu2, s1, s2, rho):
161
176
162
177
for i in range (num ):
163
178
164
- feed = {self .input_data : prev_x , self .initial_state :prev_state }
179
+ feed = {self .input_data : prev_x , self .state_in :prev_state }
165
180
166
- [o_pi , o_mu1 , o_mu2 , o_sigma1 , o_sigma2 , o_corr , o_eos , next_state ] = sess .run ([self .pi , self .mu1 , self .mu2 , self .sigma1 , self .sigma2 , self .corr , self .eos , self .final_state ],feed )
181
+ [o_pi , o_mu1 , o_mu2 , o_sigma1 , o_sigma2 , o_corr , o_eos , next_state ] = sess .run ([self .pi , self .mu1 , self .mu2 , self .sigma1 , self .sigma2 , self .corr , self .eos , self .state_out ],feed )
167
182
168
183
idx = get_pi_idx (random .random (), o_pi [0 ])
169
184
0 commit comments