Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/transformers/models/gemma2/configuration_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class Gemma2Config(PretrainedConfig):
size of the sliding window.
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.

```python
>>> from transformers import Gemma2Model, Gemma2Config
Expand All @@ -98,7 +99,6 @@ class Gemma2Config(PretrainedConfig):

model_type = "gemma2"
keys_to_ignore_at_inference = ["past_key_values"]
cache_implementation = "hybrid"
Copy link
Contributor Author

@gante gante Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL: attributes defined at a class level don't are not present in __dict__.

Not present in __dict__ -> config.to_dict() doesn't contain cache_implementation -> generation_config is not initialized with cache_implementation -> tests fail because they expect HybridCache


def __init__(
self,
Expand All @@ -125,6 +125,7 @@ def __init__(
sliding_window=4096,
final_logit_softcapping=30.0,
attn_logit_softcapping=50.0,
cache_implementation="hybrid",
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -153,3 +154,4 @@ def __init__(
self.sliding_window = sliding_window
self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping
self.cache_implementation = cache_implementation
4 changes: 3 additions & 1 deletion src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class Gemma2Config(PretrainedConfig):
size of the sliding window.
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.

```python
>>> from transformers import Gemma2Model, Gemma2Config
Expand All @@ -130,7 +131,6 @@ class Gemma2Config(PretrainedConfig):

model_type = "gemma2"
keys_to_ignore_at_inference = ["past_key_values"]
cache_implementation = "hybrid"

def __init__(
self,
Expand All @@ -157,6 +157,7 @@ def __init__(
sliding_window=4096,
final_logit_softcapping=30.0,
attn_logit_softcapping=50.0,
cache_implementation="hybrid",
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -185,6 +186,7 @@ def __init__(
self.sliding_window = sliding_window
self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping
self.cache_implementation = cache_implementation


class Gemma2RMSNorm(GemmaRMSNorm):
Expand Down
4 changes: 3 additions & 1 deletion utils/check_config_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
"Qwen2Config": ["use_sliding_window"],
"Qwen2MoeConfig": ["use_sliding_window"],
"Qwen2VLConfig": ["use_sliding_window"],
"Gemma2Config": ["tie_word_embeddings"],
# `cache_implementation` should be in the default generation config, but we don't yet support per-model
# generation configs (TODO joao)
"Gemma2Config": ["tie_word_embeddings", "cache_implementation"],
# used to compute the property `self.chunk_length`
"EncodecConfig": ["overlap"],
# used to compute the property `self.layers_block_type`
Expand Down