Skip to content

Commit 11797c6

Browse files
authored
Groupconv2d (#366)
* [release] 1.8.0rc * release Grouped Conv #363 * fixed codacy error * fixed LG's suggestions * fixed LG's suggestions
1 parent ae7cf2b commit 11797c6

File tree

2 files changed

+87
-3
lines changed

2 files changed

+87
-3
lines changed

docs/modules/layers.rst

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,9 @@ Layer list
270270
Conv2d
271271
DeConv2d
272272
DeConv3d
273+
DepthwiseConv2d
274+
DeformableConv2d
275+
GroupConv2d
273276

274277
MaxPool1d
275278
MeanPool1d
@@ -278,9 +281,6 @@ Layer list
278281
MaxPool3d
279282
MeanPool3d
280283

281-
DepthwiseConv2d
282-
DeformableConv2d
283-
284284
SubpixelConv1d
285285
SubpixelConv2d
286286

@@ -496,6 +496,9 @@ APIs may better for you.
496496
^^^^^^^^^^^^^^^^^^^^^^^
497497
.. autoclass:: DeformableConv2d
498498

499+
2D Grouped Conv
500+
^^^^^^^^^^^^^^^^^^^^^^^
501+
.. autoclass:: GroupConv2d
499502

500503
Super-Resolution layer
501504
------------------------

tensorlayer/layers/convolution.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,6 +1647,87 @@ def __init__(
16471647
self.all_params.extend([W])
16481648

16491649

1650+
class GroupConv2d(Layer):
1651+
"""The :class:`GroupConv2d` class is 2D grouped convolution, see `here <https://blog.yani.io/filter-group-tutorial/>`__.
1652+
1653+
Parameters
1654+
--------------
1655+
layer : :class:`Layer`
1656+
Previous layer.
1657+
n_filter : int
1658+
The number of filters.
1659+
filter_size : int
1660+
The filter size.
1661+
stride : int
1662+
The stride step.
1663+
n_group : int
1664+
The number of groups.
1665+
act : activation function
1666+
The activation function of this layer.
1667+
padding : str
1668+
The padding algorithm type: "SAME" or "VALID".
1669+
W_init : initializer
1670+
The initializer for the weight matrix.
1671+
b_init : initializer or None
1672+
The initializer for the bias vector. If None, skip biases.
1673+
W_init_args : dictionary
1674+
The arguments for the weight matrix initializer.
1675+
b_init_args : dictionary
1676+
The arguments for the bias vector initializer.
1677+
name : str
1678+
A unique layer name.
1679+
"""
1680+
1681+
def __init__(
1682+
self,
1683+
layer=None,
1684+
n_filter=32,
1685+
filter_size=(3, 3),
1686+
strides=(2, 2),
1687+
n_group=2,
1688+
act=tf.identity,
1689+
padding='SAME',
1690+
W_init=tf.truncated_normal_initializer(stddev=0.02),
1691+
b_init=tf.constant_initializer(value=0.0),
1692+
W_init_args=None,
1693+
b_init_args=None,
1694+
name='groupconv',
1695+
): # Windaway
1696+
if W_init_args is None:
1697+
W_init_args = {}
1698+
if b_init_args is None:
1699+
b_init_args = {}
1700+
1701+
Layer.__init__(self, name=name)
1702+
self.inputs = layer.outputs
1703+
groupConv = lambda i, k: tf.nn.conv2d(i, k, strides=[1, strides[0], strides[1], 1], padding=padding)
1704+
channels = int(self.inputs.get_shape()[-1])
1705+
with tf.variable_scope(name):
1706+
We = tf.get_variable(
1707+
name='W', shape=[filter_size[0], filter_size[1], channels / n_group, n_filter], initializer=W_init, dtype=D_TYPE, trainable=True, **W_init_args)
1708+
if b_init:
1709+
bi = tf.get_variable(name='b', shape=n_filter, initializer=b_init, dtype=D_TYPE, trainable=True, **b_init_args)
1710+
if n_group == 1:
1711+
conv = groupConv(self.inputs, We)
1712+
else:
1713+
inputGroups = tf.split(axis=3, num_or_size_splits=n_group, value=self.inputs)
1714+
weightsGroups = tf.split(axis=3, num_or_size_splits=n_group, value=We)
1715+
convGroups = [groupConv(i, k) for i, k in zip(inputGroups, weightsGroups)]
1716+
conv = tf.concat(axis=3, values=convGroups)
1717+
if b_init:
1718+
conv = tf.add(conv, bi, name='add')
1719+
1720+
self.outputs = act(conv)
1721+
self.all_layers = list(layer.all_layers)
1722+
self.all_params = list(layer.all_params)
1723+
self.all_drop = dict(layer.all_drop)
1724+
self.all_layers.append(self.outputs)
1725+
if b_init:
1726+
self.all_params.extend([We, bi])
1727+
else:
1728+
self.all_params.append(We)
1729+
1730+
16501731
# Alias
16511732
AtrousConv1dLayer = atrous_conv1d
16521733
Conv1d = conv1d

0 commit comments

Comments
 (0)