Skip to content

Encoder-Decoder Gemma #38332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jun 25, 2025
Merged

Encoder-Decoder Gemma #38332

merged 20 commits into from
Jun 25, 2025

Conversation

bzhangGo
Copy link
Contributor

What does this PR do?

Add support for encoder-decoder Gemma (https://arxiv.org/abs/2504.06225)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! thanks for the PR! Would be nice to split encoder and decoder. Decoder already exist pretty much is Gemma3 no? so simpler to write already!
Can you also add the tests! 🤗

super().__init__(config, device)


class EncdecGemma2Attention(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's split cross and self please!

return attn_output, attn_weights


def make_sliding_mask(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use create_sliding_window_mask with potentially a or_mask for bidirectionnal? 🤗

Comment on lines 556 to 564
if self.is_decoder:
# cross attention
self.cross_attn = EncdecGemma2Attention(
config=config,
layer_idx=layer_idx,
is_cross_attention=True,
)
self.pre_cross_attn_layernorm = EncdecGemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_cross_attn_layernorm = EncdecGemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should split encoder and decoder layers!

Comment on lines 591 to 619
# setup sliding window for self-attention mask
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# In prefill, we may be larger than sliding window
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
# thus we must slice from the right (at most `effective_seq_len` elements)
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask[:, -effective_seq_len:]
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
# from the left, with an offset if we are beyond the sliding window
else:
attention_mask = make_sliding_mask(
attention_mask,
self.sliding_window,
# Decoder self-attention: causal attention
# Encoder self-attention: bidirectional attention
bidirectional=not self.is_decoder
)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = torch.clamp(offset, min=0)
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed if you leverage the new causal mask API

return shifted_input_ids


class EncdecGemma2Stack(EncdecGemma2PreTrainedModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a fan of stacks, let's split encoder and decoder! I know t5 has that but it's not a good precedent!

if output_hidden_states:
all_hidden_states += (hidden_states,)

if self.gradient_checkpointing and self.training:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for convenience use GradientCheckpointingLayer!

@bzhangGo
Copy link
Contributor Author

bzhangGo commented Jun 6, 2025

Hey! thanks for the PR! Would be nice to split encoder and decoder. Decoder already exist pretty much is Gemma3 no? so simpler to write already! Can you also add the tests! 🤗

@ArthurZucker thanks for your suggestions! I made several updates to this PR. please take another look!

@bzhangGo bzhangGo marked this pull request as ready for review June 18, 2025 00:17
@bzhangGo
Copy link
Contributor Author

@ArthurZucker I just marked this PR as ready for review. could you please take another look?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kudos it is very very nice! 🤗

@bzhangGo bzhangGo force-pushed the encdecgemma2 branch 2 times, most recently from f5bfafe to 6605345 Compare June 23, 2025 17:28
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last 2 nits!

Comment on lines +833 to +839
mask_kwargs = {
"config": self.config,
"input_embeds": encoder_hidden_states,
"attention_mask": encoder_attention_mask,
"cache_position": cache_position,
"past_key_values": None,
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the mask kwargs are almos the same we can probably just re=use?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, you mean the same as the self attn mask kwargs? perhpas not to reuse becuase out of 5 kwargs, only 2 are the same; the other 3 are different.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah okay okay no worries



@auto_docstring
class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this class and the next I would rather add upon request!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this model is T5Gemma and has the encoder-decoder architecture, researchers would naturally expect it follows the T5 usage. It would be weird if classification tasks are not supported.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok makes sense

bzhangGo added 18 commits June 24, 2025 13:04
1. add __init__ file
2. tied word embedding
3. support flash/flex attention
4. model saving and loading
* Rename encdecgemma to t5gemma.
* Split attention into self- and cross-attention
* Split stack into encoder and decoder
* Add test cases
* Add auto configuration
* replace docstrings with auto_docstrings
* remove checkpoint layers
* remove deprecate_kwargs
bzhangGo added 2 commits June 24, 2025 13:04
* split encoder-only model out
* make t5gemmamodel encoder-decoder only
* update token and sequence classification
* update tests
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for bearing with me!

@ArthurZucker ArthurZucker enabled auto-merge (squash) June 25, 2025 08:52
@ArthurZucker ArthurZucker merged commit 3ef8896 into huggingface:main Jun 25, 2025
20 checks passed
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@bzhangGo bzhangGo deleted the encdecgemma2 branch June 30, 2025 13:04
@stefan-it
Copy link
Collaborator

The Token Classification implementation was not tested End-to-End:

[INFO|trainer.py:4352] 2025-07-12 12:27:59,071 >> 
***** Running Evaluation *****
[INFO|trainer.py:4354] 2025-07-12 12:27:59,071 >>   Num examples = 3250
[INFO|trainer.py:4357] 2025-07-12 12:27:59,071 >>   Batch size = 8
                                                                                                                                                                                              Traceback (most recent call last):██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 402/407 [00:06<00:00, 58.28it/s]
  File "/home/stefan/Repositories/transformers-t5gemma/examples/pytorch/token-classification/run_ner.py", line 653, in <module>
    main()
  File "/home/stefan/Repositories/transformers-t5gemma/examples/pytorch/token-classification/run_ner.py", line 584, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stefan/Repositories/transformers-t5gemma/src/transformers/trainer.py", line 2206, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/stefan/Repositories/transformers-t5gemma/src/transformers/trainer.py", line 2656, in _inner_training_loop
    self._maybe_log_save_evaluate(
  File "/home/stefan/Repositories/transformers-t5gemma/src/transformers/trainer.py", line 3095, in _maybe_log_save_evaluate
    metrics = self._evaluate(trial, ignore_keys_for_eval)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stefan/Repositories/transformers-t5gemma/src/transformers/trainer.py", line 3044, in _evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stefan/Repositories/transformers-t5gemma/src/transformers/trainer.py", line 4198, in evaluate
    output = eval_loop(
             ^^^^^^^^^^
  File "/home/stefan/Repositories/transformers-t5gemma/src/transformers/trainer.py", line 4488, in evaluation_loop
    metrics = self.compute_metrics(
              ^^^^^^^^^^^^^^^^^^^^^
  File "/home/stefan/Repositories/transformers-t5gemma/examples/pytorch/token-classification/run_ner.py", line 535, in compute_metrics
    predictions = np.argmax(predictions, axis=2)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stefan/venvs/t5gemma/lib/python3.12/site-packages/numpy/_core/fromnumeric.py", line 1359, in argmax
    return _wrapfunc(a, 'argmax', axis=axis, out=out, **kwds)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stefan/venvs/t5gemma/lib/python3.12/site-packages/numpy/_core/fromnumeric.py", line 54, in _wrapfunc
    return _wrapit(obj, method, *args, **kwds)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stefan/venvs/t5gemma/lib/python3.12/site-packages/numpy/_core/fromnumeric.py", line 42, in _wrapit
    conv = _array_converter(obj)
           ^^^^^^^^^^^^^^^^^^^^^
ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2,) + inhomogeneous part.
  5%|███████▌   

Tested with latest 4.54.0.dev0 using the PyTorch token classification example via:

BATCH_SIZE=16
LR=5e-05
EPOCHS=20
SEED=1

python3 run_ner.py \
  --model_name_or_path google/t5gemma-s-s-ul2 \
  --dataset_name conll2003 \
  --output_dir ./t5gemma-bs${BATCH_SIZE}-lr${LR}-e${EPOCHS}-${SEED} \
  --eval_strategy epoch \
  --save_strategy epoch \
  --per_device_train_batch_size ${BATCH_SIZE} \
  --learning_rate ${LR} \
  --num_train_epochs ${EPOCHS} \
  --load_best_model_at_end=True \
  --bf16 \
  --do_train \
  --do_eval

@bzhangGo
Copy link
Contributor Author

Hey, thanks for raising this! There seems some problem with eval as the model returns hidden states as well.

A temporary fix is to disable them in modeling_t5gemma.py, like

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=None, #hidden_states,
            attentions=None, #attentions,
        )

We'll make an update later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants