Skip to content

Commit d02be83

Browse files
rundiwuzsdonghao
authored andcommitted
Fix layer node bugs for list of outputs (STN case) and upgrade model weights property (#956) (#1010)
* Update examples database to tl2, unfinished * fix bugs for layer_node, particularly in the usage of STN * fix model weights property, return .copy() now * yapf fix for changes * update CHANGELOG, yapf fix
1 parent d79130e commit d02be83

File tree

9 files changed

+90
-55
lines changed

9 files changed

+90
-55
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,14 @@ To release a new version, please update the changelog as followed:
7575

7676
- `SpatialTransform2dAffine` auto `in_channels`
7777
- support TensorFlow 2.0.0-beta1
78+
- Update model weights property, now returns its copy (#PR 1010)
7879

7980
### Dependencies Update
8081

8182
### Deprecated
8283

8384
### Fixed
85+
- Fix `tf.models.Model._construct_graph` for list of outputs, e.g. STN case (PR #1010)
8486

8587
### Removed
8688

@@ -89,6 +91,7 @@ To release a new version, please update the changelog as followed:
8991
### Contributors
9092

9193
- @zsdonghao
94+
- @ChrisWu1997: #1010
9295

9396
## [2.1.0]
9497

examples/database/dispatch_tasks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,5 @@
4646

4747
# get the best model
4848
print("all tasks finished")
49-
sess = tf.InteractiveSession()
50-
net = db.find_top_model(sess=sess, model_name='mlp', sort=[("test_accuracy", -1)])
49+
net = db.find_top_model(model_name='mlp', sort=[("test_accuracy", -1)])
5150
print("the best accuracy {} is from model {}".format(net._test_accuracy, net._name))

examples/database/task_script.py

Lines changed: 32 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,69 +3,57 @@
33
import tensorflow as tf
44
import tensorlayer as tl
55

6-
tf.logging.set_verbosity(tf.logging.DEBUG)
6+
# tf.logging.set_verbosity(tf.logging.DEBUG)
77
tl.logging.set_verbosity(tl.logging.DEBUG)
88

9-
sess = tf.InteractiveSession()
10-
119
# connect to database
1210
db = tl.db.TensorHub(ip='localhost', port=27017, dbname='temp', project_name='tutorial')
1311

1412
# load dataset from database
1513
X_train, y_train, X_val, y_val, X_test, y_test = db.find_top_dataset('mnist')
1614

17-
# define placeholder
18-
x = tf.placeholder(tf.float32, shape=[None, 784], name='x')
19-
y_ = tf.placeholder(tf.int64, shape=[None], name='y_')
20-
21-
2215
# define the network
23-
def mlp(x, is_train=True, reuse=False):
24-
with tf.variable_scope("MLP", reuse=reuse):
25-
net = tl.layers.InputLayer(x, name='input')
26-
net = tl.layers.DropoutLayer(net, keep=0.8, is_fix=True, is_train=is_train, name='drop1')
27-
net = tl.layers.DenseLayer(net, n_units=n_units1, act=tf.nn.relu, name='relu1')
28-
net = tl.layers.DropoutLayer(net, keep=0.5, is_fix=True, is_train=is_train, name='drop2')
29-
net = tl.layers.DenseLayer(net, n_units=n_units2, act=tf.nn.relu, name='relu2')
30-
net = tl.layers.DropoutLayer(net, keep=0.5, is_fix=True, is_train=is_train, name='drop3')
31-
net = tl.layers.DenseLayer(net, n_units=10, act=None, name='output')
32-
return net
33-
34-
35-
# define inferences
36-
net_train = mlp(x, is_train=True, reuse=False)
37-
net_test = mlp(x, is_train=False, reuse=True)
38-
39-
# cost for training
40-
y = net_train.outputs
41-
cost = tl.cost.cross_entropy(y, y_, name='xentropy')
42-
correct_prediction = tf.equal(tf.argmax(y, 1), y_)
43-
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
44-
45-
# cost and accuracy for evalution
46-
y2 = net_test.outputs
47-
cost_test = tl.cost.cross_entropy(y2, y_, name='xentropy2')
48-
correct_prediction = tf.equal(tf.argmax(y2, 1), y_)
49-
acc_test = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
16+
def mlp():
17+
ni = tl.layers.Input([None, 784], name='input')
18+
net = tl.layers.Dropout(keep=0.8, name='drop1')(ni)
19+
net = tl.layers.Dense(n_units=n_units1, act=tf.nn.relu, name='relu1')(net)
20+
net = tl.layers.Dropout(keep=0.5, name='drop2')(net)
21+
net = tl.layers.Dense(n_units=n_units2, act=tf.nn.relu, name='relu2')(net)
22+
net = tl.layers.Dropout(keep=0.5, name='drop3')(net)
23+
net = tl.layers.Dense(n_units=10, act=None, name='output')(net)
24+
M = tl.models.Model(inputs=ni, outputs=net)
25+
return M
26+
27+
network = mlp()
28+
29+
# cost and accuracy
30+
cost = tl.cost.cross_entropy
31+
32+
def acc(y, y_):
33+
correct_prediction = tf.equal(tf.argmax(y, 1), tf.convert_to_tensor(y_, tf.int64))
34+
return tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
5035

5136
# define the optimizer
52-
train_params = tl.layers.get_variables_with_name('MLP', True, False)
53-
train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost, var_list=train_params)
54-
55-
# initialize all variables in the session
56-
sess.run(tf.global_variables_initializer())
37+
train_op = tf.optimizers.Adam(learning_rate=0.0001)
5738

5839
# train the network
40+
# tl.utils.fit(
41+
# network, train_op, cost, X_train, y_train, acc=acc, batch_size=500, n_epoch=20, print_freq=5,
42+
# X_val=X_val, y_val=y_val, eval_train=False
43+
# )
44+
5945
tl.utils.fit(
60-
sess, net_train, train_op, cost, X_train, y_train, x, y_, acc=acc, batch_size=500, n_epoch=1, print_freq=5,
61-
X_val=X_val, y_val=y_val, eval_train=False
46+
network, train_op=tf.optimizers.Adam(learning_rate=0.0001), cost=tl.cost.cross_entropy, X_train=X_train,
47+
y_train=y_train, acc=acc, batch_size=256, n_epoch=20, X_val=X_val, y_val=y_val, eval_train=False,
6248
)
6349

6450
# evaluation and save result that match the result_key
65-
test_accuracy = tl.utils.test(sess, net_test, acc_test, X_test, y_test, x, y_, batch_size=None, cost=cost_test)
51+
test_accuracy = tl.utils.test(network, acc, X_test, y_test, batch_size=None, cost=cost)
6652
test_accuracy = float(test_accuracy)
6753

6854
# save model into database
69-
db.save_model(net_train, model_name='mlp', name=str(n_units1) + '-' + str(n_units2), test_accuracy=test_accuracy)
55+
db.save_model(network, model_name='mlp', name=str(n_units1) + '-' + str(n_units2), test_accuracy=test_accuracy)
7056
# in other script, you can load the model as follow
7157
# net = db.find_model(sess=sess, model_name=str(n_units1)+'-'+str(n_units2)
58+
59+
tf.python.keras.layers.BatchNormalization

examples/spatial_transformer_network/tutorial_spatial_transformer_network_static.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def get_model(inputs_shape):
6666

6767
## 2. Spatial transformer module (sampler)
6868
stn = SpatialTransformer2dAffine(out_size=(40, 40), in_channels=20)
69-
s = stn((nn, ni))
7069
nn = stn((nn, ni))
70+
s = nn
7171

7272
## 3. Classifier
7373
nn = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME')(nn)

tensorlayer/db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ def run_top_task(self, task_name=None, sort=None, **kwargs):
641641
logging.info("[Database] Start Task: key: {} sort: {} push time: {}".format(task_name, sort, _datetime))
642642
_script = _script.decode('utf-8')
643643
with tf.Graph().as_default(): # # as graph: # clear all TF graphs
644-
exec (_script, globals())
644+
exec(_script, globals())
645645

646646
# set status to finished
647647
_ = self.db.Task.find_one_and_update({'_id': _id}, {'$set': {'status': 'finished'}})

tensorlayer/layers/spatial_transformer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,6 @@ def __repr__(self):
257257
return s.format(classname=self.__class__.__name__, **self.__dict__)
258258

259259
def build(self, inputs_shape):
260-
print("inputs_shape ", inputs_shape)
261260
if self.in_channels is None and len(inputs_shape) != 2:
262261
raise AssertionError("The dimension of theta layer input must be rank 2, please reshape or flatten it")
263262
if self.in_channels:
@@ -267,7 +266,6 @@ def build(self, inputs_shape):
267266
# shape = [inputs_shape[1], 6]
268267
self.in_channels = inputs_shape[0][-1] # zsdonghao
269268
shape = [self.in_channels, 6]
270-
print("shape", shape)
271269
self.W = self._get_weights("weights", shape=tuple(shape), init=tl.initializers.Zeros())
272270
identity = np.reshape(np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32), newshape=(6, ))
273271
self.b = self._get_weights("biases", shape=(6, ), init=tl.initializers.Constant(identity))
@@ -282,7 +280,6 @@ def forward(self, inputs):
282280
n_channels is identical to that of U.
283281
"""
284282
theta_input, U = inputs
285-
print("inputs", inputs)
286283
theta = tf.nn.tanh(tf.matmul(theta_input, self.W) + self.b)
287284
outputs = transformer(U, theta, out_size=self.out_size)
288285
# automatically set batch_size and channels

tensorlayer/models/core.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def trainable_weights(self):
401401
if layer.trainable_weights is not None:
402402
self._trainable_weights.extend(layer.trainable_weights)
403403

404-
return self._trainable_weights
404+
return self._trainable_weights.copy()
405405

406406
@property
407407
def nontrainable_weights(self):
@@ -415,7 +415,7 @@ def nontrainable_weights(self):
415415
if layer.nontrainable_weights is not None:
416416
self._nontrainable_weights.extend(layer.nontrainable_weights)
417417

418-
return self._nontrainable_weights
418+
return self._nontrainable_weights.copy()
419419

420420
@property
421421
def all_weights(self):
@@ -429,7 +429,7 @@ def all_weights(self):
429429
if layer.all_weights is not None:
430430
self._all_weights.extend(layer.all_weights)
431431

432-
return self._all_weights
432+
return self._all_weights.copy()
433433

434434
@property
435435
def config(self):
@@ -669,6 +669,8 @@ def _construct_graph(self):
669669

670670
visited_node_names = set()
671671
for out_node in output_nodes:
672+
if out_node.visited:
673+
continue
672674
queue_node.put(out_node)
673675

674676
while not queue_node.empty():

tests/layers/test_layernode.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*-
3-
43
import os
54
import unittest
65

@@ -193,6 +192,44 @@ def MyModel():
193192
self.assertEqual(net.all_layers[1].model._nodes_fixed, True)
194193
self.assertEqual(net.all_layers[1].model.all_layers[0]._nodes_fixed, True)
195194

195+
def test_STN(self):
196+
print('-' * 20, 'test STN', '-' * 20)
197+
198+
def get_model(inputs_shape):
199+
ni = Input(inputs_shape)
200+
201+
## 1. Localisation network
202+
# use MLP as the localisation net
203+
nn = Flatten()(ni)
204+
nn = Dense(n_units=20, act=tf.nn.tanh)(nn)
205+
nn = Dropout(keep=0.8)(nn)
206+
# you can also use CNN instead for MLP as the localisation net
207+
208+
## 2. Spatial transformer module (sampler)
209+
stn = SpatialTransformer2dAffine(out_size=(40, 40), in_channels=20)
210+
# s = stn((nn, ni))
211+
nn = stn((nn, ni))
212+
s = nn
213+
214+
## 3. Classifier
215+
nn = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME')(nn)
216+
nn = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME')(nn)
217+
nn = Flatten()(nn)
218+
nn = Dense(n_units=1024, act=tf.nn.relu)(nn)
219+
nn = Dense(n_units=10, act=tf.identity)(nn)
220+
221+
M = Model(inputs=ni, outputs=[nn, s])
222+
return M
223+
224+
net = get_model([None, 40, 40, 1])
225+
226+
inputs = np.random.randn(2, 40, 40, 1).astype(np.float32)
227+
o1, o2 = net(inputs, is_train=True)
228+
self.assertEqual(o1.shape, (2, 10))
229+
self.assertEqual(o2.shape, (2, 40, 40, 1))
230+
231+
self.assertEqual(len(net._node_by_depth), 10)
232+
196233

197234
if __name__ == '__main__':
198235

tests/models/test_model_core.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,15 @@ def test_get_layer(self):
392392
except Exception as e:
393393
print(e)
394394

395+
def test_model_weights_copy(self):
396+
print('-' * 20, 'test_model_weights_copy', '-' * 20)
397+
model_basic = basic_static_model()
398+
model_weights = model_basic.trainable_weights
399+
ori_len = len(model_weights)
400+
model_weights.append(np.arange(5))
401+
new_len = len(model_weights)
402+
self.assertEqual(new_len - 1, ori_len)
403+
395404

396405
if __name__ == '__main__':
397406

0 commit comments

Comments
 (0)