Skip to content

Resnet50 #1030

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ To release a new version, please update the changelog as followed:
- Support nested layer customization (#PR 1015)
- Support string dtype in InputLayer (#PR 1017)
- Support Dynamic RNN in RNN (#PR 1023)
- Add ResNet50 static model (#PR 1030)

### Changed

Expand Down Expand Up @@ -120,7 +121,7 @@ To release a new version, please update the changelog as followed:
### Contributors

- @zsdonghao
- @ChrisWu1997: #1010 #1015 #1025
- @ChrisWu1997: #1010 #1015 #1025 #1030
- @warshallrho: #1017 #1021 #1026 #1029
- @ArnoldLIULJ: #1023
- @JingqingZ: #1023
Expand Down
6 changes: 6 additions & 0 deletions docs/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ TensorLayer provides many pretrained models, you can easily use the whole or a p
VGG19
SqueezeNetV1
MobileNetV1
ResNet50
Seq2seq
Seq2seqLuongAttention

Expand Down Expand Up @@ -41,6 +42,11 @@ MobileNetV1

.. autofunction:: MobileNetV1

ResNet50
----------------

.. autofunction:: ResNet50

Seq2seq
------------------------

Expand Down
34 changes: 34 additions & 0 deletions examples/pretrained_cnn/tutorial_models_resnet50.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#! /usr/bin/python
# -*- coding: utf-8 -*-
"""
ResNet50 for ImageNet using TL models

"""

import time

import numpy as np

import tensorflow as tf
import tensorlayer as tl
from tensorlayer.models.imagenet_classes import class_names

# tf.logging.set_verbosity(tf.logging.DEBUG)
tl.logging.set_verbosity(tl.logging.DEBUG)

# get the whole model
resnet = tl.models.ResNet50(pretrained=True)

img1 = tl.vis.read_image('data/tiger.jpeg')
img1 = tl.prepro.imresize(img1, (224, 224))[:, :, ::-1]
img1 = img1 - np.array([103.939, 116.779, 123.68]).reshape((1, 1, 3))

img1 = img1.astype(np.float32)[np.newaxis, ...]

start_time = time.time()
output = resnet(img1, is_train=False)
prob = tf.nn.softmax(output)[0].numpy()
print(" End time : %.5ss" % (time.time() - start_time))
preds = (np.argsort(prob)[::-1])[0:5]
for p in preds:
print(class_names[p], prob[p])
1 change: 1 addition & 0 deletions tensorlayer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# """A collections of pre-defined well known models."""

from .core import *
from .resnet import ResNet50
from .mobilenetv1 import MobileNetV1
from .squeezenetv1 import SqueezeNetV1
from .vgg import *
Expand Down
202 changes: 202 additions & 0 deletions tensorlayer/models/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
#! /usr/bin/python
# -*- coding: utf-8 -*-
"""ResNet for ImageNet.

# Reference:
- [Deep Residual Learning for Image Recognition](
https://arxiv.org/abs/1512.03385) (CVPR 2016 Best Paper Award)

"""

import os

import tensorflow as tf
from tensorlayer import logging
from tensorlayer.files import (assign_weights, load_npz, maybe_download_and_extract)
from tensorlayer.layers import (BatchNorm, Conv2d, Elementwise, GlobalMeanPool2d, MaxPool2d, Input, Dense)
from tensorlayer.models import Model

__all__ = [
'ResNet50',
]


def identity_block(input, kernel_size, n_filters, stage, block):
"""The identity block where there is no conv layer at shortcut.

Parameters
----------
input : tf tensor
Input tensor from above layer.
kernel_size : int
The kernel size of middle conv layer at main path.
n_filters : list of integers
The numbers of filters for 3 conv layer at main path.
stage : int
Current stage label.
block : str
Current block label.

Returns
-------
Output tensor of this block.

"""
filters1, filters2, filters3 = n_filters
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'

x = Conv2d(filters1, (1, 1), W_init=tf.initializers.he_normal(), name=conv_name_base + '2a')(input)
x = BatchNorm(name=bn_name_base + '2a', act='relu')(x)

ks = (kernel_size, kernel_size)
x = Conv2d(filters2, ks, padding='SAME', W_init=tf.initializers.he_normal(), name=conv_name_base + '2b')(x)
x = BatchNorm(name=bn_name_base + '2b', act='relu')(x)

x = Conv2d(filters3, (1, 1), W_init=tf.initializers.he_normal(), name=conv_name_base + '2c')(x)
x = BatchNorm(name=bn_name_base + '2c')(x)

x = Elementwise(tf.add, act='relu')([x, input])
return x


def conv_block(input, kernel_size, n_filters, stage, block, strides=(2, 2)):
"""The conv block where there is a conv layer at shortcut.

Parameters
----------
input : tf tensor
Input tensor from above layer.
kernel_size : int
The kernel size of middle conv layer at main path.
n_filters : list of integers
The numbers of filters for 3 conv layer at main path.
stage : int
Current stage label.
block : str
Current block label.
strides : tuple
Strides for the first conv layer in the block.

Returns
-------
Output tensor of this block.

"""
filters1, filters2, filters3 = n_filters
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'

x = Conv2d(filters1, (1, 1), strides=strides, W_init=tf.initializers.he_normal(), name=conv_name_base + '2a')(input)
x = BatchNorm(name=bn_name_base + '2a', act='relu')(x)

ks = (kernel_size, kernel_size)
x = Conv2d(filters2, ks, padding='SAME', W_init=tf.initializers.he_normal(), name=conv_name_base + '2b')(x)
x = BatchNorm(name=bn_name_base + '2b', act='relu')(x)

x = Conv2d(filters3, (1, 1), W_init=tf.initializers.he_normal(), name=conv_name_base + '2c')(x)
x = BatchNorm(name=bn_name_base + '2c')(x)

shortcut = Conv2d(filters3, (1, 1), strides=strides, W_init=tf.initializers.he_normal(),
name=conv_name_base + '1')(input)
shortcut = BatchNorm(name=bn_name_base + '1')(shortcut)

x = Elementwise(tf.add, act='relu')([x, shortcut])
return x


block_names = ['2a', '2b', '2c', '3a', '3b', '3c', '3d', '4a', '4b', '4c', '4d', '4e', '4f', '5a', '5b', '5c'
] + ['avg_pool', 'fc1000']
block_filters = [[64, 64, 256], [128, 128, 512], [256, 256, 1024], [512, 512, 2048]]


def ResNet50(pretrained=False, end_with='fc1000', n_classes=1000, name=None):
"""Pre-trained MobileNetV1 model (static mode). Input shape [?, 224, 224, 3].
To use pretrained model, input should be in BGR format and subtracted from ImageNet mean [103.939, 116.779, 123.68].

Parameters
----------
pretrained : boolean
Whether to load pretrained weights. Default False.
end_with : str
The end point of the model [conv, depth1, depth2 ... depth13, globalmeanpool, out].
Default ``out`` i.e. the whole model.
n_classes : int
Number of classes in final prediction.
name : None or str
Name for this model.

Examples
---------
Classify ImageNet classes, see `tutorial_models_resnet50.py`

>>> # get the whole model with pretrained weights
>>> resnet = tl.models.ResNet50(pretrained=True)
>>> # use for inferencing
>>> output = resnet(img1, is_train=False)
>>> prob = tf.nn.softmax(output)[0].numpy()

Extract the features before fc layer
>>> resnet = tl.models.ResNet50(pretrained=True, end_with='5c')
>>> output = resnet(img1, is_train=False)

Returns
-------
ResNet50 model.

"""
ni = Input([None, 224, 224, 3], name="input")
n = Conv2d(64, (7, 7), strides=(2, 2), padding='SAME', W_init=tf.initializers.he_normal(), name='conv1')(ni)
n = BatchNorm(name='bn_conv1', act='relu')(n)
n = MaxPool2d((3, 3), strides=(2, 2), name='max_pool1')(n)

for i, name in enumerate(block_names):
if len(name) == 2:
stage = int(name[0])
block = name[1]
if block == 'a':
strides = (1, 1) if stage == 2 else (2, 2)
n = conv_block(n, 3, block_filters[stage - 2], stage=stage, block=block, strides=strides)
else:
n = identity_block(n, 3, block_filters[stage - 2], stage=stage, block=block)
elif name == 'avg_pool':
n = GlobalMeanPool2d(name='avg_pool')(n)
elif name == 'fc1000':
n = Dense(n_classes, name='fc1000')(n)

if name == end_with:
break

network = Model(inputs=ni, outputs=n, name=name)

if pretrained:
restore_params(network)

return network


def restore_params(network, path='models'):
logging.info("Restore pre-trained parameters")
maybe_download_and_extract(
'resnet50_weights_tf_dim_ordering_tf_kernels.h5',
path,
'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/',
) # ls -al
try:
import h5py
except Exception:
raise ImportError('h5py not imported')

f = h5py.File(os.path.join(path, 'resnet50_weights_tf_dim_ordering_tf_kernels.h5'), 'r')

for layer in network.all_layers:
if len(layer.all_weights) == 0:
continue
w_names = list(f[layer.name])
params = [f[layer.name][n][:] for n in w_names]
if 'bn' in layer.name:
params = [x.reshape(1, 1, 1, -1) for x in params]
assign_weights(params, layer)
del params

f.close()