@@ -2506,6 +2506,112 @@ def set_gguf_parameters(self):
2506
2506
self .gguf_writer .add_rope_freq_base (self .hparams ["rotary_emb_base" ])
2507
2507
2508
2508
2509
+ @Model .register ("XLMRobertaModel" )
2510
+ class XLMRobertaModel (BertModel ):
2511
+ model_arch = gguf .MODEL_ARCH .BERT
2512
+
2513
+ def __init__ (self , * args , ** kwargs ):
2514
+ super ().__init__ (* args , ** kwargs )
2515
+
2516
+ # we need the pad_token_id to know how to chop down position_embd matrix
2517
+ if (pad_token_id := self .hparams .get ("pad_token_id" )) is not None :
2518
+ self ._position_offset = 1 + pad_token_id
2519
+ if "max_position_embeddings" in self .hparams :
2520
+ self .hparams ["max_position_embeddings" ] -= self ._position_offset
2521
+ else :
2522
+ self ._position_offset = None
2523
+
2524
+ def set_vocab (self ):
2525
+ # to avoid TypeError: Descriptors cannot be created directly
2526
+ # exception when importing sentencepiece_model_pb2
2527
+ os .environ ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION" ] = "python"
2528
+ from sentencepiece import SentencePieceProcessor
2529
+ from sentencepiece import sentencepiece_model_pb2 as model
2530
+
2531
+ tokenizer_path = self .dir_model / 'sentencepiece.bpe.model'
2532
+ if not tokenizer_path .is_file ():
2533
+ raise FileNotFoundError (f"File not found: { tokenizer_path } " )
2534
+
2535
+ sentencepiece_model = model .ModelProto () # pyright: ignore[reportAttributeAccessIssue]
2536
+ sentencepiece_model .ParseFromString (open (tokenizer_path , "rb" ).read ())
2537
+ assert sentencepiece_model .trainer_spec .model_type == 1 # UNIGRAM
2538
+
2539
+ add_prefix = sentencepiece_model .normalizer_spec .add_dummy_prefix
2540
+ remove_whitespaces = sentencepiece_model .normalizer_spec .remove_extra_whitespaces
2541
+ precompiled_charsmap = sentencepiece_model .normalizer_spec .precompiled_charsmap
2542
+
2543
+ tokenizer = SentencePieceProcessor ()
2544
+ tokenizer .LoadFromFile (str (tokenizer_path ))
2545
+
2546
+ vocab_size = self .hparams .get ('vocab_size' , tokenizer .vocab_size ())
2547
+
2548
+ tokens : list [bytes ] = [f"[PAD{ i } ]" .encode ("utf-8" ) for i in range (vocab_size )]
2549
+ scores : list [float ] = [- 10000.0 ] * vocab_size
2550
+ toktypes : list [int ] = [SentencePieceTokenTypes .UNUSED ] * vocab_size
2551
+
2552
+ for token_id in range (tokenizer .vocab_size ()):
2553
+ piece = tokenizer .IdToPiece (token_id )
2554
+ text = piece .encode ("utf-8" )
2555
+ score = tokenizer .GetScore (token_id )
2556
+
2557
+ toktype = SentencePieceTokenTypes .NORMAL
2558
+ if tokenizer .IsUnknown (token_id ):
2559
+ toktype = SentencePieceTokenTypes .UNKNOWN
2560
+ elif tokenizer .IsControl (token_id ):
2561
+ toktype = SentencePieceTokenTypes .CONTROL
2562
+ elif tokenizer .IsUnused (token_id ):
2563
+ toktype = SentencePieceTokenTypes .UNUSED
2564
+ elif tokenizer .IsByte (token_id ):
2565
+ toktype = SentencePieceTokenTypes .BYTE
2566
+
2567
+ tokens [token_id ] = text
2568
+ scores [token_id ] = score
2569
+ toktypes [token_id ] = toktype
2570
+
2571
+ if vocab_size > len (tokens ):
2572
+ pad_count = vocab_size - len (tokens )
2573
+ logger .debug (f"Padding vocab with { pad_count } token(s) - [PAD1] through [PAD{ pad_count } ]" )
2574
+ for i in range (1 , pad_count + 1 ):
2575
+ tokens .append (bytes (f"[PAD{ i } ]" , encoding = "utf-8" ))
2576
+ scores .append (- 1000.0 )
2577
+ toktypes .append (SentencePieceTokenTypes .UNUSED )
2578
+
2579
+ # realign tokens (see HF tokenizer code)
2580
+ tokens = [b'<s>' , b'<pad>' , b'</s>' , b'<unk>' ] + tokens [3 :- 1 ]
2581
+ scores = [0.0 , 0.0 , 0.0 , 0.0 ] + scores [3 :- 1 ]
2582
+ toktypes = [
2583
+ SentencePieceTokenTypes .CONTROL ,
2584
+ SentencePieceTokenTypes .CONTROL ,
2585
+ SentencePieceTokenTypes .CONTROL ,
2586
+ SentencePieceTokenTypes .UNKNOWN ,
2587
+ ] + toktypes [3 :- 1 ]
2588
+
2589
+ self .gguf_writer .add_tokenizer_model ("t5" )
2590
+ self .gguf_writer .add_tokenizer_pre ("default" )
2591
+ self .gguf_writer .add_token_list (tokens )
2592
+ self .gguf_writer .add_token_scores (scores )
2593
+ self .gguf_writer .add_token_types (toktypes )
2594
+ self .gguf_writer .add_add_space_prefix (add_prefix )
2595
+ self .gguf_writer .add_token_type_count (1 )
2596
+ self .gguf_writer .add_remove_extra_whitespaces (remove_whitespaces )
2597
+ if precompiled_charsmap :
2598
+ self .gguf_writer .add_precompiled_charsmap (precompiled_charsmap )
2599
+
2600
+ special_vocab = gguf .SpecialVocab (self .dir_model , n_vocab = len (tokens ))
2601
+ special_vocab .add_to_gguf (self .gguf_writer )
2602
+
2603
+ self .gguf_writer .add_add_bos_token (True )
2604
+ self .gguf_writer .add_add_eos_token (True )
2605
+
2606
+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
2607
+ # position embeddings start at pad_token_id + 1, so just chop down the weight tensor
2608
+ if name == "embeddings.position_embeddings.weight" :
2609
+ if self ._position_offset is not None :
2610
+ data_torch = data_torch [self ._position_offset :,:]
2611
+
2612
+ return super ().modify_tensors (data_torch , name , bid )
2613
+
2614
+
2509
2615
@Model .register ("GemmaForCausalLM" )
2510
2616
class GemmaModel (Model ):
2511
2617
model_arch = gguf .MODEL_ARCH .GEMMA
0 commit comments