-
Notifications
You must be signed in to change notification settings - Fork 63
Add Op(aten::_scaled_dot_product_efficient_attention) | feat(torchlib) #1197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Op(aten::_scaled_dot_product_efficient_attention) | feat(torchlib) #1197
Conversation
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #1197 +/- ##
==========================================
- Coverage 78.66% 78.60% -0.07%
==========================================
Files 118 118
Lines 15445 15473 +28
Branches 2428 2431 +3
==========================================
+ Hits 12150 12162 +12
- Misses 2897 2915 +18
+ Partials 398 396 -2 ☔ View full report in Codecov by Sentry. |
@justinchuby I tried to run the test on dev10, but it seems the tests are somehow fixed with CPU. Could you point me the direction to test it with GPU? |
Maybe try changing
|
del kwargs | ||
|
||
make = opinfo_core.partial( | ||
opinfo_core.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may also control device here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I enabled the whole CUDA tests if it's needed by the test.
nn_ops.aten__scaled_dot_product_efficient_attention, | ||
trace_only=True, | ||
tolerance={torch.float32: (3e-4, 1.5e-5)}, | ||
# Output[0] is OK, but other outputs just have the same shape with zero values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the other compare option instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
compare_shape_only_for_output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@justinchuby You might want to review again. I changed the function implementation and enabled CUDA tests. |
LGTM. Thanks! |
Fix #1160
Follow up #1043 It's another
scaled_dot_product_XXX_attention
It only supports CUDA.
https://github.com/pytorch/pytorch/blob/38ae17d166a001ef6837553d1ddffa111624df27/torch/_meta_registrations.py#L5195-L5236
NOTE: This PR also enables CUDA tests.