Skip to content

Commit cdd1889

Browse files
authored
convert : add support for XLMRoberta embedding models (#8658)
* add conversion for bge-m3; small fix in unigram tokenizer * clean up and simplify XLMRoberta conversion
1 parent c21a896 commit cdd1889

File tree

2 files changed

+110
-1
lines changed

2 files changed

+110
-1
lines changed

convert_hf_to_gguf.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2506,6 +2506,112 @@ def set_gguf_parameters(self):
25062506
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
25072507

25082508

2509+
@Model.register("XLMRobertaModel")
2510+
class XLMRobertaModel(BertModel):
2511+
model_arch = gguf.MODEL_ARCH.BERT
2512+
2513+
def __init__(self, *args, **kwargs):
2514+
super().__init__(*args, **kwargs)
2515+
2516+
# we need the pad_token_id to know how to chop down position_embd matrix
2517+
if (pad_token_id := self.hparams.get("pad_token_id")) is not None:
2518+
self._position_offset = 1 + pad_token_id
2519+
if "max_position_embeddings" in self.hparams:
2520+
self.hparams["max_position_embeddings"] -= self._position_offset
2521+
else:
2522+
self._position_offset = None
2523+
2524+
def set_vocab(self):
2525+
# to avoid TypeError: Descriptors cannot be created directly
2526+
# exception when importing sentencepiece_model_pb2
2527+
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
2528+
from sentencepiece import SentencePieceProcessor
2529+
from sentencepiece import sentencepiece_model_pb2 as model
2530+
2531+
tokenizer_path = self.dir_model / 'sentencepiece.bpe.model'
2532+
if not tokenizer_path.is_file():
2533+
raise FileNotFoundError(f"File not found: {tokenizer_path}")
2534+
2535+
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
2536+
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
2537+
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
2538+
2539+
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
2540+
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
2541+
precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
2542+
2543+
tokenizer = SentencePieceProcessor()
2544+
tokenizer.LoadFromFile(str(tokenizer_path))
2545+
2546+
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
2547+
2548+
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
2549+
scores: list[float] = [-10000.0] * vocab_size
2550+
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
2551+
2552+
for token_id in range(tokenizer.vocab_size()):
2553+
piece = tokenizer.IdToPiece(token_id)
2554+
text = piece.encode("utf-8")
2555+
score = tokenizer.GetScore(token_id)
2556+
2557+
toktype = SentencePieceTokenTypes.NORMAL
2558+
if tokenizer.IsUnknown(token_id):
2559+
toktype = SentencePieceTokenTypes.UNKNOWN
2560+
elif tokenizer.IsControl(token_id):
2561+
toktype = SentencePieceTokenTypes.CONTROL
2562+
elif tokenizer.IsUnused(token_id):
2563+
toktype = SentencePieceTokenTypes.UNUSED
2564+
elif tokenizer.IsByte(token_id):
2565+
toktype = SentencePieceTokenTypes.BYTE
2566+
2567+
tokens[token_id] = text
2568+
scores[token_id] = score
2569+
toktypes[token_id] = toktype
2570+
2571+
if vocab_size > len(tokens):
2572+
pad_count = vocab_size - len(tokens)
2573+
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
2574+
for i in range(1, pad_count + 1):
2575+
tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
2576+
scores.append(-1000.0)
2577+
toktypes.append(SentencePieceTokenTypes.UNUSED)
2578+
2579+
# realign tokens (see HF tokenizer code)
2580+
tokens = [b'<s>', b'<pad>', b'</s>', b'<unk>'] + tokens[3:-1]
2581+
scores = [0.0, 0.0, 0.0, 0.0] + scores[3:-1]
2582+
toktypes = [
2583+
SentencePieceTokenTypes.CONTROL,
2584+
SentencePieceTokenTypes.CONTROL,
2585+
SentencePieceTokenTypes.CONTROL,
2586+
SentencePieceTokenTypes.UNKNOWN,
2587+
] + toktypes[3:-1]
2588+
2589+
self.gguf_writer.add_tokenizer_model("t5")
2590+
self.gguf_writer.add_tokenizer_pre("default")
2591+
self.gguf_writer.add_token_list(tokens)
2592+
self.gguf_writer.add_token_scores(scores)
2593+
self.gguf_writer.add_token_types(toktypes)
2594+
self.gguf_writer.add_add_space_prefix(add_prefix)
2595+
self.gguf_writer.add_token_type_count(1)
2596+
self.gguf_writer.add_remove_extra_whitespaces(remove_whitespaces)
2597+
if precompiled_charsmap:
2598+
self.gguf_writer.add_precompiled_charsmap(precompiled_charsmap)
2599+
2600+
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
2601+
special_vocab.add_to_gguf(self.gguf_writer)
2602+
2603+
self.gguf_writer.add_add_bos_token(True)
2604+
self.gguf_writer.add_add_eos_token(True)
2605+
2606+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2607+
# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
2608+
if name == "embeddings.position_embeddings.weight":
2609+
if self._position_offset is not None:
2610+
data_torch = data_torch[self._position_offset:,:]
2611+
2612+
return super().modify_tensors(data_torch, name, bid)
2613+
2614+
25092615
@Model.register("GemmaForCausalLM")
25102616
class GemmaModel(Model):
25112617
model_arch = gguf.MODEL_ARCH.GEMMA

src/llama-vocab.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,9 @@ struct llm_tokenizer_ugm {
816816
* the best tokenization.
817817
*/
818818
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
819+
// get current size of output (for reversal later)
820+
size_t output_size = output.size();
821+
819822
// normalize the input first
820823
std::string normalized;
821824
normalize(text, &normalized);
@@ -895,7 +898,7 @@ struct llm_tokenizer_ugm {
895898
}
896899

897900
// reverse the output since we added tokens starting from the end of the input
898-
std::reverse(output.begin(), output.end());
901+
std::reverse(output.begin() + output_size, output.end());
899902
}
900903

901904
private:

0 commit comments

Comments
 (0)