Skip to content
Merged
1 change: 1 addition & 0 deletions docs/source/en/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ For now the supported model architectures are the architectures that have been v
- LLaMa
- Mistral
- Qwen2
- T5

## Example usage

Expand Down
128 changes: 125 additions & 3 deletions src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from array import array

import numpy as np
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers
from tokenizers.models import BPE
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE, Unigram

from .. import AddedToken
from ..convert_slow_tokenizer import LlamaConverter, Qwen2Converter
from ..convert_slow_tokenizer import LlamaConverter, Qwen2Converter, T5Converter
from ..utils import logging
from ..utils.logging import tqdm

Expand Down Expand Up @@ -79,6 +79,51 @@
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
"t5": {
"token_embd": "shared",
"dec.blk.{bid}.attn_q": "decoder.block.{bid}.layer.0.SelfAttention.q",
"dec.blk.{bid}.attn_k": "decoder.block.{bid}.layer.0.SelfAttention.k",
"dec.blk.{bid}.attn_v": "decoder.block.{bid}.layer.0.SelfAttention.v",
"dec.blk.{bid}.attn_o": "decoder.block.{bid}.layer.0.SelfAttention.o",
"dec.blk.{bid}.attn_rel_b": "decoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias",
"dec.blk.{bid}.attn_norm": "decoder.block.{bid}.layer.0.layer_norm",
"dec.blk.{bid}.cross_attn_q": "decoder.block.{bid}.layer.1.EncDecAttention.q",
"dec.blk.{bid}.cross_attn_k": "decoder.block.{bid}.layer.1.EncDecAttention.k",
"dec.blk.{bid}.cross_attn_v": "decoder.block.{bid}.layer.1.EncDecAttention.v",
"dec.blk.{bid}.cross_attn_o": "decoder.block.{bid}.layer.1.EncDecAttention.o",
"dec.blk.{bid}.cross_attn_norm": "decoder.block.{bid}.layer.1.layer_norm",
"dec.blk.{bid}.ffn_gate": "decoder.block.{bid}.layer.2.DenseReluDense.wi_0",
"dec.blk.{bid}.ffn_up": "decoder.block.{bid}.layer.2.DenseReluDense.wi_1",
"dec.blk.{bid}.ffn_down": "decoder.block.{bid}.layer.2.DenseReluDense.wo",
"dec.blk.{bid}.ffn_norm": "decoder.block.{bid}.layer.2.layer_norm",
"dec.output_norm": "decoder.final_layer_norm",
"enc.blk.{bid}.attn_q": "encoder.block.{bid}.layer.0.SelfAttention.q",
"enc.blk.{bid}.attn_k": "encoder.block.{bid}.layer.0.SelfAttention.k",
"enc.blk.{bid}.attn_v": "encoder.block.{bid}.layer.0.SelfAttention.v",
"enc.blk.{bid}.attn_o": "encoder.block.{bid}.layer.0.SelfAttention.o",
"enc.blk.{bid}.attn_rel_b": "encoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias",
"enc.blk.{bid}.attn_norm": "encoder.block.{bid}.layer.0.layer_norm",
"enc.blk.{bid}.ffn_gate": "encoder.block.{bid}.layer.1.DenseReluDense.wi_0",
"enc.blk.{bid}.ffn_up": "encoder.block.{bid}.layer.1.DenseReluDense.wi_1",
"enc.blk.{bid}.ffn_down": "encoder.block.{bid}.layer.1.DenseReluDense.wo",
"enc.blk.{bid}.ffn_norm": "encoder.block.{bid}.layer.1.layer_norm",
"enc.output_norm": "encoder.final_layer_norm",
"output.weight": "lm_head.weight",
},
"t5encoder": {
"token_embd": "shared",
"enc.blk.{bid}.attn_q": "encoder.block.{bid}.layer.0.SelfAttention.q",
"enc.blk.{bid}.attn_k": "encoder.block.{bid}.layer.0.SelfAttention.k",
"enc.blk.{bid}.attn_v": "encoder.block.{bid}.layer.0.SelfAttention.v",
"enc.blk.{bid}.attn_o": "encoder.block.{bid}.layer.0.SelfAttention.o",
"enc.blk.{bid}.attn_rel_b": "encoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias",
"enc.blk.{bid}.attn_norm": "encoder.block.{bid}.layer.0.layer_norm",
"enc.blk.{bid}.ffn_gate": "encoder.block.{bid}.layer.1.DenseReluDense.wi_0",
"enc.blk.{bid}.ffn_up": "encoder.block.{bid}.layer.1.DenseReluDense.wi_1",
"enc.blk.{bid}.ffn_down": "encoder.block.{bid}.layer.1.DenseReluDense.wo",
"enc.blk.{bid}.ffn_norm": "encoder.block.{bid}.layer.1.layer_norm",
"enc.output_norm": "encoder.final_layer_norm",
},
}


Expand Down Expand Up @@ -123,6 +168,19 @@
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"t5": {
"context_length": "n_positions",
"block_count": "num_layers",
"feed_forward_length": "d_ff",
"embedding_length": "d_model",
"attention.key_length": "d_kv",
"attention.head_count": "num_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_epsilon": "layer_norm_epsilon",
"attention.relative_buckets_count": "relative_attention_num_buckets",
"decoder_start_token_id": "decoder_start_token_id",
"vocab_size": "vocab_size",
},
"tokenizer": {
"ggml.bos_token_id": "bos_token_id",
"ggml.eos_token_id": "eos_token_id",
Expand Down Expand Up @@ -355,9 +413,73 @@ def converted(self) -> Tokenizer:
return tokenizer


class GGUFT5Converter(T5Converter):
def __init__(self, tokenizer_dict):
# set dummy data to avoid unnecessary merges calculation
tokenizer_dict["merges"] = ["dummy text"]

self.proto = GGUFTokenizerSkeleton(tokenizer_dict)
self.token2id = {k: v for v, k in enumerate(self.proto.tokens)}
self.original_tokenizer = self.proto
self.additional_kwargs = {}

def vocab(self, proto):
return list(zip(proto.tokens, proto.scores))

def normalizer(self, proto):
if getattr(self.original_tokenizer, "legacy", True):
sequence = []
if getattr(self.original_tokenizer, "add_prefix_space", True):
sequence += [normalizers.Prepend(prepend="▁")]
sequence += [normalizers.Replace(pattern=" ", content="▁")]
return normalizers.Sequence(sequence)
return None # non-legacy, no normalizer

def post_processor(self):
return processors.TemplateProcessing(
single=["$A", "</s>"],
pair=["$A", "</s>", "$B", "</s>"],
special_tokens=[
("</s>", self.token2id["</s>"]),
],
)

def converted(self) -> Tokenizer:
vocab_scores = self.vocab(self.proto)
tokenizer = Tokenizer(
Unigram(
vocab_scores,
unk_id=self.proto.unk_token_id,
byte_fallback=False,
)
)

# Tokenizer assemble
normalizer = self.normalizer(self.proto)
if normalizer is not None:
tokenizer.normalizer = normalizer

replacement = "▁"
add_prefix_space = True
if hasattr(self.original_tokenizer, "add_prefix_space"):
add_prefix_space = self.original_tokenizer.add_prefix_space

pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
if pre_tokenizer is not None:
tokenizer.pre_tokenizer = pre_tokenizer

tokenizer.decoder = self.decoder(replacement, add_prefix_space)
post_processor = self.post_processor()
if post_processor:
tokenizer.post_processor = post_processor

return tokenizer


GGUF_TO_FAST_CONVERTERS = {
"llama": GGUFLlamaConverter,
"qwen2": GGUFQwen2Converter,
"t5": GGUFT5Converter,
}


Expand Down
14 changes: 12 additions & 2 deletions src/transformers/modeling_gguf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
# to add this patch to ensure things work correctly on our side.
if "llama" in architecture and "mistral" in model_name:
updated_architecture = "mistral"
elif "t5" in architecture or "t5encoder" in architecture:
parsed_parameters["config"]["tie_word_embeddings"] = False
parsed_parameters["config"]["is_gated_act"] = True
updated_architecture = "t5"
else:
updated_architecture = architecture

Expand Down Expand Up @@ -166,9 +170,15 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
elif ".attn_k." in name:
weights = reverse_permute_weights(weights, num_heads, num_kv_heads)

bid = None
if architecture in ("t5", "t5encoder"):
for chunk in name.split("."):
if chunk.isdigit():
bid = int(chunk)
break
for tensor_name in tensor_key_mapping:
if tensor_name in name:
name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
if tensor_name.format(bid=bid) in name:
name = name.replace(tensor_name.format(bid=bid), tensor_key_mapping[tensor_name].format(bid=bid))

# Use copy to avoid errors with numpy and pytorch
parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/t5/tokenization_t5_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(
kwargs["from_slow"] = True

super().__init__(
vocab_file,
vocab_file=vocab_file,
tokenizer_file=tokenizer_file,
eos_token=eos_token,
unk_token=unk_token,
Expand Down
18 changes: 17 additions & 1 deletion tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import tempfile
import unittest

from transformers import AddedToken, AutoModelForCausalLM, AutoTokenizer
from transformers import AddedToken, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.testing_utils import require_gguf, require_torch_gpu, slow, torch_device
from transformers.utils import is_torch_available

Expand All @@ -35,6 +35,7 @@ class GgufIntegrationTests(unittest.TestCase):
qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF"
llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF"
tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF"
t5_model_id = "repetitio/flan-t5-small"

# standard quants
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
Expand All @@ -61,6 +62,7 @@ class GgufIntegrationTests(unittest.TestCase):
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf"
f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf"
q8_0_t5_model_id = "flan-t5-small-q8_0.gguf"

example_text = "Hello"

Expand Down Expand Up @@ -340,6 +342,20 @@ def test_llama3_q4_0(self):
EXPECTED_TEXT = "Hello, I am interested in [The Park]\nThe"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_t5_q8_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.t5_model_id, gguf_file=self.q8_0_t5_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(
self.t5_model_id, gguf_file=self.q8_0_t5_model_id, device_map="auto", torch_dtype=torch.float16
)

T5_EXAMPLE_TEXT = "translate English to German: How old are you?"

text = tokenizer(T5_EXAMPLE_TEXT, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Wie ich er?"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_tokenization_xnli(self):
import tqdm
from datasets import load_dataset
Expand Down