-
Notifications
You must be signed in to change notification settings - Fork 129
Use Blockwise for matmul #452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
c6b4410
to
7e3c2dd
Compare
7e3c2dd
to
8d4054b
Compare
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #452 +/- ##
==========================================
- Coverage 80.76% 80.75% -0.02%
==========================================
Files 159 159
Lines 45869 45849 -20
Branches 11238 11234 -4
==========================================
- Hits 37048 37026 -22
- Misses 6593 6595 +2
Partials 2228 2228
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, I played around with it and everything seems to work as expected. Currently you can't compile graphs with blockwise-matmul into jax/numba, is that beyond the scope of this PR? I thought there was a jax.vectorize
-type function that would make that trivial (for the jax case at least)?
elif x1.type.ndim == 1: | ||
out = _matrix_matrix_matmul(x1[None], x2).squeeze(-2) | ||
elif x2.type.ndim == 1: | ||
out = _matrix_matrix_matmul(x1, x2[:, None]).squeeze(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is all this better than a separate _matrix_vector_matmul
function? I only ask because BLAS makes the distinction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be fine, once we ever go into optimizing this further in jax/numba backends we should be able to know which case is which by inspecting the input static types.
Yet it should be pretty simple. It's in the todo list: #430 |
We use a Blockwised Dot for matmul so that we get gradients for free.
C-performance won't be great for Blockwised Dot, since that doesn't have a C-implementation.
We could Blockwise more specialized Dot22 / GEMM Ops but that code is a bit of a mess at the moment and not useful long term as we deprecate the C backend
Alternatively we probably could useThey are fundamentally different (or it would be rather inefficient to convert one to the other)tensor_dot
/batched_dot
?Closes #451
Needed for pymc-devs/pymc#6897