Skip to content

Commit 0812689

Browse files
alexzmsgithub-actions[bot]
authored andcommitted
[add-new-model-like] Robust search & proper outer '),' in tokenizer mapping (huggingface#38703)
* [add-new-model-like] Robust search & proper outer '),' in tokenizer mapping * code-style: arrange the importation in add_new_model_like.py * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent b190dc2 commit 0812689

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

src/transformers/commands/add_new_model_like.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# 1. Standard library
1516
import difflib
1617
import json
1718
import os
@@ -28,7 +29,12 @@
2829

2930
from ..models import auto as auto_module
3031
from ..models.auto.configuration_auto import model_type_to_module_name
31-
from ..utils import is_flax_available, is_tf_available, is_torch_available, logging
32+
from ..utils import (
33+
is_flax_available,
34+
is_tf_available,
35+
is_torch_available,
36+
logging,
37+
)
3238
from . import BaseTransformersCLICommand
3339
from .add_fast_image_processor import add_fast_image_processor
3440

@@ -1009,10 +1015,11 @@ def insert_tokenizer_in_auto_module(old_model_patterns: ModelPatterns, new_model
10091015
with open(TRANSFORMERS_PATH / "models" / "auto" / "tokenization_auto.py", "r", encoding="utf-8") as f:
10101016
content = f.read()
10111017

1018+
pattern_tokenizer = re.compile(r"^\s*TOKENIZER_MAPPING_NAMES\s*=\s*OrderedDict\b")
10121019
lines = content.split("\n")
10131020
idx = 0
10141021
# First we get to the TOKENIZER_MAPPING_NAMES block.
1015-
while not lines[idx].startswith(" TOKENIZER_MAPPING_NAMES = OrderedDict("):
1022+
while not pattern_tokenizer.search(lines[idx]):
10161023
idx += 1
10171024
idx += 1
10181025

@@ -1024,9 +1031,12 @@ def insert_tokenizer_in_auto_module(old_model_patterns: ModelPatterns, new_model
10241031
# Otherwise it takes several lines until we get to a "),"
10251032
else:
10261033
block = []
1027-
while not lines[idx].startswith(" ),"):
1034+
# should change to " )," instead of " ),"
1035+
while not lines[idx].startswith(" ),"):
10281036
block.append(lines[idx])
10291037
idx += 1
1038+
# if the lines[idx] does start with " )," we still need it in our block
1039+
block.append(lines[idx])
10301040
block = "\n".join(block)
10311041
idx += 1
10321042

0 commit comments

Comments
 (0)