Skip to content

fix(testing): Fix MoonshineEncoder UnboundLocalError and Florence2VisionBackbone dtype mismatch#44503

Merged
Rocketknight1 merged 1 commit intohuggingface:mainfrom
harshaljanjani:fix/moonshine-florence2-runtime-errors
Mar 9, 2026
Merged

fix(testing): Fix MoonshineEncoder UnboundLocalError and Florence2VisionBackbone dtype mismatch#44503
Rocketknight1 merged 1 commit intohuggingface:mainfrom
harshaljanjani:fix/moonshine-florence2-runtime-errors

Conversation

@harshaljanjani
Copy link
Contributor

@harshaljanjani harshaljanjani commented Mar 6, 2026

What does this PR do?

The following failing tests were identified and fixed in this PR:

Moonshine: In MoonshineEncoder.forward, the var output_attention_mask is only assigned inside the if attention_mask is not None clause, but create_bidirectional_mask called immediately after with attention_mask=None still returns a non-None tensor. Under torch.compile(dynamic=True), is_tracing() returns True inside _ignore_bidirectional_mask_sdpanot is_tracing(padding_mask) fails and _ignore_bidirectional_mask_sdpa returns False, so the allow_is_bidirectional_skip early-exit in sdpa_mask clause never fires and a full batched tensor is materialized (shown in the trace). Speaking to the return-time guard, attention_mask is not None becomes truthy even when no padding mask was provided, leaving output_attention_mask unset. The aforementioned case is hit by test_sdpa_can_compile_dynamic since it drops all attention masks before calling the model and causes an UnboundLocalError; this change should fix that.

Here are the trace outputs that led to arrive at the above explanation:

→ MoonshineEncoder.forward:   input attention_mask=None
→ MoonshineEncoder.forward:   attention_mask is None; output_attention_mask NOT assigned
→ MoonshineEncoder.forward:   calling create_bidirectional_mask with attention_mask=None
→ _preprocess_mask_arguments: early_exit=False, returning attention_mask=None, q_length=1, kv_length=1
→ create_bidirectional_mask:  mask_interface=sdpa_mask
→ create_bidirectional_mask:  calling mask_interface with allow_is_bidirectional_skip=True, attention_mask=None
→ create_bidirectional_mask:  mask_interface returned type=Tensor, shape=torch.Size([3, 1, 1, 1])
→ MoonshineEncoder.forward:   guard 'attention_mask is not None' = True
→ MoonshineEncoder.forward:   about to return, evaluating guard now
>> **UnboundLocalError**

Florence2Vision: test_sdpa_can_compile_dynamic loads the model in bfloat16, but Florence2VisionBackbone.forward got float32 activations from upstream processing but (its conv weights self.convs[0].conv.weight) were bfloat16. Casting the dtype fixes this :)

CI Failures:

3

Before the fix (feel free to cross-check; these errors are reproducible):

1

After the fix (feel free to cross-check):

2

cc: @Rocketknight1

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you fix any necessary existing tests?

@github-actions
Copy link
Contributor

github-actions bot commented Mar 6, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: florence2, moonshine

@harshaljanjani harshaljanjani marked this pull request as ready for review March 6, 2026 17:17
@harshaljanjani harshaljanjani changed the title fix(models): Fix MoonshineEncoder UnboundLocalError and Florence2VisionBackbone dtype mismatch fix(testing): Fix MoonshineEncoder UnboundLocalError and Florence2VisionBackbone dtype mismatch Mar 6, 2026
Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, LGTM!

@Rocketknight1 Rocketknight1 enabled auto-merge (squash) March 9, 2026 16:22
@Rocketknight1 Rocketknight1 merged commit 1a50a3b into huggingface:main Mar 9, 2026
21 checks passed
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@harshaljanjani harshaljanjani deleted the fix/moonshine-florence2-runtime-errors branch March 9, 2026 18:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants