-
Notifications
You must be signed in to change notification settings - Fork 130
Add numba overload for solve_triangular
#423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Remove test_SolveTriangular from numba\test_nlinalg.py
Remove test_SolveTriangular from numba\test_nlinalg.py
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #423 +/- ##
==========================================
- Coverage 80.75% 80.64% -0.12%
==========================================
Files 159 160 +1
Lines 45849 46016 +167
Branches 11234 11263 +29
==========================================
+ Hits 37026 37108 +82
- Misses 6595 6671 +76
- Partials 2228 2237 +9
|
Add informative message to error raised by check_finite=True
Is this ready for review or something important still missing? |
I'm still not 100% sold on how it's all implemented. I wanted someone to take a closer look at |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there are maybe a few ways to make this a bit faster, but it looks good to me as it is. I'm not really sure why it would feel hackish? The only downside I can think of compared to compiling a separate extension module is that numba can't cache this due to the dynamic pointer.
|
||
# Need to expand B here; I tried everywhere else and it doesn't work | ||
if B_is_1d: | ||
B_copy = _copy_to_fortran_order(np.expand_dims(B, -1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the original B was 1d, I don't think we need the copy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my testing, trtrs expects at least 2d everything. The docs say LDB >= 0, but when I was giving it 1d arrays I was getting back numerically incorrect results.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After testing, you're right. I wasn't able to avoid the copy in the 2d case though. If I don't copy 2d B, numba flags this line:
B_NDIM = 1 if B_is_1d else int(B.shape[1])
Saying that it's considering a case where B
is 3d. Not sure why it thinks that is possible. Does numba evaluate all if-else branches on all possible inputs?
if A.shape[0] != B.shape[0]: | ||
raise linalg.LinAlgError("Dimensions of A and B do not conform") | ||
|
||
A_copy = _copy_to_fortran_order(A) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we avoid the copy if it is c-order by flipping transval? I think we could also have a special overload for when trans, lower and unit_diag are literals, and we statically know that A and B are C or Fortran continuous.
I think that would really be only an optimization of the current code though, this here should be fine as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does setting an array to fortran contiguous actually transpose the matrix, or does it just re-order the pointers to the internal flat representation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After testing, we can avoid copying A in all cases.
Re: the other point, do you mean checking the values of trans
, lower
, and unit_diag
inside the wrapper function, then returning a specialized impl
function based on their values? Similar to how I'm doing dispatching to real/complex versions here?
It feels hackish because 1) we can't cache the functions (relevant for compile times, which you've pointed out are extremely long with numba), 2) we can't support complex inputs due to a weird technical reason, not due to some principled/fundamental reason, 3) It's nowhere close to working within the "official" numba API, so I have no idea how future proof it is. Complex inputs definitely worked last year, so something was changed in the numba codebase to break that |
Rename addr to lapack_ptr
Don't copy B matrix when B is array in overload func
Some conflicts have cropped up |
What do you mean? |
Still says no conflicts for me. I'll update my fork and double check everything. |
Swap Squash and merge button to rebase and merge, and you should see it |
Marked as a draft, given the suggestion to upper pin numba. Feel free to convert back to ready for review if you chage your mind about it or do it! |
@maresb is there a way to set an upper version limit on an optional dependency (that will also be respected by conda)? |
Sure, just add it under |
Following conversation with @ricardoV94 I'm merging this. We'll cross the bridge of code breaking when/if we get there. |
Motivation for these changes
The
pytensor.tensor.slinalg
module is not currently compatible withmode = "NUMBA"
. This PR is a first step in an effort to fix that. It's marked as a draft because it's 1) not done, and 2) needs discussion/work.Functions in
slinalg
don't have overloads innumba.np.linalg
, so to implement these functions there needs to be an overload that calls the relevant C LAPACK functions. This involves some acrobatics with C pointers and typing, which I am absolutely not an expert at. Currently, I use dynamic pointers fromctypes
, essentially just following numba/numba#5301. This works, but it means the resulting functions can't be cached, which will be a huge slowdown on complex graphs (I think).A more complete approach would try to directly extend numba/numba/_lapack.c with some new pointers to the relevant scipy code. I'm not sure if it would be possible to have our own e..g
_lapack_extensions.c
that could have#include _lapack.c
on top? The pattern in that file looks straightforward enough to copy, but it's been a long time since I did anything in C, and I'm not sure how importing across modules would work.Also, to answer "why
solve_triangular
? Because it's a function that we don't have now, that only depends on a single LAPACK call. Once the pattern is ironed out, I'll do these for all the functions we currently have inslinalg
, most importantlysolve
(yes, we have thenp.linalg.solve
overload, but it doesn't allow access to the specialized solvers for e.g. symmetric positive definite matrices, which matters a lot for PyMC).Implementation details
I followed the implementation of LAPACK overloads established in numba/numpy/linalg/linalg.py. There's a class called
_LAPACK
that holds signatures for all the LAPACK functions that will be implemented, then an overload function.Checklist
Major / Breaking Changes
None
New features
Solve triangular matrices with numba!
Bugfixes
None
Documentation
Not yet
Maintenance
None