Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 81 additions & 2 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

import warnings
from typing import Optional

from tqdm import tqdm
from functools import lru_cache
from packaging import version
from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE, Unigram, WordPiece
Expand Down Expand Up @@ -1691,6 +1692,81 @@ def converted(self) -> Tokenizer:

return tokenizer

class MistralConverter:
def __init__(
self,
vocab_file=None,
pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
add_prefix_space=False,
additional_special_tokens=None,
**kwargs,
):
self.vocab_file = vocab_file
self.pattern = pattern
self.add_prefix_space = add_prefix_space
self.additional_special_tokens = (
additional_special_tokens.keys()
if isinstance(additional_special_tokens, dict)
else additional_special_tokens
)

def extract_vocab_merges_from_model(self, tiktoken_url: str):
import json
import base64

with open(self.vocab_file, "r", encoding="utf-8") as f:
untyped = json.load(f)
self.pattern = untyped['config']['pattern']
self.additional_special_tokens = [AddedToken(k['token_str'], special=k["is_control"]) for k in untyped['special_tokens']]
bpe_ranks = untyped["vocab"]
byte_encoder = bytes_to_unicode()

@lru_cache
def token_bytes_to_string(b):
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])

merges = []
vocab = {}
for idx, token in enumerate(self.additional_special_tokens):
vocab[token.content] = idx
bpe_ranks = [base64.b64decode(k["token_bytes"]) for k in bpe_ranks]
rank_set = set(bpe_ranks)
for rank, token in enumerate(tqdm(bpe_ranks, desc="Converting tekken.json to tokenizer.json")):
vocab[token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
local = []
for index in range(1, len(token)):
piece_l, piece_r = token[:index], token[index:]
if piece_l in rank_set and piece_r in rank_set and (piece_l + piece_r) in rank_set:
local.append((piece_l, piece_r, rank))
local = sorted(local, key=lambda x: (bpe_ranks.index(x[0]), bpe_ranks.index(x[1])), reverse=False)
merges.extend(local)
merges = sorted(merges, key=lambda val: val[2], reverse=False)
merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
return vocab, merges

def tokenizer(self):
vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab_file)
tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
if hasattr(tokenizer.model, "ignore_merges"):
tokenizer.model.ignore_merges = True
return tokenizer

def converted(self) -> Tokenizer:
tokenizer = self.tokenizer()
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
]
)
tokenizer.decoder = decoders.ByteLevel()

tokenizer.add_tokens(self.additional_special_tokens)
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)

return tokenizer

SLOW_TO_FAST_CONVERTERS = {
"AlbertTokenizer": AlbertConverter,
Expand Down Expand Up @@ -1771,7 +1847,10 @@ def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokeni
if tokenizer_class_name in SLOW_TO_FAST_CONVERTERS and not from_tiktoken:
converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
return converter_class(transformer_tokenizer).converted()

elif transformer_tokenizer.vocab_file.endswith("tekken.json"):
transformer_tokenizer.original_tokenizer = transformer_tokenizer
logger.info("Converting from Mistral tekken.json")
return MistralConverter(transformer_tokenizer.vocab_file).converted()
else:
try:
logger.info("Converting from Tiktoken")
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,8 +750,8 @@
(
"MistralCommonTokenizer"
if is_mistral_common_available()
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
else ("PreTrainedTokenizerFast" if is_tokenizers_available() else None),
"PreTrainedTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually the pattern is (slow, fast), here it's (fast, fast), not sure if intended
Maybe it should be instead:

(None, "MistralCommonTokenizer" if is_mistral_common_available() else ("PreTrainedTokenizerFast" if is_tokenizers_available() else None)

so that we never have slow one anyway?

),
),
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
Expand Down
27 changes: 25 additions & 2 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Optional, Union, overload
from huggingface_hub import list_repo_files

import numpy as np
from packaging import version
Expand Down Expand Up @@ -150,7 +151,7 @@ def __str__(self):

# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
FULL_TOKENIZER_FILE = "tokenizer.json"
_re_tokenizer_file = re.compile(r"tokenizer\.(.*)\.json")
_re_tokenizer_file = re.compile(r"(tokenizer|tekken)\.(.*)\.json")


class TruncationStrategy(ExplicitEnum):
Expand Down Expand Up @@ -2098,7 +2099,13 @@ def from_pretrained(
template = template.removesuffix(".jinja")
vocab_files[f"chat_template_{template}"] = f"{CHAT_TEMPLATE_DIR}/{template}.jinja"

# Get files from url, cache, or disk depending on the case
remote_files = list_repo_files(pretrained_model_name_or_path)
if not re.search(vocab_files["tokenizer_file"], "".join(remote_files)):
# mistral tokenizer names are different, but we can still convert them if
# mistral common is not there
other_pattern = "tekken.json|tokenizer.model.*"
vocab_files["vocab_file"] = re.search(other_pattern, "".join(remote_files)).group()

resolved_vocab_files = {}
for file_id, file_path in vocab_files.items():
if file_path is None:
Expand Down Expand Up @@ -2417,6 +2424,22 @@ def _from_pretrained(
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are"
" fine-tuned or trained."
)
if tokenizer.vocab_size > 100000 and getattr(tokenizer.backend_tokenizer, "pre_tokenizer", None) is not None:
from huggingface_hub import model_info
def is_base_mistral(model_id: str) -> bool:
model = model_info(model_id)
if model.tags is not None:
if re.search("base_model:.*mistralai", "".join(model.tags)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm so that's only for mistral org no? Should we directly check of `model_type in ["mistral" ....] so that it also works for other orgs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't do that until we download the config / config is there

return True
return False

if is_base_mistral(pretrained_model_name_or_path) and not kwargs.get("fix_regex"):
logger.warning(
f"The tokenizer you are loading from '{pretrained_model_name_or_path}'"
f" with an old regex pattern. This will lead to incorrect tokenization."
)
import tokenizers
tokenizer.backend_tokenizer.pre_tokenizer[0] = tokenizers.pre_tokenizers.Split(pattern=tokenizers.Regex(r"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"), behavior = "isolated")
return tokenizer

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/tokenization_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def __init__(self, *args, **kwargs):
if tokens:
self.add_tokens(tokens)

# there was an issue with mistral models where the pre_tokenizer's regex pattern
# is not correct. Here we try to fix it.
try:
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", self.add_prefix_space) != self.add_prefix_space:
Expand Down