-
Notifications
You must be signed in to change notification settings - Fork 63
Implement aten::_softmax
| feat(torchlib)
#1024
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## main #1024 +/- ##
==========================================
- Coverage 77.51% 77.50% -0.01%
==========================================
Files 114 114
Lines 14309 14335 +26
Branches 1521 1525 +4
==========================================
+ Hits 11091 11110 +19
- Misses 2854 2859 +5
- Partials 364 366 +2
|
|
||
@torch_op("aten::_softmax", trace_only=True) | ||
def aten__softmax( | ||
self: TFloatHighPrecision, dim: int, half_to_float: bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does these 2 _softmax() functions has different dtypes for self?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because we want to cast float16 to float32 when half_to_float is true. Since we have no idea what the input dtype will be within the function body, we rely on the dispatcher to pick the function that already limits the input dtype it can accept to know we are dealing with float16 types.
Test Results 18 files ± 0 18 suites ±0 1h 18m 35s ⏱️ - 9m 21s For more details on these failures and errors, see this check. Results for commit a1368e4. ± Comparison against base commit 74c068c. This pull request removes 266 and adds 277 tests. Note that renamed tests count towards both.
♻️ This comment has been updated with latest results. |
Move
aten_special_softmax
->aten_softmax
, support alias and when dtype input is None. Implement_softmax
usingsoftmax
.An alternative would be to create a decomp in pytorch from _softmax to softmax. But since pytorch optimization tends to create decomp to private ops (and not the other way around) in different optimizations steps as @BowenBao pointed out, we still need to support _softmax here.