Skip to content

Commit bad1cfa

Browse files
authored
fix: gru and lambda layer for tf 2.18&2.19
1 parent 01e39b3 commit bad1cfa

File tree

9 files changed

+148
-36
lines changed

9 files changed

+148
-36
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
strategy:
1616
matrix:
1717
os: [ubuntu-latest, macOS-latest] # add windows-2019 when poetry allows installation with `-f` flag
18-
python-version: [3.9, '3.12']
18+
python-version: [3.9, '3.11']
1919
tf-version: [2.13.1, 2.15.1]
2020

2121
exclude:
@@ -92,7 +92,7 @@ jobs:
9292
- name: Set up Python
9393
uses: actions/setup-python@v5
9494
with:
95-
python-version: '3.12'
95+
python-version: '3.11'
9696

9797
- name: Create pip cache directory manually
9898
run: |

examples/run_prediction_simple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
def parse_args():
2020
parser = argparse.ArgumentParser()
2121
parser.add_argument("--seed", type=int, default=315, required=False, help="seed")
22-
parser.add_argument("--use_model", type=str, default="rnn", help="model for train")
22+
parser.add_argument("--use_model", type=str, default="bert", help="model for train")
2323
parser.add_argument("--use_data", type=str, default="sine", help="dataset: sine or air passengers")
2424
parser.add_argument("--train_length", type=int, default=24, help="sequence length for train")
2525
parser.add_argument("--predict_sequence_length", type=int, default=12, help="sequence length for predict")

tfts/layers/embed_layer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,11 @@ def __init__(self, embed_size: int, positional_type: Optional[str] = "positional
6060
self.embed_size = embed_size
6161
self.positional_type = positional_type
6262

63+
def build(self, input_shape: Tuple[int, ...]):
6364
# Value embedding layer: the below section is put in init, so it could build while DataEmbedding is call
6465
# Otherwise, while load the weights, the TokenEmbedding is not built
6566
self.value_embedding = TokenEmbedding(self.embed_size)
67+
self.value_embedding.build(input_shape)
6668

6769
# Positional embedding layer based on specified type
6870
if self.positional_type == "positional encoding":
@@ -74,6 +76,10 @@ def __init__(self, embed_size: int, positional_type: Optional[str] = "positional
7476
else:
7577
self.positional_embedding = None
7678

79+
if self.positional_embedding:
80+
self.positional_embedding.build(input_shape)
81+
self.built = True
82+
7783
def call(self, x: tf.Tensor) -> tf.Tensor:
7884
"""
7985
Forward pass of the layer.

tfts/layers/util_layer.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,20 @@
44
class ShapeLayer(tf.keras.layers.Layer):
55
"""Layer to handle shape operations in a Keras-compatible way."""
66

7-
def __init__(self):
8-
super().__init__()
7+
def __init__(self, **kwargs):
8+
super().__init__(**kwargs)
99

1010
def call(self, x):
1111
return tf.shape(x)
12+
13+
14+
class CreateDecoderFeature(tf.keras.layers.Layer):
15+
def __init__(self, predict_sequence_length, **kwargs):
16+
super().__init__(**kwargs)
17+
self.predict_sequence_length = predict_sequence_length
18+
19+
def call(self, encoder_feature):
20+
batch_size = tf.shape(encoder_feature)[0]
21+
time_range = tf.range(self.predict_sequence_length)
22+
tiled = tf.tile(tf.reshape(time_range, (1, self.predict_sequence_length, 1)), (batch_size, 1, 1))
23+
return tf.cast(tiled, tf.float32)

tfts/models/autoformer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,11 @@ def __init__(self, predict_sequence_length: int = 1, config: Optional[AutoFormer
9191
hidden_dropout_prob=self.config.hidden_dropout_prob,
9292
)
9393

94-
self.project1 = Dense(1, activation=None)
9594
self.drop1 = Dropout(self.config.hidden_dropout_prob)
9695
self.dense1 = Dense(512, activation="relu")
9796
self.drop2 = Dropout(self.config.hidden_dropout_prob)
9897
self.dense2 = Dense(1024, activation="relu")
98+
self.project1 = Dense(1, activation=None)
9999

100100
def __call__(
101101
self,
@@ -121,7 +121,7 @@ def __call__(
121121
Otherwise, returns the output tensor.
122122
"""
123123
x, encoder_feature, decoder_feature = self._prepare_3d_inputs(inputs, ignore_decoder_inputs=False)
124-
batch_size, _, n_feature = self.shape_layer(encoder_feature)
124+
# batch_size, _, n_feature = self.shape_layer(encoder_feature)
125125

126126
# Encoder
127127
encoder_output = self.encoder(x)
@@ -198,6 +198,10 @@ def get_config(self):
198198
base_config = super().get_config()
199199
return dict(list(base_config.items()) + list(config.items()))
200200

201+
def compute_output_shape(self, input_shape):
202+
batch_size, time_steps, _ = input_shape
203+
return (batch_size, time_steps, self.hidden_size)
204+
201205

202206
class EncoderLayer(tf.keras.layers.Layer):
203207
"""Encoder Layer for Autoformer architecture."""
@@ -317,6 +321,10 @@ def get_config(self):
317321
base_config = super().get_config()
318322
return dict(list(base_config.items()) + list(config.items()))
319323

324+
def compute_output_shape(self, input_shape):
325+
batch_size, time_steps, _ = input_shape
326+
return (batch_size, time_steps, self.hidden_size)
327+
320328

321329
class DecoderLayer(tf.keras.layers.Layer):
322330
"""Decoder Layer for Autoformer architecture."""

tfts/models/base.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from tensorflow.keras.layers import Concatenate, Lambda
1212

1313
from ..constants import CONFIG_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME
14+
from ..layers.util_layer import CreateDecoderFeature
1415

1516
logger = logging.getLogger(__name__)
1617

@@ -95,18 +96,14 @@ def _prepare_3d_inputs(self, inputs, ignore_decoder_inputs=True):
9596
else:
9697
encoder_feature = x = inputs
9798
if not ignore_decoder_inputs:
98-
decoder_feature = Lambda(
99-
lambda encoder_feature: tf.cast(
100-
tf.tile(
101-
tf.reshape(tf.range(self.predict_sequence_length), (1, self.predict_sequence_length, 1)),
102-
(tf.shape(encoder_feature)[0], 1, 1),
103-
),
104-
tf.float32,
105-
),
106-
output_shape=(self.predict_sequence_length, 1),
107-
)(encoder_feature)
99+
decoder_feature = CreateDecoderFeature(self.predict_sequence_length)(encoder_feature)
108100
return x, encoder_feature, decoder_feature
109101

102+
def _create_decoder_feature(batch_size, predict_sequence_length):
103+
time_range = tf.range(predict_sequence_length)
104+
tiled = tf.tile(tf.reshape(time_range, (1, predict_sequence_length, 1)), (batch_size, 1, 1))
105+
return tf.cast(tiled, tf.float32)
106+
110107
def save_pretrained(
111108
self,
112109
save_directory: Union[str, os.PathLike],

tfts/models/rnn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def build(self, input_shape):
122122
activation="tanh",
123123
return_sequences=True,
124124
return_state=return_state,
125+
reset_after=False,
125126
dropout=self.rnn_dropout if self.rnn_dropout > 0 else 0.0,
126127
)
127128

@@ -213,7 +214,7 @@ def compute_output_shape(self, input_shape):
213214
elif self.rnn_type == "gru":
214215
# GRU: (output, state)
215216
return ((batch_size, seq_length, rnn_output_size), (batch_size, rnn_output_size))
216-
else: # LSTM
217+
else:
217218
# LSTM: (output, state_h, state_c)
218219
return ((batch_size, seq_length, rnn_output_size), (batch_size, 2 * rnn_output_size))
219220

tfts/models/seq2seq.py

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,24 +105,35 @@ def __call__(
105105
class Encoder(tf.keras.layers.Layer):
106106
def __init__(self, rnn_size, rnn_type="gru", rnn_dropout=0, dense_size=32, return_state=False, **kwargs):
107107
super().__init__(**kwargs)
108+
self.rnn_size = rnn_size
108109
self.rnn_type = rnn_type.lower()
110+
self.rnn_dropout = rnn_dropout
111+
self.dense_size = dense_size
109112
self.return_state = return_state
110-
if rnn_type == "gru":
113+
114+
def build(self, input_shape):
115+
if self.rnn_type == "gru":
111116
self.rnn = GRU(
112-
units=rnn_size, activation="tanh", return_state=True, return_sequences=True, dropout=rnn_dropout
117+
units=self.rnn_size,
118+
activation="tanh",
119+
return_state=True,
120+
return_sequences=True,
121+
dropout=self.rnn_dropout,
122+
reset_after=False,
113123
)
114-
elif rnn_type == "lstm":
124+
elif self.rnn_type == "lstm":
115125
self.rnn = LSTM(
116-
units=rnn_size,
126+
units=self.rnn_size,
117127
activation="tanh",
118128
return_state=True,
119129
return_sequences=True,
120-
dropout=rnn_dropout,
130+
dropout=self.rnn_dropout,
121131
)
122132
else:
123-
raise ValueError(f"No supported RNN type: {rnn_type}")
133+
raise ValueError(f"No supported RNN type: {self.rnn_type}")
124134

125-
self.dense = Dense(units=dense_size, activation="tanh")
135+
self.dense = Dense(units=self.dense_size, activation="tanh")
136+
super(Encoder, self).build(input_shape)
126137

127138
def call(self, inputs):
128139
"""Process input through the encoder RNN and dense layers.
@@ -138,7 +149,8 @@ def call(self, inputs):
138149
- For LSTM: tuple of (batch_size, dense_size), (batch_size, dense_size)
139150
"""
140151
if self.rnn_type == "gru":
141-
outputs, state = self.rnn(inputs)
152+
rnn_outputs = self.rnn(inputs)
153+
outputs, state = rnn_outputs
142154
state = self.dense(state)
143155
elif self.rnn_type == "lstm":
144156
outputs, state_h, state_c = self.rnn(inputs)
@@ -151,6 +163,32 @@ def call(self, inputs):
151163
# outputs = self.dense(outputs) # => batch_size * input_seq_length * dense_size
152164
return outputs, state
153165

166+
def get_config(self):
167+
config = super().get_config()
168+
config.update(
169+
{
170+
"rnn_size": self.rnn_size,
171+
"rnn_type": self.rnn_type,
172+
"rnn_dropout": self.rnn_dropout,
173+
"dense_size": self.dense_size,
174+
"return_state": self.return_state,
175+
}
176+
)
177+
return config
178+
179+
def compute_output_shape(self, input_shape):
180+
batch_size, seq_len, _ = input_shape
181+
rnn_output_shape = (batch_size, seq_len, self.rnn_size)
182+
183+
# State shape depends on RNN type
184+
if self.rnn_type == "gru":
185+
state_shape = (batch_size, self.dense_size)
186+
elif self.rnn_type == "lstm":
187+
state_shape = ((batch_size, self.dense_size), (batch_size, self.dense_size))
188+
else:
189+
raise ValueError(f"No supported rnn type of {self.rnn_type}")
190+
return rnn_output_shape, state_shape
191+
154192

155193
class DecoderV1(tf.keras.layers.Layer):
156194
def __init__(
@@ -256,6 +294,29 @@ def call(
256294
decoder_outputs = tf.concat(decoder_outputs, axis=-1)
257295
return tf.expand_dims(decoder_outputs, -1)
258296

297+
def get_config(self):
298+
config = super().get_config()
299+
config.update(
300+
{
301+
"rnn_size": self.rnn_size,
302+
"rnn_type": self.rnn_type,
303+
"predict_sequence_length": self.predict_sequence_length,
304+
"use_attention": self.use_attention,
305+
"attention_size": self.attention_size,
306+
"num_attention_heads": self.num_attention_heads,
307+
"attention_probs_dropout_prob": self.attention_probs_dropout_prob,
308+
}
309+
)
310+
return config
311+
312+
def compute_output_shape(self, input_shape):
313+
decoder_init_input_shape = input_shape[1]
314+
if isinstance(decoder_init_input_shape, (list, tuple)):
315+
batch_size = decoder_init_input_shape[0]
316+
else:
317+
batch_size = None
318+
return (batch_size, self.predict_sequence_length, 1)
319+
259320

260321
class DecoderV2(tf.keras.layers.Layer):
261322
def __init__(

tfts/models/transformer.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -218,16 +218,22 @@ def call(self, inputs: tf.Tensor, mask: Optional[tf.Tensor] = None):
218218
return x
219219

220220
def get_config(self):
221-
config = {
222-
"num_hidden_layers": self.num_hidden_layers,
223-
"hidden_size": self.hidden_size,
224-
"num_attention_heads": self.num_attention_heads,
225-
"attention_probs_dropout_prob": self.attention_probs_dropout_prob,
226-
"ffn_intermediate_size": self.ffn_intermediate_size,
227-
"hidden_dropout_prob": self.hidden_dropout_prob,
228-
}
229-
base_config = super(Encoder, self).get_config()
230-
return dict(list(base_config.items()) + list(config.items()))
221+
config = super().get_config()
222+
config.update(
223+
{
224+
"num_hidden_layers": self.num_hidden_layers,
225+
"hidden_size": self.hidden_size,
226+
"num_attention_heads": self.num_attention_heads,
227+
"attention_probs_dropout_prob": self.attention_probs_dropout_prob,
228+
"ffn_intermediate_size": self.ffn_intermediate_size,
229+
"hidden_dropout_prob": self.hidden_dropout_prob,
230+
"layer_norm_eps": self.layer_norm_eps,
231+
}
232+
)
233+
return config
234+
235+
def compute_output_shape(self, input_shape):
236+
return input_shape
231237

232238

233239
class Decoder(tf.keras.layers.Layer):
@@ -332,6 +338,24 @@ def get_causal_attention_mask(self, sequence_length: int) -> tf.Tensor:
332338
mask = tf.cast(i >= j, dtype="int32")
333339
return tf.reshape(mask, (1, sequence_length, sequence_length))
334340

341+
def get_config(self):
342+
config = super().get_config()
343+
config.update(
344+
{
345+
"num_decoder_layers": self.num_decoder_layers,
346+
"hidden_size": self.hidden_size,
347+
"num_attention_heads": self.num_attention_heads,
348+
"attention_probs_dropout_prob": self.attention_probs_dropout_prob,
349+
"ffn_intermediate_size": self.ffn_intermediate_size,
350+
"hidden_dropout_prob": self.hidden_dropout_prob,
351+
"layer_norm_eps": self.layer_norm_eps,
352+
}
353+
)
354+
return config
355+
356+
def compute_output_shape(self, input_shape):
357+
return input_shape
358+
335359

336360
class DecoderLayer(tf.keras.layers.Layer):
337361
def __init__(
@@ -399,6 +423,9 @@ def get_config(self):
399423
base_config = super(DecoderLayer, self).get_config()
400424
return dict(list(base_config.items()) + list(config.items()))
401425

426+
def compute_output_shape(self, input_shape):
427+
return input_shape
428+
402429

403430
class TransformerBlock(tf.keras.layers.Layer):
404431
"""Basic Transformer block with attention and feed-forward layers."""

0 commit comments

Comments
 (0)