Skip to content
This repository was archived by the owner on Mar 20, 2026. It is now read-only.

Commit 718677e

Browse files
Naman Goyalfacebook-github-bot
authored andcommitted
dont project maske tokens for mlm loss (#859)
Summary: This saves ~4-5gb gpu memory while training roberta large with `seq_len=512`. I am able to fit `--max-sentences=16` on `volta32gb` for `roberta-large` Pull Request resolved: fairinternal/fairseq-py#859 Differential Revision: D17435814 fbshipit-source-id: 2663909768fac0ef0102107613770ee01b1f8c00
1 parent 31dd13f commit 718677e

2 files changed

Lines changed: 16 additions & 9 deletions

File tree

fairseq/criterions/masked_lm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,11 @@ def forward(self, model, sample, reduce=True):
3030
3) logging outputs to display while training
3131
"""
3232
# compute MLM loss
33-
logits = model(**sample['net_input'], return_all_hiddens=False)[0]
33+
masked_tokens = sample['target'].ne(self.padding_idx)
34+
logits = model(**sample['net_input'], masked_tokens=masked_tokens)[0]
3435
targets = model.get_targets(sample, [logits])
36+
targets = targets[masked_tokens]
37+
3538
loss = F.nll_loss(
3639
F.log_softmax(
3740
logits.view(-1, logits.size(-1)),
@@ -43,7 +46,7 @@ def forward(self, model, sample, reduce=True):
4346
ignore_index=self.padding_idx,
4447
)
4548

46-
sample_size = targets.ne(self.padding_idx).int().sum().item()
49+
sample_size = masked_tokens.int().sum().item()
4750

4851
logging_output = {
4952
'loss': utils.item(loss.data) if reduce else loss.data,
@@ -64,6 +67,7 @@ def aggregate_logging_outputs(logging_outputs):
6467

6568
agg_output = {
6669
'loss': loss / sample_size / math.log(2),
70+
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if ntokens > 0 else 0.,
6771
'ntokens': ntokens,
6872
'nsentences': nsentences,
6973
'sample_size': sample_size,

fairseq/models/roberta/model.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,17 @@ def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
201201
self.weight = weight
202202
self.bias = nn.Parameter(torch.zeros(output_dim))
203203

204-
def forward(self, features, **kwargs):
204+
def forward(self, features, masked_tokens=None, **kwargs):
205+
# Only project the unmasked tokens while training,
206+
# saves both memory and computation
207+
if masked_tokens is not None:
208+
features = features[masked_tokens, :]
209+
205210
x = self.dense(features)
206211
x = self.activation_fn(x)
207212
x = self.layer_norm(x)
208-
209213
# project back to size of vocabulary with bias
210214
x = F.linear(x, self.weight) + self.bias
211-
212215
return x
213216

214217

@@ -265,7 +268,7 @@ def __init__(self, args, dictionary):
265268
weight=self.sentence_encoder.embed_tokens.weight,
266269
)
267270

268-
def forward(self, src_tokens, features_only=False, return_all_hiddens=False, **unused):
271+
def forward(self, src_tokens, features_only=False, return_all_hiddens=False, masked_tokens=None, **unused):
269272
"""
270273
Args:
271274
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
@@ -283,7 +286,7 @@ def forward(self, src_tokens, features_only=False, return_all_hiddens=False, **u
283286
"""
284287
x, extra = self.extract_features(src_tokens, return_all_hiddens)
285288
if not features_only:
286-
x = self.output_layer(x)
289+
x = self.output_layer(x, masked_tokens=masked_tokens)
287290
return x, extra
288291

289292
def extract_features(self, src_tokens, return_all_hiddens=False, **unused):
@@ -293,8 +296,8 @@ def extract_features(self, src_tokens, return_all_hiddens=False, **unused):
293296
features = inner_states[-1]
294297
return features, {'inner_states': inner_states if return_all_hiddens else None}
295298

296-
def output_layer(self, features, **unused):
297-
return self.lm_head(features)
299+
def output_layer(self, features, masked_tokens=None, **unused):
300+
return self.lm_head(features, masked_tokens)
298301

299302
def max_positions(self):
300303
"""Maximum output length supported by the encoder."""

0 commit comments

Comments
 (0)