12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ # 1. Standard library
15
16
import difflib
16
17
import json
17
18
import os
28
29
29
30
from ..models import auto as auto_module
30
31
from ..models .auto .configuration_auto import model_type_to_module_name
31
- from ..utils import is_flax_available , is_tf_available , is_torch_available , logging
32
+ from ..utils import (
33
+ is_flax_available ,
34
+ is_tf_available ,
35
+ is_torch_available ,
36
+ logging ,
37
+ )
32
38
from . import BaseTransformersCLICommand
33
39
from .add_fast_image_processor import add_fast_image_processor
34
40
@@ -1009,10 +1015,11 @@ def insert_tokenizer_in_auto_module(old_model_patterns: ModelPatterns, new_model
1009
1015
with open (TRANSFORMERS_PATH / "models" / "auto" / "tokenization_auto.py" , "r" , encoding = "utf-8" ) as f :
1010
1016
content = f .read ()
1011
1017
1018
+ pattern_tokenizer = re .compile (r"^\s*TOKENIZER_MAPPING_NAMES\s*=\s*OrderedDict\b" )
1012
1019
lines = content .split ("\n " )
1013
1020
idx = 0
1014
1021
# First we get to the TOKENIZER_MAPPING_NAMES block.
1015
- while not lines [idx ]. startswith ( " TOKENIZER_MAPPING_NAMES = OrderedDict(" ):
1022
+ while not pattern_tokenizer . search ( lines [idx ]):
1016
1023
idx += 1
1017
1024
idx += 1
1018
1025
@@ -1024,9 +1031,12 @@ def insert_tokenizer_in_auto_module(old_model_patterns: ModelPatterns, new_model
1024
1031
# Otherwise it takes several lines until we get to a "),"
1025
1032
else :
1026
1033
block = []
1027
- while not lines [idx ].startswith (" )," ):
1034
+ # should change to " )," instead of " ),"
1035
+ while not lines [idx ].startswith (" )," ):
1028
1036
block .append (lines [idx ])
1029
1037
idx += 1
1038
+ # if the lines[idx] does start with " )," we still need it in our block
1039
+ block .append (lines [idx ])
1030
1040
block = "\n " .join (block )
1031
1041
idx += 1
1032
1042
0 commit comments