@@ -58,23 +58,32 @@ def is_compatible_with(x, Type):
5858
5959class HookAttribute (object ):
6060 """
61- Hook Attribute object. As a member of ParameterAttribute class, the hook is an auxiliary operation that occurs
62- during training process of a layer with parameters, such as img_conv layer, fc layer.
63-
64- :param type: Hook type, currently supported types:
65- 'pruning' : user specify a sparsity_ratio before training started, and the
66- network will prune the parameters based on the sparsity_ratio.
67- eg: The definition of Hook object can be hk = HookAttribute('pruning', 0.6)
68- The specific usage can be paddle.layer.img_conv(input=img, filter_size=3,
69- num_channels=3, num_filters=64,
70- param_attr=ParameterAttribute(update_hooks=hk) )
71- The pruning details can be found https://arxiv.org/pdf/1506.02626.pdf
72- :type type: string
61+ Hook Attribute object. As a member of ParameterAttribute class,
62+ the hook is an auxiliary operation that occurs during training process of
63+ a layer with parameters, such as img_conv layer, fc layer.
64+
65+ Reference:
66+ Learning both Weights and Connections for Efficient Neural Networks
67+ https://arxiv.org/pdf/1506.02626.pdf
68+
69+ The example usage is:
70+
71+ .. code-block:: python
7372
74- :param sparsity_ratio: Must be specified if hook type is 'pruning',
75- it represents the ratio of the zero elements to be set by the Parameter.
73+ paddle.layer.img_conv(input=img, filter_size=3,
74+ num_channels=3, num_filters=64,
75+ param_attr=ParameterAttribute(update_hooks=hk) )
76+
77+
78+ :param type: Hook type, currently supported types:
79+ 'pruning' : user specify a sparsity_ratio before training started, and the
80+ network will prune the parameters based on the sparsity_ratio.
81+ eg: The definition of Hook object can be hk = HookAttribute('pruning', 0.6)
82+ :type type: string
83+ :param sparsity_ratio: Must be specified if hook type is 'pruning',
84+ it represents the ratio of the zero elements to be set by the Parameter.
7685 :type sparsity_ratio: float or None
77-
86+
7887 """
7988
8089 def __init__ (self , type , sparsity_ratio = None ):
@@ -84,7 +93,8 @@ def __init__(self, type, sparsity_ratio=None):
8493 assert is_compatible_with (
8594 self .sparsity_ratio ,
8695 float ), 'sparisity_ratio must be float type'
87- assert self .sparsity_ratio <= 1 and self .sparsity_ratio >= 0 , 'sparsity_ratio must be a float between [0, 1] '
96+ assert self .sparsity_ratio <= 1 and self .sparsity_ratio >= 0 , \
97+ 'sparsity_ratio must be a float between [0, 1] '
8898
8999 def __call__ (self ):
90100 return ParameterHook (self .type , sparsity_ratio = self .sparsity_ratio )
@@ -139,6 +149,7 @@ class ParameterAttribute(object):
139149 def __init__ (self ,
140150 name = None ,
141151 is_static = False ,
152+ initial_smart = None ,
142153 initial_std = None ,
143154 initial_mean = None ,
144155 initial_max = None ,
@@ -152,32 +163,35 @@ def __init__(self,
152163 update_hooks = None ,
153164 initializer = None ):
154165 self .attr = {}
155-
156- if is_static :
157- self .attr ['is_static' ] = True
158-
159- if initial_std is None and initial_mean is None and initial_max \
160- is None and initial_min is None :
161- self .attr ['initial_smart' ] = True
162- elif is_compatible_with (initial_std , float ) or \
163- is_compatible_with (initial_mean , float ):
164- if initial_std is not None :
165- self .attr ['initial_std' ] = initial_std
166- if initial_mean is not None :
166+ self .attr ['is_static' ] = is_static
167+
168+ if initial_smart is not None :
169+ self .attr ['initial_smart' ] = initial_smart
170+
171+ if initial_std is not None or initial_mean is not None or \
172+ initial_max is not None or initial_min is not None :
173+ # smart initalization will be ignored, because user customizes
174+ # parameters related to initialization distribution
175+ self .attr ['initial_smart' ] = False
176+ if is_compatible_with (initial_std , float ) or \
177+ is_compatible_with (initial_mean , float ):
178+ if initial_std is not None :
179+ self .attr ['initial_std' ] = initial_std
180+ if initial_mean is not None :
181+ self .attr ['initial_mean' ] = initial_mean
182+ self .attr ['initial_strategy' ] = 0 # Gauss Random
183+ elif is_compatible_with (initial_max , float ) and \
184+ is_compatible_with (initial_min , float ):
185+ initial_max = initial_max
186+ initial_min = initial_min
187+ assert initial_min < initial_max
188+ initial_mean = (initial_max + initial_min ) / 2
189+ initial_std = initial_mean - initial_min
167190 self .attr ['initial_mean' ] = initial_mean
168- self .attr ['initial_strategy' ] = 0 # Gauss Random
169- elif is_compatible_with (initial_max , float ) and \
170- is_compatible_with (initial_min , float ):
171- initial_max = initial_max
172- initial_min = initial_min
173- assert initial_min < initial_max
174- initial_mean = (initial_max + initial_min ) / 2
175- initial_std = initial_mean - initial_min
176- self .attr ['initial_mean' ] = initial_mean
177- self .attr ['initial_std' ] = initial_std
178- self .attr ['initial_strategy' ] = 1 # Uniform Random
179- else :
180- raise RuntimeError ("Unexpected branch." )
191+ self .attr ['initial_std' ] = initial_std
192+ self .attr ['initial_strategy' ] = 1 # Uniform Random
193+ else :
194+ raise RuntimeError ("Unexpected branch." )
181195
182196 if not is_static and is_compatible_with (l1_rate , float ):
183197 self .attr ['decay_rate_l1' ] = l1_rate
0 commit comments