@@ -120,7 +120,7 @@ def encoder_layer_forward(self,
120120
121121 After inference, `disable_faster_encoder` could be called to restore the
122122 `forward` function of `paddle.nn.TransformerEncoder` and
123- `paddle.nn.TransformerEncoder `.
123+ `paddle.nn.TransformerEncoderLayer `.
124124
125125 Args:
126126 src (Tensor):
@@ -130,14 +130,13 @@ def encoder_layer_forward(self,
130130 src_mask (Tensor, optional):
131131 A tensor used in multi-head attention to prevents attention to some
132132 unwanted positions, usually the paddings or the subsequent
133- positions. It is a tensor with shape broadcasted to
134- `[batch_size, n_head, sequence_length, sequence_length]`. When the
135- data type is bool, the unwanted positions have `False` values and
136- the others have `True` values. When the data type is int, the
137- unwanted positions have 0 values and the others have 1 values. When
138- the data type is float, the unwanted positions have `-INF` values
139- and the others have 0 values. It can be None when nothing wanted or
140- needed to be prevented attention to. Defaults to None.
133+ positions. It is a tensor with shape `[batch_size, 1, 1, sequence_length]`.
134+ When the data type is bool, the unwanted positions have `False`
135+ values and the others have `True` values. When the data type is int,
136+ the unwanted positions have 0 values and the others have 1 values.
137+ When the data type is float, the unwanted positions have `-INF`
138+ values and the others have 0 values. It can be None when nothing
139+ wanted or needed to be prevented attention to. Defaults to None.
141140
142141 Returns:
143142 src(Tensor|tuple):
@@ -192,7 +191,7 @@ def encoder_forward(self, src, src_mask=None, cache=None):
192191
193192 After inference, `disable_faster_encoder` could be called to restore the
194193 `forward` function of `paddle.nn.TransformerEncoder` and
195- `paddle.nn.TransformerEncoder `.
194+ `paddle.nn.TransformerEncoderLayer `.
196195
197196 Args:
198197 src (Tensor):
@@ -202,14 +201,14 @@ def encoder_forward(self, src, src_mask=None, cache=None):
202201 src_mask (Tensor, optional):
203202 A tensor used in multi-head attention to prevents attention to
204203 some unwanted positions, usually the paddings or the subsequent
205- positions. It is a tensor with shape broadcasted to
206- `[batch_size, n_head, sequence_length, sequence_length]`. When the
207- data type is bool, the unwanted positions have `False ` values and
208- the others have `True` values. When the data type is int, the
209- unwanted positions have 0 values and the others have 1 values.
210- When the data type is float, the unwanted positions have `-INF`
211- values and the others have 0 values. It can be None when nothing
212- wanted or needed to be prevented attention to. Default None.
204+ positions. It is a tensor with shape `[batch_size, 1, 1, sequence_length]`.
205+ When the data type is bool, the unwanted positions have `False`
206+ values and the others have `True ` values. When the data type is
207+ int, the unwanted positions have 0 values and the others have 1
208+ values. When the data type is float, the unwanted positions have
209+ `-INF` values and the others have 0 values. It can be None when
210+ nothing wanted or needed to be prevented attention to. Defaults
211+ to None.
213212
214213 Returns:
215214 output (Tensor|tuple):
@@ -252,35 +251,34 @@ def enable_faster_encoder(self):
252251 model = disable_faster_encoder(model)
253252 """
254253
255- def check_if_usable (layer ):
256- for sub_layer in layer .children ():
257- if isinstance (sub_layer ,
258- TransformerEncoderLayer ) and sub_layer ._config [
259- 'bias_attr' ] == False :
254+ def init_func (layer ):
255+ if isinstance (layer , TransformerEncoderLayer ):
256+ is_usable = True
257+ if layer ._config ['bias_attr' ] == False :
260258 logger .warning ("`False` for paddle.nn.TransformerEncoder's" \
261259 " parameter `bias_attr` is not supported in " \
262- "FasterTransformer by now. Original Paddle API " \
263- "would be called." )
264- return False
265- elif not check_if_usable (sub_layer ):
266- return False
267- return True
268-
269- def init_func (layer ):
270- if isinstance (layer , (TransformerEncoderLayer , TransformerEncoder )):
260+ "FasterTransformer by now. The original forward" \
261+ " will be involved." )
262+ is_usable = False
263+ if layer ._config ['activation' ] not in ('relu' , 'gelu' ):
264+ logger .warning ("Only 'relu' or 'gelu' is supported by now. " \
265+ "The original forward will be involved." )
266+ is_usable = False
267+ if is_usable :
268+ layer .forward = layer ._ft_forward
269+ elif isinstance (layer , TransformerEncoder ):
271270 layer .forward = layer ._ft_forward
272271
273272 if not self .training :
274- if not check_if_usable (self ):
275- return self
276273 try :
277274 load ("FasterTransformer" , verbose = True )
278- for layer in self .children ():
279- layer .apply (init_func )
280275 except Exception :
281276 logger .warning (
282277 "Exception occurs when using Faster Transformer. " \
283278 "The original forward will be involved. " )
279+ return self
280+ for layer in self .children ():
281+ layer .apply (init_func )
284282 return self
285283
286284
0 commit comments