@@ -105,24 +105,35 @@ def __call__(
105105class Encoder (tf .keras .layers .Layer ):
106106 def __init__ (self , rnn_size , rnn_type = "gru" , rnn_dropout = 0 , dense_size = 32 , return_state = False , ** kwargs ):
107107 super ().__init__ (** kwargs )
108+ self .rnn_size = rnn_size
108109 self .rnn_type = rnn_type .lower ()
110+ self .rnn_dropout = rnn_dropout
111+ self .dense_size = dense_size
109112 self .return_state = return_state
110- if rnn_type == "gru" :
113+
114+ def build (self , input_shape ):
115+ if self .rnn_type == "gru" :
111116 self .rnn = GRU (
112- units = rnn_size , activation = "tanh" , return_state = True , return_sequences = True , dropout = rnn_dropout
117+ units = self .rnn_size ,
118+ activation = "tanh" ,
119+ return_state = True ,
120+ return_sequences = True ,
121+ dropout = self .rnn_dropout ,
122+ reset_after = False ,
113123 )
114- elif rnn_type == "lstm" :
124+ elif self . rnn_type == "lstm" :
115125 self .rnn = LSTM (
116- units = rnn_size ,
126+ units = self . rnn_size ,
117127 activation = "tanh" ,
118128 return_state = True ,
119129 return_sequences = True ,
120- dropout = rnn_dropout ,
130+ dropout = self . rnn_dropout ,
121131 )
122132 else :
123- raise ValueError (f"No supported RNN type: { rnn_type } " )
133+ raise ValueError (f"No supported RNN type: { self . rnn_type } " )
124134
125- self .dense = Dense (units = dense_size , activation = "tanh" )
135+ self .dense = Dense (units = self .dense_size , activation = "tanh" )
136+ super (Encoder , self ).build (input_shape )
126137
127138 def call (self , inputs ):
128139 """Process input through the encoder RNN and dense layers.
@@ -138,7 +149,8 @@ def call(self, inputs):
138149 - For LSTM: tuple of (batch_size, dense_size), (batch_size, dense_size)
139150 """
140151 if self .rnn_type == "gru" :
141- outputs , state = self .rnn (inputs )
152+ rnn_outputs = self .rnn (inputs )
153+ outputs , state = rnn_outputs
142154 state = self .dense (state )
143155 elif self .rnn_type == "lstm" :
144156 outputs , state_h , state_c = self .rnn (inputs )
@@ -151,6 +163,32 @@ def call(self, inputs):
151163 # outputs = self.dense(outputs) # => batch_size * input_seq_length * dense_size
152164 return outputs , state
153165
166+ def get_config (self ):
167+ config = super ().get_config ()
168+ config .update (
169+ {
170+ "rnn_size" : self .rnn_size ,
171+ "rnn_type" : self .rnn_type ,
172+ "rnn_dropout" : self .rnn_dropout ,
173+ "dense_size" : self .dense_size ,
174+ "return_state" : self .return_state ,
175+ }
176+ )
177+ return config
178+
179+ def compute_output_shape (self , input_shape ):
180+ batch_size , seq_len , _ = input_shape
181+ rnn_output_shape = (batch_size , seq_len , self .rnn_size )
182+
183+ # State shape depends on RNN type
184+ if self .rnn_type == "gru" :
185+ state_shape = (batch_size , self .dense_size )
186+ elif self .rnn_type == "lstm" :
187+ state_shape = ((batch_size , self .dense_size ), (batch_size , self .dense_size ))
188+ else :
189+ raise ValueError (f"No supported rnn type of { self .rnn_type } " )
190+ return rnn_output_shape , state_shape
191+
154192
155193class DecoderV1 (tf .keras .layers .Layer ):
156194 def __init__ (
@@ -256,6 +294,29 @@ def call(
256294 decoder_outputs = tf .concat (decoder_outputs , axis = - 1 )
257295 return tf .expand_dims (decoder_outputs , - 1 )
258296
297+ def get_config (self ):
298+ config = super ().get_config ()
299+ config .update (
300+ {
301+ "rnn_size" : self .rnn_size ,
302+ "rnn_type" : self .rnn_type ,
303+ "predict_sequence_length" : self .predict_sequence_length ,
304+ "use_attention" : self .use_attention ,
305+ "attention_size" : self .attention_size ,
306+ "num_attention_heads" : self .num_attention_heads ,
307+ "attention_probs_dropout_prob" : self .attention_probs_dropout_prob ,
308+ }
309+ )
310+ return config
311+
312+ def compute_output_shape (self , input_shape ):
313+ decoder_init_input_shape = input_shape [1 ]
314+ if isinstance (decoder_init_input_shape , (list , tuple )):
315+ batch_size = decoder_init_input_shape [0 ]
316+ else :
317+ batch_size = None
318+ return (batch_size , self .predict_sequence_length , 1 )
319+
259320
260321class DecoderV2 (tf .keras .layers .Layer ):
261322 def __init__ (
0 commit comments