diff --git a/CHANGELOG.md b/CHANGELOG.md index c867dd7fe..19ac5ab5f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -93,6 +93,7 @@ To release a new version, please update the changelog as followed: - Support nested layer customization (#PR 1015) - Support string dtype in InputLayer (#PR 1017) - Support Dynamic RNN in RNN (#PR 1023) +- Add ResNet50 static model (#PR 1030) ### Changed @@ -120,7 +121,7 @@ To release a new version, please update the changelog as followed: ### Contributors - @zsdonghao -- @ChrisWu1997: #1010 #1015 #1025 +- @ChrisWu1997: #1010 #1015 #1025 #1030 - @warshallrho: #1017 #1021 #1026 #1029 - @ArnoldLIULJ: #1023 - @JingqingZ: #1023 diff --git a/docs/modules/models.rst b/docs/modules/models.rst index 46b8d7e1b..272f1d9c6 100644 --- a/docs/modules/models.rst +++ b/docs/modules/models.rst @@ -13,6 +13,7 @@ TensorLayer provides many pretrained models, you can easily use the whole or a p VGG19 SqueezeNetV1 MobileNetV1 + ResNet50 Seq2seq Seq2seqLuongAttention @@ -41,6 +42,11 @@ MobileNetV1 .. autofunction:: MobileNetV1 +ResNet50 +---------------- + +.. autofunction:: ResNet50 + Seq2seq ------------------------ diff --git a/examples/pretrained_cnn/tutorial_models_resnet50.py b/examples/pretrained_cnn/tutorial_models_resnet50.py new file mode 100644 index 000000000..b5055cee3 --- /dev/null +++ b/examples/pretrained_cnn/tutorial_models_resnet50.py @@ -0,0 +1,34 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- +""" +ResNet50 for ImageNet using TL models + +""" + +import time + +import numpy as np + +import tensorflow as tf +import tensorlayer as tl +from tensorlayer.models.imagenet_classes import class_names + +# tf.logging.set_verbosity(tf.logging.DEBUG) +tl.logging.set_verbosity(tl.logging.DEBUG) + +# get the whole model +resnet = tl.models.ResNet50(pretrained=True) + +img1 = tl.vis.read_image('data/tiger.jpeg') +img1 = tl.prepro.imresize(img1, (224, 224))[:, :, ::-1] +img1 = img1 - np.array([103.939, 116.779, 123.68]).reshape((1, 1, 3)) + +img1 = img1.astype(np.float32)[np.newaxis, ...] + +start_time = time.time() +output = resnet(img1, is_train=False) +prob = tf.nn.softmax(output)[0].numpy() +print(" End time : %.5ss" % (time.time() - start_time)) +preds = (np.argsort(prob)[::-1])[0:5] +for p in preds: + print(class_names[p], prob[p]) diff --git a/tensorlayer/models/__init__.py b/tensorlayer/models/__init__.py index 065b94885..19f5bb665 100644 --- a/tensorlayer/models/__init__.py +++ b/tensorlayer/models/__init__.py @@ -4,6 +4,7 @@ # """A collections of pre-defined well known models.""" from .core import * +from .resnet import ResNet50 from .mobilenetv1 import MobileNetV1 from .squeezenetv1 import SqueezeNetV1 from .vgg import * diff --git a/tensorlayer/models/resnet.py b/tensorlayer/models/resnet.py new file mode 100644 index 000000000..9938fd1cd --- /dev/null +++ b/tensorlayer/models/resnet.py @@ -0,0 +1,202 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- +"""ResNet for ImageNet. + +# Reference: +- [Deep Residual Learning for Image Recognition]( + https://arxiv.org/abs/1512.03385) (CVPR 2016 Best Paper Award) + +""" + +import os + +import tensorflow as tf +from tensorlayer import logging +from tensorlayer.files import (assign_weights, load_npz, maybe_download_and_extract) +from tensorlayer.layers import (BatchNorm, Conv2d, Elementwise, GlobalMeanPool2d, MaxPool2d, Input, Dense) +from tensorlayer.models import Model + +__all__ = [ + 'ResNet50', +] + + +def identity_block(input, kernel_size, n_filters, stage, block): + """The identity block where there is no conv layer at shortcut. + + Parameters + ---------- + input : tf tensor + Input tensor from above layer. + kernel_size : int + The kernel size of middle conv layer at main path. + n_filters : list of integers + The numbers of filters for 3 conv layer at main path. + stage : int + Current stage label. + block : str + Current block label. + + Returns + ------- + Output tensor of this block. + + """ + filters1, filters2, filters3 = n_filters + conv_name_base = 'res' + str(stage) + block + '_branch' + bn_name_base = 'bn' + str(stage) + block + '_branch' + + x = Conv2d(filters1, (1, 1), W_init=tf.initializers.he_normal(), name=conv_name_base + '2a')(input) + x = BatchNorm(name=bn_name_base + '2a', act='relu')(x) + + ks = (kernel_size, kernel_size) + x = Conv2d(filters2, ks, padding='SAME', W_init=tf.initializers.he_normal(), name=conv_name_base + '2b')(x) + x = BatchNorm(name=bn_name_base + '2b', act='relu')(x) + + x = Conv2d(filters3, (1, 1), W_init=tf.initializers.he_normal(), name=conv_name_base + '2c')(x) + x = BatchNorm(name=bn_name_base + '2c')(x) + + x = Elementwise(tf.add, act='relu')([x, input]) + return x + + +def conv_block(input, kernel_size, n_filters, stage, block, strides=(2, 2)): + """The conv block where there is a conv layer at shortcut. + + Parameters + ---------- + input : tf tensor + Input tensor from above layer. + kernel_size : int + The kernel size of middle conv layer at main path. + n_filters : list of integers + The numbers of filters for 3 conv layer at main path. + stage : int + Current stage label. + block : str + Current block label. + strides : tuple + Strides for the first conv layer in the block. + + Returns + ------- + Output tensor of this block. + + """ + filters1, filters2, filters3 = n_filters + conv_name_base = 'res' + str(stage) + block + '_branch' + bn_name_base = 'bn' + str(stage) + block + '_branch' + + x = Conv2d(filters1, (1, 1), strides=strides, W_init=tf.initializers.he_normal(), name=conv_name_base + '2a')(input) + x = BatchNorm(name=bn_name_base + '2a', act='relu')(x) + + ks = (kernel_size, kernel_size) + x = Conv2d(filters2, ks, padding='SAME', W_init=tf.initializers.he_normal(), name=conv_name_base + '2b')(x) + x = BatchNorm(name=bn_name_base + '2b', act='relu')(x) + + x = Conv2d(filters3, (1, 1), W_init=tf.initializers.he_normal(), name=conv_name_base + '2c')(x) + x = BatchNorm(name=bn_name_base + '2c')(x) + + shortcut = Conv2d(filters3, (1, 1), strides=strides, W_init=tf.initializers.he_normal(), + name=conv_name_base + '1')(input) + shortcut = BatchNorm(name=bn_name_base + '1')(shortcut) + + x = Elementwise(tf.add, act='relu')([x, shortcut]) + return x + + +block_names = ['2a', '2b', '2c', '3a', '3b', '3c', '3d', '4a', '4b', '4c', '4d', '4e', '4f', '5a', '5b', '5c' + ] + ['avg_pool', 'fc1000'] +block_filters = [[64, 64, 256], [128, 128, 512], [256, 256, 1024], [512, 512, 2048]] + + +def ResNet50(pretrained=False, end_with='fc1000', n_classes=1000, name=None): + """Pre-trained MobileNetV1 model (static mode). Input shape [?, 224, 224, 3]. + To use pretrained model, input should be in BGR format and subtracted from ImageNet mean [103.939, 116.779, 123.68]. + + Parameters + ---------- + pretrained : boolean + Whether to load pretrained weights. Default False. + end_with : str + The end point of the model [conv, depth1, depth2 ... depth13, globalmeanpool, out]. + Default ``out`` i.e. the whole model. + n_classes : int + Number of classes in final prediction. + name : None or str + Name for this model. + + Examples + --------- + Classify ImageNet classes, see `tutorial_models_resnet50.py` + + >>> # get the whole model with pretrained weights + >>> resnet = tl.models.ResNet50(pretrained=True) + >>> # use for inferencing + >>> output = resnet(img1, is_train=False) + >>> prob = tf.nn.softmax(output)[0].numpy() + + Extract the features before fc layer + >>> resnet = tl.models.ResNet50(pretrained=True, end_with='5c') + >>> output = resnet(img1, is_train=False) + + Returns + ------- + ResNet50 model. + + """ + ni = Input([None, 224, 224, 3], name="input") + n = Conv2d(64, (7, 7), strides=(2, 2), padding='SAME', W_init=tf.initializers.he_normal(), name='conv1')(ni) + n = BatchNorm(name='bn_conv1', act='relu')(n) + n = MaxPool2d((3, 3), strides=(2, 2), name='max_pool1')(n) + + for i, name in enumerate(block_names): + if len(name) == 2: + stage = int(name[0]) + block = name[1] + if block == 'a': + strides = (1, 1) if stage == 2 else (2, 2) + n = conv_block(n, 3, block_filters[stage - 2], stage=stage, block=block, strides=strides) + else: + n = identity_block(n, 3, block_filters[stage - 2], stage=stage, block=block) + elif name == 'avg_pool': + n = GlobalMeanPool2d(name='avg_pool')(n) + elif name == 'fc1000': + n = Dense(n_classes, name='fc1000')(n) + + if name == end_with: + break + + network = Model(inputs=ni, outputs=n, name=name) + + if pretrained: + restore_params(network) + + return network + + +def restore_params(network, path='models'): + logging.info("Restore pre-trained parameters") + maybe_download_and_extract( + 'resnet50_weights_tf_dim_ordering_tf_kernels.h5', + path, + 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/', + ) # ls -al + try: + import h5py + except Exception: + raise ImportError('h5py not imported') + + f = h5py.File(os.path.join(path, 'resnet50_weights_tf_dim_ordering_tf_kernels.h5'), 'r') + + for layer in network.all_layers: + if len(layer.all_weights) == 0: + continue + w_names = list(f[layer.name]) + params = [f[layer.name][n][:] for n in w_names] + if 'bn' in layer.name: + params = [x.reshape(1, 1, 1, -1) for x in params] + assign_weights(params, layer) + del params + + f.close()