-
Notifications
You must be signed in to change notification settings - Fork 129
Add slogdet
for Numba and JAX
#172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
2a23d4c
to
fdaaafb
Compare
This looks great, thank you! |
Hi @aseyboldt, thanks! We can leave it for another PR. I'm looking at this gradient right now and I have a question: As the For example pytensor/tests/tensor/test_nlinalg.py Line 353 in 9b2cb97
Internally I see there is class EighGrad(Op) operator that defines custom gradient computation.
Would you have a hint how I can compute it here? |
Yes, that's correct. I think in this case this shouldn't be a problem though, the derivative of the sign should just be 0 I think. It isn't actually zero everywhere I guess, but it is zero everywhere, where the derivative is defined in the first place... Oh, and about doing this in this PR or a separate one: Whatever you prefer. |
Does this Op have a JAX implementation by any chance? |
Thanks @aseyboldt! Then this PR is done from my side! (Will prepare a separate one)
Hi @ricardoV94, do you mean if JAX supports |
In that case would be good to also add the jax dispatch, so that it works with both backends |
Sure! Added! |
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #172 +/- ##
==========================================
+ Coverage 79.97% 80.03% +0.05%
==========================================
Files 169 170 +1
Lines 44621 45082 +461
Branches 9433 9602 +169
==========================================
+ Hits 35686 36081 +395
- Misses 6740 6789 +49
- Partials 2195 2212 +17
|
Thanks @mtsokol -- this is a great contribution! |
This PR relates to #161 and adds
slogdet
numpy method for Numba. Please share your feedback! 🙂