Skip to content

Commit 3c63f65

Browse files
committed
transfer tensors in tests to cpu
1 parent 0ed7c3a commit 3c63f65

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

tests/ignite/metrics/nlp/test_bleu.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ def _test(candidates, references, average, smooth="no_smooth", smooth_nltk_fn=No
6464
assert pytest.approx(reference) == bleu._corpus_bleu(references, candidates)
6565

6666
bleu.update((candidates, references))
67-
assert pytest.approx(reference) == bleu.compute()
67+
computed = bleu.compute()
68+
if isinstance(computed, torch.Tensor):
69+
computed = computed.cpu().item()
70+
assert pytest.approx(reference) == computed
6871

6972

7073
@pytest.mark.parametrize(*parametrize_args)
@@ -153,7 +156,11 @@ def test_bleu_batch_macro(available_device):
153156
+ sentence_bleu(refs[1], hypotheses[1])
154157
+ sentence_bleu(refs[2], hypotheses[2])
155158
) / 3
156-
assert pytest.approx(bleu.compute()) == reference_bleu_score
159+
computed = bleu.compute()
160+
if isinstance(computed, torch.Tensor):
161+
computed = computed.cpu().item()
162+
163+
assert pytest.approx(computed) == reference_bleu_score
157164

158165
value = 0
159166
for _hypotheses, _refs in zip(hypotheses, refs):

0 commit comments

Comments
 (0)