Skip to content

Commit 5ca82fc

Browse files
authored
convert : workaround for AutoConfig dummy labels (#13881)
1 parent 6385b84 commit 5ca82fc

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

convert_hf_to_gguf.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3690,14 +3690,20 @@ def __init__(self, *args, **kwargs):
36903690
super().__init__(*args, **kwargs)
36913691
self.vocab_size = None
36923692

3693+
if cls_out_labels := self.hparams.get("id2label"):
3694+
if len(cls_out_labels) == 2 and cls_out_labels[0] == "LABEL_0":
3695+
# Remove dummy labels added by AutoConfig
3696+
cls_out_labels = None
3697+
self.cls_out_labels = cls_out_labels
3698+
36933699
def set_gguf_parameters(self):
36943700
super().set_gguf_parameters()
36953701
self.gguf_writer.add_causal_attention(False)
36963702
self._try_set_pooling_type()
36973703

3698-
if cls_out_labels := self.hparams.get("id2label"):
3704+
if self.cls_out_labels:
36993705
key_name = gguf.Keys.Classifier.OUTPUT_LABELS.format(arch = gguf.MODEL_ARCH_NAMES[self.model_arch])
3700-
self.gguf_writer.add_array(key_name, [v for k, v in sorted(cls_out_labels.items())])
3706+
self.gguf_writer.add_array(key_name, [v for k, v in sorted(self.cls_out_labels.items())])
37013707

37023708
def set_vocab(self):
37033709
tokens, toktypes, tokpre = self.get_vocab_base()
@@ -3749,7 +3755,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
37493755
if name.startswith("cls.seq_relationship"):
37503756
return []
37513757

3752-
if self.hparams.get("id2label"):
3758+
if self.cls_out_labels:
37533759
# For BertForSequenceClassification (direct projection layer)
37543760
if name == "classifier.weight":
37553761
name = "classifier.out_proj.weight"

0 commit comments

Comments
 (0)