Skip to content

Commit 5c68832

Browse files
authored
[GPT OSS] Fix false flag (#43120)
fix
1 parent 68dcd13 commit 5c68832

File tree

2 files changed

+1
-2
lines changed

2 files changed

+1
-2
lines changed

src/transformers/models/gpt_oss/modeling_gpt_oss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ class GptOssPreTrainedModel(PreTrainedModel):
434434
_skip_keys_device_placement = ["past_key_values"]
435435
_supports_flash_attn = True
436436
_supports_sdpa = False
437-
_supports_flex_attn = False
437+
_supports_flex_attn = True
438438

439439
_can_compile_fullgraph = True
440440
_supports_attention_backend = True

src/transformers/models/gpt_oss/modular_gpt_oss.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,6 @@ def forward(
354354
class GptOssPreTrainedModel(LlamaPreTrainedModel):
355355
_keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"]
356356
_supports_sdpa = False
357-
_supports_flex_attn = False
358357
_can_record_outputs = {
359358
"router_logits": OutputRecorder(GptOssTopKRouter, index=0),
360359
"hidden_states": GptOssDecoderLayer,

0 commit comments

Comments
 (0)