Skip to content
This repository was archived by the owner on Mar 20, 2026. It is now read-only.

Commit 5065077

Browse files
Myle Ottfacebook-github-bot
authored andcommitted
Use cross entropy from apex for improved memory efficiency (#1122)
Summary: Pull Request resolved: fairinternal/fairseq-py#1122 Reviewed By: ngoyal2707 Differential Revision: D20745717 Pulled By: myleott fbshipit-source-id: 877a1185f17952461ef204d8ad7f05b8d37b1fd9
1 parent 4d2efae commit 5065077

4 files changed

Lines changed: 61 additions & 8 deletions

File tree

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,12 @@ and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more example
8989
* [PyTorch](http://pytorch.org/) version >= 1.4.0
9090
* Python version >= 3.6
9191
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
92-
* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library with the `--cuda_ext` and `--deprecated_fused_adam` options
92+
* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library:
93+
```bash
94+
git clone https://github.com/NVIDIA/apex
95+
cd apex
96+
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--deprecated_fused_adam" --global-option="--xentropy" --global-option="--fast_multihead_attn" ./
97+
```
9398

9499
To install fairseq:
95100
```bash

fairseq/criterions/masked_lm.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
import torch.nn.functional as F
1010

11-
from fairseq import metrics, utils
11+
from fairseq import metrics, modules, utils
1212
from fairseq.criterions import FairseqCriterion, register_criterion
1313

1414

@@ -47,12 +47,8 @@ def forward(self, model, sample, reduce=True):
4747
targets = model.get_targets(sample, [logits])
4848
targets = targets[masked_tokens]
4949

50-
loss = F.nll_loss(
51-
F.log_softmax(
52-
logits.view(-1, logits.size(-1)),
53-
dim=-1,
54-
dtype=torch.float32,
55-
),
50+
loss = modules.cross_entropy(
51+
logits.view(-1, logits.size(-1)),
5652
targets.view(-1),
5753
reduction='sum',
5854
ignore_index=self.padding_idx,

fairseq/modules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .beamable_mm import BeamableMM
99
from .character_token_embedder import CharacterTokenEmbedder
1010
from .conv_tbc import ConvTBC
11+
from .cross_entropy import cross_entropy
1112
from .downsampled_multihead_attention import DownsampledMultiHeadAttention
1213
from .dynamic_convolution import DynamicConv, DynamicConv1dTBC
1314
from .dynamic_crf_layer import DynamicCRF
@@ -36,6 +37,7 @@
3637
'BeamableMM',
3738
'CharacterTokenEmbedder',
3839
'ConvTBC',
40+
'cross_entropy',
3941
'DownsampledMultiHeadAttention',
4042
'DynamicConv1dTBC',
4143
'DynamicConv',

fairseq/modules/cross_entropy.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
8+
import torch
9+
import torch.nn.functional as F
10+
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction='mean'):
16+
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
17+
return F.nll_loss(
18+
lprobs, target, ignore_index=ignore_index, reduction=reduction,
19+
)
20+
21+
22+
try:
23+
from apex.contrib import xentropy
24+
25+
logger.info('using fused cross entropy')
26+
27+
def cross_entropy(logits, target, ignore_index=-100, reduction='mean'):
28+
if logits.device == torch.device('cpu'):
29+
return _cross_entropy_pytorch(logits, target, ignore_index, reduction)
30+
else:
31+
half_to_float = (logits.dtype == torch.half)
32+
losses = xentropy.SoftmaxCrossEntropyLoss.apply(
33+
logits, target, 0.0, ignore_index, half_to_float,
34+
)
35+
if reduction == 'sum':
36+
return losses.sum()
37+
elif reduction == 'mean':
38+
if ignore_index >= 0:
39+
return losses.sum() / target.ne(ignore_index).sum()
40+
else:
41+
return losses.mean()
42+
elif reduction == 'none':
43+
return losses
44+
else:
45+
raise NotImplementedError
46+
47+
except ImportError:
48+
49+
def cross_entropy(logits, target, ignore_index=-100, reduction='mean'):
50+
return _cross_entropy_pytorch(logits, target, ignore_index, reduction)

0 commit comments

Comments
 (0)