@@ -3690,14 +3690,20 @@ def __init__(self, *args, **kwargs):
3690
3690
super ().__init__ (* args , ** kwargs )
3691
3691
self .vocab_size = None
3692
3692
3693
+ if cls_out_labels := self .hparams .get ("id2label" ):
3694
+ if len (cls_out_labels ) == 2 and cls_out_labels [0 ] == "LABEL_0" :
3695
+ # Remove dummy labels added by AutoConfig
3696
+ cls_out_labels = None
3697
+ self .cls_out_labels = cls_out_labels
3698
+
3693
3699
def set_gguf_parameters (self ):
3694
3700
super ().set_gguf_parameters ()
3695
3701
self .gguf_writer .add_causal_attention (False )
3696
3702
self ._try_set_pooling_type ()
3697
3703
3698
- if cls_out_labels := self .hparams . get ( "id2label" ) :
3704
+ if self .cls_out_labels :
3699
3705
key_name = gguf .Keys .Classifier .OUTPUT_LABELS .format (arch = gguf .MODEL_ARCH_NAMES [self .model_arch ])
3700
- self .gguf_writer .add_array (key_name , [v for k , v in sorted (cls_out_labels .items ())])
3706
+ self .gguf_writer .add_array (key_name , [v for k , v in sorted (self . cls_out_labels .items ())])
3701
3707
3702
3708
def set_vocab (self ):
3703
3709
tokens , toktypes , tokpre = self .get_vocab_base ()
@@ -3749,7 +3755,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
3749
3755
if name .startswith ("cls.seq_relationship" ):
3750
3756
return []
3751
3757
3752
- if self .hparams . get ( "id2label" ) :
3758
+ if self .cls_out_labels :
3753
3759
# For BertForSequenceClassification (direct projection layer)
3754
3760
if name == "classifier.weight" :
3755
3761
name = "classifier.out_proj.weight"
0 commit comments