Skip to content
This repository was archived by the owner on Mar 20, 2026. It is now read-only.

Commit 9b19ede

Browse files
Myle Ottfacebook-github-bot
authored andcommitted
Fix keyword arguments in translation_moe task
Summary: Pull Request resolved: #1546 Differential Revision: D19225548 Pulled By: myleott fbshipit-source-id: 43240cb90ca477ab7a790386ab2d9f4fd14e2625
1 parent 4333437 commit 9b19ede

1 file changed

Lines changed: 8 additions & 2 deletions

File tree

fairseq/tasks/translation_moe.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,19 @@ def _get_loss(self, sample, model, criterion):
121121
bsz = sample['target'].size(0)
122122

123123
def get_lprob_y(encoder_out, prev_output_tokens_k):
124-
net_output = model.decoder(prev_output_tokens_k, encoder_out)
124+
net_output = model.decoder(
125+
prev_output_tokens=prev_output_tokens_k,
126+
encoder_out=encoder_out,
127+
)
125128
loss, _ = criterion.compute_loss(model, net_output, sample, reduce=False)
126129
loss = loss.view(bsz, -1)
127130
return -loss.sum(dim=1, keepdim=True) # -> B x 1
128131

129132
def get_lprob_yz(winners=None):
130-
encoder_out = model.encoder(sample['net_input']['src_tokens'], sample['net_input']['src_lengths'])
133+
encoder_out = model.encoder(
134+
src_tokens=sample['net_input']['src_tokens'],
135+
src_lengths=sample['net_input']['src_lengths'],
136+
)
131137

132138
if winners is None:
133139
lprob_y = []

0 commit comments

Comments
 (0)