Skip to content

Commit 3ef8896

Browse files
authored
Encoder-Decoder Gemma (#38332)
* Initial submit * Fix bugs: 1. add __init__ file 2. tied word embedding 3. support flash/flex attention 4. model saving and loading * Code refactor: * Rename encdecgemma to t5gemma. * Split attention into self- and cross-attention * Split stack into encoder and decoder * Add test cases * Add auto configuration * Update configurations. * Fix bugs related to copy and attribute checks * Fix type union * Fix merge errors * run ruff format * Run make style and update tests. * Add t5gemma model doc. * ruff and style formatting. * Add missed module config. * Add dummy checkpoint link to pass tests (need updated when real checkpoints are uplioaded.). * Update model doc. * Minor updates following Arthur's comments: * replace docstrings with auto_docstrings * remove checkpoint layers * remove deprecate_kwargs * fix rebase errors * Fix docstring issues. * fix t5gemma doc issue. * run ruff format * Updates: * split encoder-only model out * make t5gemmamodel encoder-decoder only * update token and sequence classification * update tests
1 parent af98702 commit 3ef8896

File tree

12 files changed

+5148
-0
lines changed

12 files changed

+5148
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,8 @@
655655
title: SwitchTransformers
656656
- local: model_doc/t5
657657
title: T5
658+
- local: model_doc/t5gemma
659+
title: T5Gemma
658660
- local: model_doc/t5v1.1
659661
title: T5v1.1
660662
- local: model_doc/tapex
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
2+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
5+
the License. You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
10+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
11+
specific language governing permissions and limitations under the License.
12+
13+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
14+
rendered properly in your Markdown viewer.
15+
16+
-->
17+
18+
19+
# T5Gemma
20+
21+
T5Gemma (aka encoder-decoder Gemma) was proposed in a [research paper](https://arxiv.org/abs/2504.06225) by Google. It is a family of encoder-decoder large langauge models, developed by adapting pretrained decoder-only models into encoder-decoder. T5Gemma includes pretrained and instruction-tuned variants. The architecture is based on transformer encoder-decoder design following T5, with improvements from Gemma 2: GQA, RoPE, GeGLU activation, RMSNorm, and interleaved local/global attention.
22+
23+
T5Gemma has two groups of model sizes: 1) [Gemma 2](https://ai.google.dev/gemma/docs/core/model_card_2) sizes (2B-2B, 9B-2B, and 9B-9B), which are based on the offical Gemma 2 models (2B and 9B); and 2) [T5](https://arxiv.org/abs/1910.10683) sizes (Small, Base, Large, and XL), where are pretrained under the Gemma 2 framework following T5 configuration. In addition, we also provide a model at ML size (medium large, ~2B in total), which is in-between T5 Large and T5 XL.
24+
25+
The pretrained varaints are trained with two objectives: prefix language modeling with knowledge distillation (PrefixLM) and UL2, separately. We release both variants for each model size. The instruction-turned varaints was post-trained with supervised fine-tuning and reinforcement learning.
26+
27+
The example below demonstrates how to chat with the model with [`Pipeline`] or the [`AutoModel`] class, and from the command line.
28+
29+
<hfoptions id="usage">
30+
<hfoption id="Pipeline">
31+
32+
33+
```python
34+
import torch
35+
from transformers import pipeline
36+
37+
pipe = pipeline(
38+
task="text2text-generation",
39+
model="google/t5gemma-placeholder",
40+
torch_dtype=torch.bfloat16,
41+
device="cuda",
42+
)
43+
44+
pipe("Question: Why is the sky blue?\nAnswer:", max_new_tokens=50)
45+
```
46+
47+
</hfoption>
48+
<hfoption id="AutoModel">
49+
50+
```python
51+
import torch
52+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
53+
54+
tokenizer = AutoTokenizer.from_pretrained("google/t5gemma-placeholder")
55+
model = AutoModelForSeq2SeqLM.from_pretrained(
56+
"google/t5gemma-placeholder",
57+
torch_dtype=torch.bfloat16,
58+
device_map="auto"
59+
)
60+
61+
input_text = "Question: Why is the sky blue?\nAnswer:"
62+
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
63+
64+
outputs = model.generate(**input_ids, max_new_tokens=32)
65+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
66+
67+
```
68+
69+
</hfoption>
70+
<hfoption id="transformers CLI">
71+
72+
```
73+
echo -e "Question: Why is the sky blue? Answer:" | transformers run --task text2text-generation --model google/t5gemma-placeholder --device 0
74+
```
75+
76+
## T5GemmaConfig
77+
78+
[[autodoc]] T5GemmaConfig
79+
80+
## T5GemmaModuleConfig
81+
82+
[[autodoc]] T5GemmaModuleConfig
83+
84+
## T5GemmaModel
85+
86+
[[autodoc]] T5GemmaModel
87+
- forward
88+
89+
## T5GemmaEncoderModel
90+
91+
[[autodoc]] T5GemmaEncoderModel
92+
- forward
93+
94+
## T5GemmaForConditionalGeneration
95+
96+
[[autodoc]] T5GemmaForConditionalGeneration
97+
- forward
98+
99+
## T5GemmaForSequenceClassification
100+
101+
[[autodoc]] T5GemmaForSequenceClassification
102+
- forward
103+
104+
## T5GemmaForTokenClassification
105+
106+
[[autodoc]] T5GemmaForTokenClassification
107+
- forward

src/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@
294294
from .swinv2 import *
295295
from .switch_transformers import *
296296
from .t5 import *
297+
from .t5gemma import *
297298
from .table_transformer import *
298299
from .tapas import *
299300
from .textnet import *

src/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@
333333
("swinv2", "Swinv2Config"),
334334
("switch_transformers", "SwitchTransformersConfig"),
335335
("t5", "T5Config"),
336+
("t5gemma", "T5GemmaConfig"),
336337
("table-transformer", "TableTransformerConfig"),
337338
("tapas", "TapasConfig"),
338339
("textnet", "TextNetConfig"),
@@ -721,6 +722,7 @@
721722
("swinv2", "Swin Transformer V2"),
722723
("switch_transformers", "SwitchTransformers"),
723724
("t5", "T5"),
725+
("t5gemma", "T5Gemma"),
724726
("t5v1.1", "T5v1.1"),
725727
("table-transformer", "Table Transformer"),
726728
("tapas", "TAPAS"),

src/transformers/models/auto/modeling_auto.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@
310310
("swinv2", "Swinv2Model"),
311311
("switch_transformers", "SwitchTransformersModel"),
312312
("t5", "T5Model"),
313+
("t5gemma", "T5GemmaModel"),
313314
("table-transformer", "TableTransformerModel"),
314315
("tapas", "TapasModel"),
315316
("textnet", "TextNetModel"),
@@ -430,6 +431,7 @@
430431
("squeezebert", "SqueezeBertForMaskedLM"),
431432
("switch_transformers", "SwitchTransformersForConditionalGeneration"),
432433
("t5", "T5ForConditionalGeneration"),
434+
("t5gemma", "T5GemmaForConditionalGeneration"),
433435
("tapas", "TapasForMaskedLM"),
434436
("transfo-xl", "TransfoXLLMHeadModel"),
435437
("tvlt", "TvltForPreTraining"),
@@ -524,6 +526,7 @@
524526
("squeezebert", "SqueezeBertForMaskedLM"),
525527
("switch_transformers", "SwitchTransformersForConditionalGeneration"),
526528
("t5", "T5ForConditionalGeneration"),
529+
("t5gemma", "T5GemmaForConditionalGeneration"),
527530
("tapas", "TapasForMaskedLM"),
528531
("transfo-xl", "TransfoXLLMHeadModel"),
529532
("wav2vec2", "Wav2Vec2ForMaskedLM"),
@@ -1044,6 +1047,7 @@
10441047
("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"),
10451048
("switch_transformers", "SwitchTransformersForConditionalGeneration"),
10461049
("t5", "T5ForConditionalGeneration"),
1050+
("t5gemma", "T5GemmaForConditionalGeneration"),
10471051
("umt5", "UMT5ForConditionalGeneration"),
10481052
("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
10491053
]
@@ -1156,6 +1160,7 @@
11561160
("stablelm", "StableLmForSequenceClassification"),
11571161
("starcoder2", "Starcoder2ForSequenceClassification"),
11581162
("t5", "T5ForSequenceClassification"),
1163+
("t5gemma", "T5GemmaForSequenceClassification"),
11591164
("tapas", "TapasForSequenceClassification"),
11601165
("transfo-xl", "TransfoXLForSequenceClassification"),
11611166
("umt5", "UMT5ForSequenceClassification"),
@@ -1349,6 +1354,7 @@
13491354
("stablelm", "StableLmForTokenClassification"),
13501355
("starcoder2", "Starcoder2ForTokenClassification"),
13511356
("t5", "T5ForTokenClassification"),
1357+
("t5gemma", "T5GemmaForTokenClassification"),
13521358
("umt5", "UMT5ForTokenClassification"),
13531359
("xlm", "XLMForTokenClassification"),
13541360
("xlm-roberta", "XLMRobertaForTokenClassification"),
@@ -1582,6 +1588,7 @@
15821588
("roformer", "RoFormerModel"),
15831589
("squeezebert", "SqueezeBertModel"),
15841590
("t5", "T5EncoderModel"),
1591+
("t5gemma", "T5GemmaEncoderModel"),
15851592
("umt5", "UMT5EncoderModel"),
15861593
("xlm", "XLMModel"),
15871594
("xlm-roberta", "XLMRobertaModel"),

src/transformers/models/auto/tokenization_auto.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,13 @@
582582
"T5TokenizerFast" if is_tokenizers_available() else None,
583583
),
584584
),
585+
(
586+
"t5gemma",
587+
(
588+
"GemmaTokenizer" if is_sentencepiece_available() else None,
589+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
590+
),
591+
),
585592
("tapas", ("TapasTokenizer", None)),
586593
("tapex", ("TapexTokenizer", None)),
587594
("transfo-xl", ("TransfoXLTokenizer", None)),
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import TYPE_CHECKING
15+
16+
from ...utils import _LazyModule
17+
from ...utils.import_utils import define_import_structure
18+
19+
20+
if TYPE_CHECKING:
21+
from .configuration_encdecgemma2 import *
22+
from .modeling_encdecgemma2 import *
23+
else:
24+
import sys
25+
26+
_file = globals()["__file__"]
27+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

0 commit comments

Comments
 (0)