Skip to content
This repository was archived by the owner on Mar 20, 2026. It is now read-only.

Commit 48d9afb

Browse files
Myle Ottfacebook-github-bot
authored andcommitted
Speed improvements (#531)
Summary: * Add FusedLayerNorm and FusedAdam * Softmax and zero grad optimizations Pull Request resolved: #531 Differential Revision: D14218457 Pulled By: myleott fbshipit-source-id: 5656b2d0152cd85f77dc21ec0e1439ec04b9fa89
1 parent a24880b commit 48d9afb

14 files changed

Lines changed: 102 additions & 54 deletions

README.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ translation and language modeling datasets.
3636
![Model](fairseq.gif)
3737

3838
# Requirements and Installation
39-
* A [PyTorch installation](http://pytorch.org/)
39+
40+
* [PyTorch](http://pytorch.org/) version >= 1.0.0
41+
* Python version >= 3.6
4042
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
41-
* Python version 3.6
4243

43-
Currently fairseq requires PyTorch version >= 1.0.0.
44-
Please follow the instructions here: https://github.com/pytorch/pytorch#installation.
44+
Please follow the instructions here to install PyTorch: https://github.com/pytorch/pytorch#installation.
4545

4646
If you use Docker make sure to increase the shared memory size either with
4747
`--ipc=host` or `--shm-size` as command line options to `nvidia-docker run`.
@@ -60,6 +60,12 @@ cd fairseq
6060
pip install --editable .
6161
```
6262

63+
**Improved training speed**
64+
65+
Training speed can be further improved by installing NVIDIA's
66+
[apex](https://github.com/NVIDIA/apex) library with the `--cuda_ext` option.
67+
fairseq will automatically switch to the faster modules provided by apex.
68+
6369
# Getting Started
6470

6571
The [full documentation](https://fairseq.readthedocs.io/) contains instructions

fairseq/distributed_utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,19 +122,23 @@ def all_gather_list(data, group=None, max_size=16384):
122122
if not hasattr(all_gather_list, '_buffer') or \
123123
all_gather_list._buffer.numel() < buffer_size:
124124
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
125+
all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory()
125126
buffer = all_gather_list._buffer
126127
buffer.zero_()
128+
cpu_buffer = all_gather_list._cpu_buffer
127129

128130
enc = pickle.dumps(data)
129131
enc_size = len(enc)
130132
if enc_size + 2 > max_size:
131133
raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2))
132134
assert max_size < 255*256
133135

134-
buffer_rank = buffer[rank * max_size : (rank + 1) * max_size]
135-
buffer_rank[0] = enc_size // 255 # this encoding works for max_size < 65k
136-
buffer_rank[1] = enc_size % 255
137-
buffer_rank[2:enc_size+2] = torch.ByteTensor(list(enc))
136+
cpu_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k
137+
cpu_buffer[1] = enc_size % 255
138+
cpu_buffer[2 : enc_size + 2] = torch.ByteTensor(list(enc))
139+
start = rank * max_size
140+
size = enc_size + 2
141+
buffer[start : start + size].copy_(cpu_buffer[:size])
138142

139143
all_reduce(buffer, group=group)
140144

@@ -144,9 +148,7 @@ def all_gather_list(data, group=None, max_size=16384):
144148
out_buffer = buffer[i * max_size : (i + 1) * max_size]
145149
size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1])
146150
if size > 0:
147-
result.append(
148-
pickle.loads(bytes(out_buffer[2:size+2].tolist()))
149-
)
151+
result.append(pickle.loads(bytes(out_buffer[2 : size + 2].tolist())))
150152
return result
151153
except pickle.UnpicklingError:
152154
raise Exception(

fairseq/models/fairseq_decoder.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
# can be found in the PATENTS file in the same directory.
77

88
import torch.nn as nn
9-
import torch.nn.functional as F
9+
10+
from fairseq import utils
1011

1112

1213
class FairseqDecoder(nn.Module):
@@ -15,6 +16,7 @@ class FairseqDecoder(nn.Module):
1516
def __init__(self, dictionary):
1617
super().__init__()
1718
self.dictionary = dictionary
19+
self.onnx_trace = False
1820

1921
def forward(self, prev_output_tokens, encoder_out):
2022
"""
@@ -33,6 +35,9 @@ def forward(self, prev_output_tokens, encoder_out):
3335
"""
3436
raise NotImplementedError
3537

38+
def prepare_for_onnx_export_(self):
39+
self.onnx_trace = True
40+
3641
def get_normalized_probs(self, net_output, log_probs, sample):
3742
"""Get normalized probabilities (or log probs) from a net's output."""
3843

@@ -45,11 +50,11 @@ def get_normalized_probs(self, net_output, log_probs, sample):
4550
out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
4651
return out.exp_() if not log_probs else out
4752

48-
logits = net_output[0].float()
53+
logits = net_output[0]
4954
if log_probs:
50-
return F.log_softmax(logits, dim=-1)
55+
return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
5156
else:
52-
return F.softmax(logits, dim=-1)
57+
return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
5358

5459
def max_positions(self):
5560
"""Maximum input length supported by the decoder."""

fairseq/models/fconv_self_att.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import torch.nn.functional as F
1414

1515
from fairseq.modules import (
16-
DownsampledMultiHeadAttention, GradMultiply, LearnedPositionalEmbedding,
17-
LinearizedConvolution,
16+
DownsampledMultiHeadAttention, GradMultiply, LayerNorm,
17+
LearnedPositionalEmbedding, LinearizedConvolution,
1818
)
1919
from fairseq import utils
2020

@@ -351,13 +351,13 @@ def expand_bool_array(val):
351351
# pretrained and trained models are joined
352352
self.joining = nn.Sequential(
353353
Linear(out_embed_dim*2, out_embed_dim*2),
354-
nn.LayerNorm(out_embed_dim*2),
354+
LayerNorm(out_embed_dim*2),
355355
nn.GLU(),
356356
Linear(out_embed_dim, out_embed_dim*2),
357-
nn.LayerNorm(out_embed_dim*2),
357+
LayerNorm(out_embed_dim*2),
358358
nn.GLU(),
359359
Linear(out_embed_dim, out_embed_dim),
360-
nn.LayerNorm(out_embed_dim)
360+
LayerNorm(out_embed_dim)
361361
)
362362
# pretrained model contains an output layer that is nhid -> vocab size
363363
# but the models are combined in their hidden state
@@ -470,7 +470,7 @@ def __init__(self, out_channels, embed_dim, num_heads, project_input=False, gate
470470
self.in_proj_q = Linear(out_channels, embed_dim)
471471
self.in_proj_k = Linear(out_channels, embed_dim)
472472
self.in_proj_v = Linear(out_channels, embed_dim)
473-
self.ln = nn.LayerNorm(out_channels)
473+
self.ln = LayerNorm(out_channels)
474474

475475
def forward(self, x):
476476
residual = x

fairseq/models/lightconv.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,16 @@
1111
import torch.nn as nn
1212
import torch.nn.functional as F
1313

14-
from fairseq import options
15-
from fairseq import utils
16-
14+
from fairseq import options, utils
1715
from fairseq.modules import (
18-
AdaptiveInput, AdaptiveSoftmax, CharacterTokenEmbedder, LearnedPositionalEmbedding, MultiheadAttention,
19-
SinusoidalPositionalEmbedding, DynamicConv1dTBC, LightweightConv1dTBC
16+
AdaptiveInput, AdaptiveSoftmax, CharacterTokenEmbedder, LayerNorm,
17+
LearnedPositionalEmbedding, MultiheadAttention, SinusoidalPositionalEmbedding,
18+
DynamicConv1dTBC, LightweightConv1dTBC,
2019
)
2120

2221
from . import (
23-
FairseqIncrementalDecoder, FairseqEncoder, FairseqLanguageModel, FairseqModel, register_model,
24-
register_model_architecture,
22+
FairseqIncrementalDecoder, FairseqEncoder, FairseqLanguageModel,
23+
FairseqModel, register_model, register_model_architecture,
2524
)
2625

2726

@@ -771,11 +770,6 @@ def Embedding(num_embeddings, embedding_dim, padding_idx):
771770
return m
772771

773772

774-
def LayerNorm(embedding_dim):
775-
m = nn.LayerNorm(embedding_dim)
776-
return m
777-
778-
779773
def Linear(in_features, out_features, bias=True):
780774
m = nn.Linear(in_features, out_features, bias)
781775
nn.init.xavier_uniform_(m.weight)

fairseq/models/transformer.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,15 @@
1111
import torch.nn as nn
1212
import torch.nn.functional as F
1313

14-
from fairseq import options
15-
from fairseq import utils
16-
14+
from fairseq import options, utils
1715
from fairseq.modules import (
18-
AdaptiveInput, AdaptiveSoftmax, CharacterTokenEmbedder, LearnedPositionalEmbedding, MultiheadAttention,
19-
SinusoidalPositionalEmbedding
16+
AdaptiveInput, AdaptiveSoftmax, CharacterTokenEmbedder, LayerNorm,
17+
LearnedPositionalEmbedding, MultiheadAttention, SinusoidalPositionalEmbedding,
2018
)
2119

2220
from . import (
23-
FairseqIncrementalDecoder, FairseqEncoder, FairseqLanguageModel, FairseqModel, register_model,
24-
register_model_architecture,
21+
FairseqIncrementalDecoder, FairseqEncoder, FairseqLanguageModel,
22+
FairseqModel, register_model, register_model_architecture,
2523
)
2624

2725

@@ -766,11 +764,6 @@ def Embedding(num_embeddings, embedding_dim, padding_idx):
766764
return m
767765

768766

769-
def LayerNorm(embedding_dim):
770-
m = nn.LayerNorm(embedding_dim)
771-
return m
772-
773-
774767
def Linear(in_features, out_features, bias=True):
775768
m = nn.Linear(in_features, out_features, bias)
776769
nn.init.xavier_uniform_(m.weight)

fairseq/modules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .dynamic_convolution import DynamicConv1dTBC
1515
from .grad_multiply import GradMultiply
1616
from .highway import Highway
17+
from .layer_norm import LayerNorm
1718
from .learned_positional_embedding import LearnedPositionalEmbedding
1819
from .lightweight_convolution import LightweightConv1dTBC
1920
from .linearized_convolution import LinearizedConvolution
@@ -34,6 +35,7 @@
3435
'DynamicConv1dTBC',
3536
'GradMultiply',
3637
'Highway',
38+
'LayerNorm',
3739
'LearnedPositionalEmbedding',
3840
'LightweightConv1dTBC',
3941
'LinearizedConvolution',

fairseq/modules/layer_norm.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the LICENSE file in
5+
# the root directory of this source tree. An additional grant of patent rights
6+
# can be found in the PATENTS file in the same directory.
7+
8+
import torch
9+
10+
11+
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
12+
if torch.cuda.is_available():
13+
try:
14+
from apex.normalization import FusedLayerNorm
15+
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
16+
except ImportError:
17+
pass
18+
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)

fairseq/modules/lightweight_convolution.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# can be found in the PATENTS file in the same directory.
77

88
import math
9+
910
import torch
1011
import torch.nn as nn
1112
import torch.nn.functional as F
@@ -121,6 +122,8 @@ def __init__(self, input_size, kernel_size=1, padding_l=None, num_heads=1,
121122

122123
self.reset_parameters()
123124

125+
self.onnx_trace = False
126+
124127
def reset_parameters(self):
125128
nn.init.xavier_uniform_(self.weight)
126129
if self.bias is not None:
@@ -144,6 +147,9 @@ def forward(self, x, incremental_state=None, unfold=False):
144147
output = output + self.bias.view(1, 1, -1)
145148
return output
146149

150+
def prepare_for_onnx_export_(self):
151+
self.onnx_trace = True
152+
147153
def _forward_unfolded(self, x, incremental_state):
148154
'''The conventional implementation of convolutions.
149155
Unfolding the input by having a window shifting to the right.'''
@@ -167,7 +173,7 @@ def _forward_unfolded(self, x, incremental_state):
167173
x_unfold = x_unfold.view(T*B*H, R, K)
168174

169175
if self.weight_softmax:
170-
weight = F.softmax(weight.float(), dim=1).type_as(weight)
176+
weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as(weight)
171177

172178
if incremental_state is not None:
173179
weight = weight[:, -x_unfold.size(2):]
@@ -192,7 +198,7 @@ def _forward_expanded(self, x, incremental_state):
192198

193199
weight = self.weight.view(H, K)
194200
if self.weight_softmax:
195-
weight = F.softmax(weight.float(), dim=1).type_as(weight)
201+
weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as(weight)
196202
weight = weight.view(1, H, K).expand(T*B, H, K).contiguous()
197203
weight = weight.view(T, B*H, K).transpose(0, 1)
198204

fairseq/modules/multihead_attention.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,9 @@ def forward(self, query, key, value, key_padding_mask=None, incremental_state=No
184184
).type_as(attn_weights) # FP16 support: cast to float and back
185185
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
186186

187-
attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
187+
attn_weights = utils.softmax(
188+
attn_weights, dim=-1, onnx_trace=self.onnx_trace,
189+
).type_as(attn_weights)
188190
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
189191

190192
attn = torch.bmm(attn_weights, v)

0 commit comments

Comments
 (0)