-
Notifications
You must be signed in to change notification settings - Fork 65
Decompose addmm with Gemm | feat(torchlib) #1111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## main #1111 +/- ##
==========================================
+ Coverage 78.13% 78.16% +0.02%
==========================================
Files 117 117
Lines 14954 14966 +12
Branches 1585 1586 +1
==========================================
+ Hits 11685 11698 +13
Misses 2900 2900
+ Partials 369 368 -1
|
|
@@ -232,6 +249,29 @@ def aten_addmm( | |||
return op.Add(scaled_self, scaled_mat1_mat2) | |||
|
|||
|
|||
@torch_op("aten::addmm") | |||
def aten_addmm_gemm( | |||
self: FLOAT, mat1: FLOAT, mat2: FLOAT, beta: float = 1.0, alpha: float = 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why only do 'gemm' variant on float inputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ORT implements Gemm for float inputs only. So I set it accordingly for practicality
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to TFloat because it supports float16 as well
Test Results 18 files ± 0 18 suites ±0 1h 45m 59s ⏱️ + 21m 35s For more details on these failures and errors, see this check. Results for commit b448bc0. ± Comparison against base commit 9fb0a7d. This pull request removes 520 and adds 539 tests. Note that renamed tests count towards both.
This pull request skips 8 and un-skips 20 tests.
♻️ This comment has been updated with latest results. |
Yeah Gemm seems to be more specialized for that case. Backends should just do the right thing for matmul imo |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Decompose addmm with Gemm by creating a special variant for
FLOAT
and conditionally check for the ranks if the input tensors. The if branch is expected to be folded away by constant folding passes.I have not found other instances where Gemm is used in the torch.onnx exporter.
Fixes #1089
cc @baijumeswani