|
9 | 9 | from enum import Enum
|
10 | 10 | from itertools import repeat
|
11 | 11 | from pathlib import Path
|
12 |
| -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union, cast |
| 12 | +from typing import ( |
| 13 | + Any, |
| 14 | + Callable, |
| 15 | + Dict, |
| 16 | + Iterable, |
| 17 | + List, |
| 18 | + Optional, |
| 19 | + Set, |
| 20 | + Tuple, |
| 21 | + TypeVar, |
| 22 | + Union, |
| 23 | + cast, |
| 24 | +) |
13 | 25 |
|
14 | 26 | import datasets.utils.logging as datasets_logging
|
15 | 27 | import evaluate
|
|
24 | 36 | from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef
|
25 | 37 | from sacremoses import MosesPunctNormalizer
|
26 | 38 | from tokenizers import AddedToken, NormalizedString, Regex
|
27 |
| -from tokenizers.implementations import SentencePieceBPETokenizer, SentencePieceUnigramTokenizer |
| 39 | +from tokenizers.implementations import ( |
| 40 | + SentencePieceBPETokenizer, |
| 41 | + SentencePieceUnigramTokenizer, |
| 42 | +) |
28 | 43 | from tokenizers.normalizers import Normalizer
|
29 | 44 | from torch import Tensor, TensorType, nn, optim
|
30 | 45 | from torch.utils.data import Sampler
|
|
73 | 88 | from ..common.corpus import Term, count_lines, get_terms
|
74 | 89 | from ..common.environment import SIL_NLP_ENV
|
75 | 90 | from ..common.translator import DraftGroup, TranslationGroup
|
76 |
| -from ..common.utils import NoiseMethod, ReplaceRandomToken, Side, create_noise_methods, get_mt_exp_dir, merge_dict |
| 91 | +from ..common.utils import ( |
| 92 | + NoiseMethod, |
| 93 | + ReplaceRandomToken, |
| 94 | + Side, |
| 95 | + create_noise_methods, |
| 96 | + get_mt_exp_dir, |
| 97 | + merge_dict, |
| 98 | +) |
77 | 99 | from .config import CheckpointType, Config, DataFile, NMTModel
|
78 | 100 | from .tokenizer import NullTokenizer, Tokenizer
|
79 | 101 |
|
@@ -1185,13 +1207,16 @@ def translate(
|
1185 | 1207 | ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST,
|
1186 | 1208 | ) -> Iterable[TranslationGroup]:
|
1187 | 1209 | tokenizer = self._config.get_tokenizer()
|
1188 |
| - if isinstance(tokenizer, (NllbTokenizer, NllbTokenizerFast)): |
1189 |
| - tokenizer = PunctuationNormalizingTokenizer(tokenizer) |
1190 |
| - |
1191 | 1210 | model = self._create_inference_model(ckpt, tokenizer)
|
1192 | 1211 | if model.config.max_length is not None and model.config.max_length < 512:
|
1193 | 1212 | model.config.max_length = 512
|
1194 | 1213 | lang_codes: Dict[str, str] = self._config.data["lang_codes"]
|
| 1214 | + |
| 1215 | + # The tokenizer isn't wrapped until after calling _create_inference_model, |
| 1216 | + # because the tokenizer's input/output language codes are set there |
| 1217 | + if isinstance(tokenizer, (NllbTokenizer, NllbTokenizerFast)): |
| 1218 | + tokenizer = PunctuationNormalizingTokenizer(tokenizer) |
| 1219 | + |
1195 | 1220 | pipeline = TranslationPipeline(
|
1196 | 1221 | model=model,
|
1197 | 1222 | tokenizer=tokenizer,
|
|
0 commit comments