-
Notifications
You must be signed in to change notification settings - Fork 31.7k
[tests] Parameterized test_eager_matches_sdpa_inference
#36650
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tests] Parameterized test_eager_matches_sdpa_inference
#36650
Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the |
| self.skipTest(reason="Model does not support output_attentions") | ||
|
|
||
| # TODO: if we can also check with `batch_size=1` without being flaky? | ||
| for batch_size in [7]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
most of the diff in the loop is indentation :)
there are a few changes, to avoid overwriting this large test and, more importantly, to pressure us into standardizing model interfaces (see musicgen notes in this loop)
| # TODO: we shouldn't need to do this skip, i.e. the test would be composable from the model tester. CLIP-like | ||
| # models have a custom mixin, which we detect to skip this test. | ||
| if not any(".ModelTesterMixin" in str(base) for base in self.__class__.__bases__): | ||
| self.skipTest(reason="CLIP-like models have a different `test_eager_matches_sdpa_inference`") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the differences here already existed -- this test was being overwritten using the same names.
With the parameterization we get new test names, hence this extra skip
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I didn't know changing this will make the new test names being run ... I trust you on this being neeeded here.
|
|
||
| if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device): | ||
| # convert shorthand name to torch.dtype | ||
| if torch_dtype == "fp16": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float16 -> fp16 for a shorter test name :)
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
1bea90c to
db015c4
Compare
| # TODO: standardize the interfaces for musicgen models, see other todo in this test | ||
| if model.__class__.__name__ == "MusicgenMelodyForConditionalGeneration": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean we need this because of the new generate tests have new names?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe I misunderstand: the change here is to avoid this test being overwitten for musicgen models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No :) For this specific model, which behaves differently, we have the two usual options:
- overwrite the test as usual
- add this exception
I went with the second to pressure me into having a second look at musicgen, which has many interface issues (same argument names, different meaning and expected shapes) 👀 These interface issues, in turn, are causing me issues whenever we fix generate, tests, ...
I can do a normal test overwrite if you think it's preferable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's ok, I have no strong opinion, just wondering if : with this if ... else ..., musicgen models is still run this test successfully without the overwrite. If so, that is nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, they pass the test 🙌
ydshieh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only have a nit question
https://github.com/huggingface/transformers/pull/36650/files#r1993288313
I love the idea to make the test name clear in the summary.
I trust you without taking look by myself on the changes in different test files.
|
One more nit: for the name in the test, we can make even short by |
What does this PR do?
Problem
test_eager_matches_sdpa_inferenceonmainis running many nested test cases (48!). This is not only a bad practice but also slows down our workflow: when it breaks, we need to parse the test outputs to see which configuration(s) broke. An example:Which expands into

Did it crash because of a specific parameterization? Or because we are running many subtests in a test? Answering those questions is not clear atm.
Fix
This PR replaces the test's
forloops by@parameterized.expand, making sure the test name immediately identifies the test case.In the process, I've noticed many skips/overwrites are no longer needed. The test is still super ugly, and I've left a few TODO for whenever we decide to touch the test again.
An example -- what's being tested is now clear, minimal difference in run time ✨
