-
Notifications
You must be signed in to change notification settings - Fork 29.7k
Encoder-Decoder Gemma #38332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Encoder-Decoder Gemma #38332
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! thanks for the PR! Would be nice to split encoder and decoder. Decoder already exist pretty much is Gemma3 no? so simpler to write already!
Can you also add the tests! 🤗
super().__init__(config, device) | ||
|
||
|
||
class EncdecGemma2Attention(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's split cross and self please!
return attn_output, attn_weights | ||
|
||
|
||
def make_sliding_mask( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you use create_sliding_window_mask
with potentially a or_mask
for bidirectionnal? 🤗
if self.is_decoder: | ||
# cross attention | ||
self.cross_attn = EncdecGemma2Attention( | ||
config=config, | ||
layer_idx=layer_idx, | ||
is_cross_attention=True, | ||
) | ||
self.pre_cross_attn_layernorm = EncdecGemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
self.post_cross_attn_layernorm = EncdecGemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should split encoder and decoder layers!
# setup sliding window for self-attention mask | ||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding | ||
# In prefill, we may be larger than sliding window | ||
effective_seq_len = max(cache_position.shape[0], self.sliding_window) | ||
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), | ||
# thus we must slice from the right (at most `effective_seq_len` elements) | ||
if self.config._attn_implementation == "flash_attention_2": | ||
attention_mask = attention_mask[:, -effective_seq_len:] | ||
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice | ||
# from the left, with an offset if we are beyond the sliding window | ||
else: | ||
attention_mask = make_sliding_mask( | ||
attention_mask, | ||
self.sliding_window, | ||
# Decoder self-attention: causal attention | ||
# Encoder self-attention: bidirectional attention | ||
bidirectional=not self.is_decoder | ||
) | ||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing | ||
offset = cache_position[-1] - effective_seq_len + 1 | ||
# Should only be used when beyond the sliding window (i.e. offset > 0) | ||
offset = torch.clamp(offset, min=0) | ||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, | ||
# but without data-dependent slicing (i.e. torch.compile friendly) | ||
mask_indexes = torch.arange( | ||
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device | ||
) | ||
mask_indexes += offset | ||
attention_mask = attention_mask[:, :, :, mask_indexes] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed if you leverage the new causal mask API
return shifted_input_ids | ||
|
||
|
||
class EncdecGemma2Stack(EncdecGemma2PreTrainedModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a fan of stacks, let's split encoder and decoder! I know t5 has that but it's not a good precedent!
if output_hidden_states: | ||
all_hidden_states += (hidden_states,) | ||
|
||
if self.gradient_checkpointing and self.training: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for convenience use GradientCheckpointingLayer!
@ArthurZucker thanks for your suggestions! I made several updates to this PR. please take another look! |
@ArthurZucker I just marked this PR as ready for review. could you please take another look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kudos it is very very nice! 🤗
f5bfafe
to
6605345
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Last 2 nits!
mask_kwargs = { | ||
"config": self.config, | ||
"input_embeds": encoder_hidden_states, | ||
"attention_mask": encoder_attention_mask, | ||
"cache_position": cache_position, | ||
"past_key_values": None, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the mask kwargs are almos the same we can probably just re=use?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, you mean the same as the self attn mask kwargs? perhpas not to reuse becuase out of 5 kwargs, only 2 are the same; the other 3 are different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah okay okay no worries
|
||
|
||
@auto_docstring | ||
class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this class and the next I would rather add upon request!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this model is T5Gemma and has the encoder-decoder architecture, researchers would naturally expect it follows the T5 usage. It would be weird if classification tasks are not supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok makes sense
* Rename encdecgemma to t5gemma. * Split attention into self- and cross-attention * Split stack into encoder and decoder * Add test cases * Add auto configuration
…points are uplioaded.).
* replace docstrings with auto_docstrings * remove checkpoint layers * remove deprecate_kwargs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for bearing with me!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
The Token Classification implementation was not tested End-to-End: [INFO|trainer.py:4352] 2025-07-12 12:27:59,071 >>
***** Running Evaluation *****
[INFO|trainer.py:4354] 2025-07-12 12:27:59,071 >> Num examples = 3250
[INFO|trainer.py:4357] 2025-07-12 12:27:59,071 >> Batch size = 8
Traceback (most recent call last):██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 402/407 [00:06<00:00, 58.28it/s]
File "/home/stefan/Repositories/transformers-t5gemma/examples/pytorch/token-classification/run_ner.py", line 653, in <module>
main()
File "/home/stefan/Repositories/transformers-t5gemma/examples/pytorch/token-classification/run_ner.py", line 584, in main
train_result = trainer.train(resume_from_checkpoint=checkpoint)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/stefan/Repositories/transformers-t5gemma/src/transformers/trainer.py", line 2206, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/home/stefan/Repositories/transformers-t5gemma/src/transformers/trainer.py", line 2656, in _inner_training_loop
self._maybe_log_save_evaluate(
File "/home/stefan/Repositories/transformers-t5gemma/src/transformers/trainer.py", line 3095, in _maybe_log_save_evaluate
metrics = self._evaluate(trial, ignore_keys_for_eval)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/stefan/Repositories/transformers-t5gemma/src/transformers/trainer.py", line 3044, in _evaluate
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/stefan/Repositories/transformers-t5gemma/src/transformers/trainer.py", line 4198, in evaluate
output = eval_loop(
^^^^^^^^^^
File "/home/stefan/Repositories/transformers-t5gemma/src/transformers/trainer.py", line 4488, in evaluation_loop
metrics = self.compute_metrics(
^^^^^^^^^^^^^^^^^^^^^
File "/home/stefan/Repositories/transformers-t5gemma/examples/pytorch/token-classification/run_ner.py", line 535, in compute_metrics
predictions = np.argmax(predictions, axis=2)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/stefan/venvs/t5gemma/lib/python3.12/site-packages/numpy/_core/fromnumeric.py", line 1359, in argmax
return _wrapfunc(a, 'argmax', axis=axis, out=out, **kwds)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/stefan/venvs/t5gemma/lib/python3.12/site-packages/numpy/_core/fromnumeric.py", line 54, in _wrapfunc
return _wrapit(obj, method, *args, **kwds)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/stefan/venvs/t5gemma/lib/python3.12/site-packages/numpy/_core/fromnumeric.py", line 42, in _wrapit
conv = _array_converter(obj)
^^^^^^^^^^^^^^^^^^^^^
ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2,) + inhomogeneous part.
5%|███████▌ Tested with latest BATCH_SIZE=16
LR=5e-05
EPOCHS=20
SEED=1
python3 run_ner.py \
--model_name_or_path google/t5gemma-s-s-ul2 \
--dataset_name conll2003 \
--output_dir ./t5gemma-bs${BATCH_SIZE}-lr${LR}-e${EPOCHS}-${SEED} \
--eval_strategy epoch \
--save_strategy epoch \
--per_device_train_batch_size ${BATCH_SIZE} \
--learning_rate ${LR} \
--num_train_epochs ${EPOCHS} \
--load_best_model_at_end=True \
--bf16 \
--do_train \
--do_eval |
Hey, thanks for raising this! There seems some problem with eval as the model returns hidden states as well. A temporary fix is to disable them in modeling_t5gemma.py, like
We'll make an update later. |
What does this PR do?
Add support for encoder-decoder Gemma (https://arxiv.org/abs/2504.06225)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.