Skip to content

Commit 0b189de

Browse files
authored
Merge pull request #674 from sillsdev/punctuation-normalizer-fix
Fix for #656 (mixed-language output)
2 parents 6db0f5f + 7d3c7d7 commit 0b189de

File tree

1 file changed

+31
-6
lines changed

1 file changed

+31
-6
lines changed

silnlp/nmt/hugging_face_config.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,19 @@
99
from enum import Enum
1010
from itertools import repeat
1111
from pathlib import Path
12-
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union, cast
12+
from typing import (
13+
Any,
14+
Callable,
15+
Dict,
16+
Iterable,
17+
List,
18+
Optional,
19+
Set,
20+
Tuple,
21+
TypeVar,
22+
Union,
23+
cast,
24+
)
1325

1426
import datasets.utils.logging as datasets_logging
1527
import evaluate
@@ -24,7 +36,10 @@
2436
from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef
2537
from sacremoses import MosesPunctNormalizer
2638
from tokenizers import AddedToken, NormalizedString, Regex
27-
from tokenizers.implementations import SentencePieceBPETokenizer, SentencePieceUnigramTokenizer
39+
from tokenizers.implementations import (
40+
SentencePieceBPETokenizer,
41+
SentencePieceUnigramTokenizer,
42+
)
2843
from tokenizers.normalizers import Normalizer
2944
from torch import Tensor, TensorType, nn, optim
3045
from torch.utils.data import Sampler
@@ -73,7 +88,14 @@
7388
from ..common.corpus import Term, count_lines, get_terms
7489
from ..common.environment import SIL_NLP_ENV
7590
from ..common.translator import DraftGroup, TranslationGroup
76-
from ..common.utils import NoiseMethod, ReplaceRandomToken, Side, create_noise_methods, get_mt_exp_dir, merge_dict
91+
from ..common.utils import (
92+
NoiseMethod,
93+
ReplaceRandomToken,
94+
Side,
95+
create_noise_methods,
96+
get_mt_exp_dir,
97+
merge_dict,
98+
)
7799
from .config import CheckpointType, Config, DataFile, NMTModel
78100
from .tokenizer import NullTokenizer, Tokenizer
79101

@@ -1185,13 +1207,16 @@ def translate(
11851207
ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST,
11861208
) -> Iterable[TranslationGroup]:
11871209
tokenizer = self._config.get_tokenizer()
1188-
if isinstance(tokenizer, (NllbTokenizer, NllbTokenizerFast)):
1189-
tokenizer = PunctuationNormalizingTokenizer(tokenizer)
1190-
11911210
model = self._create_inference_model(ckpt, tokenizer)
11921211
if model.config.max_length is not None and model.config.max_length < 512:
11931212
model.config.max_length = 512
11941213
lang_codes: Dict[str, str] = self._config.data["lang_codes"]
1214+
1215+
# The tokenizer isn't wrapped until after calling _create_inference_model,
1216+
# because the tokenizer's input/output language codes are set there
1217+
if isinstance(tokenizer, (NllbTokenizer, NllbTokenizerFast)):
1218+
tokenizer = PunctuationNormalizingTokenizer(tokenizer)
1219+
11951220
pipeline = TranslationPipeline(
11961221
model=model,
11971222
tokenizer=tokenizer,

0 commit comments

Comments
 (0)