-
-
Notifications
You must be signed in to change notification settings - Fork 152
Broadcast matrix inputs to Gemm #986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Converting to draft because of the performance concerns that this fix entails |
The old I made a remark about this (i.e. the "old") interpretation of The type constraints you mention are viable, but also really do require a much more clearly defined and implemented type system, and, ultimately, some basic constraint logic. My push for the broad use of miniKanren is—in part—motivated by the availability (and compartmentalization) of such features. Regardless, why can't we broadcast all the inputs to the GEMM |
I'll explore that. I didn't plan to mess with blas related Ops, but here we are ^^ In any case I feel that supporting (and enforcing) non size1 type shape might come in handy in a couple of places. |
Broadcasting the matrix inputs was not so hard in the end. Doing that now |
Codecov Report
@@ Coverage Diff @@
## main #986 +/- ##
=======================================
Coverage 79.26% 79.26%
=======================================
Files 152 152
Lines 47927 47932 +5
Branches 10912 10913 +1
=======================================
+ Hits 37990 37995 +5
Misses 7429 7429
Partials 2508 2508
|
This test was not particularly slow compared to the rest of the module.
Tests are passing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great.
I did some sanity checks and I am more confident that I didn't screw up anything :) |
The
Gemm
Op
is only applicable when there is no broadcasting betweenZ
anddot(A, B)
in the expressionZ + dot(A, B)
. This PR broadcasts the matrix inputs when the Op was inserted for a mix of matrices that weren't known to be row / column matrices until runtime.Closes #984