Skip to content

Add slogdet for Numba and JAX #172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 8, 2023
Merged

Conversation

mtsokol
Copy link
Contributor

@mtsokol mtsokol commented Jan 4, 2023

This PR relates to #161 and adds slogdet numpy method for Numba. Please share your feedback! 🙂

@mtsokol mtsokol changed the title Add slogdet for Numba Add slogdet for Numba Jan 4, 2023
@mtsokol mtsokol mentioned this pull request Jan 4, 2023
8 tasks
@mtsokol mtsokol force-pushed the api/numba_slogdet branch from 2a23d4c to fdaaafb Compare January 5, 2023 14:37
@mtsokol mtsokol marked this pull request as ready for review January 5, 2023 14:39
@aseyboldt
Copy link
Member

This looks great, thank you!
Given that the derivative has such a nice form (it's just $X^{-T}$), it would be great to add that too, using a solve op in the grad method. But I think it could be merged as it is already, and we can add the gradients later as well.

@mtsokol
Copy link
Contributor Author

mtsokol commented Jan 5, 2023

This looks great, thank you! Given that the derivative has such a nice form (it's just X−T), it would be great to add that too, using a solve op in the grad method. But I think it could be merged as it is already, and we can add the gradients later as well.

Hi @aseyboldt, thanks! We can leave it for another PR.

I'm looking at this gradient right now and I have a question: As the slogdet returns (sign_of_det, abs_log_det) instead of just logdet I need to provide gradient definition for $\frac{\partial detsign}{\partial x_{ij}}$ and for $\frac{\partial |det|}{\partial x_{ij}}$ respectively, is that correct?

For example eigh also returns a tuple and the test uses verify_grad with respect to each output (eigenvalues, eigenvectors):

utt.verify_grad(lambda x: self.op(x.dot(x.T))[0], [X], rng=self.rng)

Internally I see there is class EighGrad(Op) operator that defines custom gradient computation.

Would you have a hint how I can compute it here?

@aseyboldt
Copy link
Member

aseyboldt commented Jan 6, 2023

I'm looking at this gradient right now and I have a question: As the slogdet returns (sign_of_det, abs_log_det) instead of just logdet I need to provide gradient definition for $\frac{\partial detsign}{\partial x_{ij}}$ and for $\frac{\partial |det|}{\partial x_{ij}}$ respectively, is that correct?

Yes, that's correct. I think in this case this shouldn't be a problem though, the derivative of the sign should just be 0 I think. It isn't actually zero everywhere I guess, but it is zero everywhere, where the derivative is defined in the first place...

Oh, and about doing this in this PR or a separate one: Whatever you prefer.

@ricardoV94
Copy link
Member

Does this Op have a JAX implementation by any chance?

@mtsokol
Copy link
Contributor Author

mtsokol commented Jan 7, 2023

Yes, that's correct. I think in this case this shouldn't be a problem though, the derivative of the sign should just be 0 I think. It isn't actually zero everywhere I guess, but it is zero everywhere, where the derivative is defined in the first place...

Oh, and about doing this in this PR or a separate one: Whatever you prefer.

Thanks @aseyboldt! Then this PR is done from my side! (Will prepare a separate one)

Does this Op have a JAX implementation by any chance?

Hi @ricardoV94, do you mean if JAX supports slogdet or if pytensor has a JAX port for slogdet? The former - yes! and I'm checking how they compute _slogdet_jvp. The latter - not yet.

@ricardoV94
Copy link
Member

In that case would be good to also add the jax dispatch, so that it works with both backends

@mtsokol
Copy link
Contributor Author

mtsokol commented Jan 7, 2023

In that case would be good to also add the jax dispatch, so that it works with both backends

Sure! Added!

@codecov-commenter
Copy link

codecov-commenter commented Jan 7, 2023

Codecov Report

Merging #172 (79917e0) into main (25236cf) will increase coverage by 0.05%.
The diff coverage is 91.89%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #172      +/-   ##
==========================================
+ Coverage   79.97%   80.03%   +0.05%     
==========================================
  Files         169      170       +1     
  Lines       44621    45082     +461     
  Branches     9433     9602     +169     
==========================================
+ Hits        35686    36081     +395     
- Misses       6740     6789      +49     
- Partials     2195     2212      +17     
Impacted Files Coverage Δ
pytensor/tensor/nlinalg.py 97.84% <85.71%> (-0.73%) ⬇️
pytensor/link/jax/dispatch/nlinalg.py 89.33% <100.00%> (+0.76%) ⬆️
pytensor/link/numba/dispatch/nlinalg.py 100.00% <100.00%> (ø)
pytensor/link/numba/dispatch/elemwise.py 89.01% <0.00%> (-8.04%) ⬇️
pytensor/link/numba/dispatch/basic.py 89.33% <0.00%> (-0.43%) ⬇️
pytensor/link/c/cmodule.py 51.54% <0.00%> (-0.34%) ⬇️
pytensor/scalar/math.py 85.00% <0.00%> (-0.30%) ⬇️
pytensor/graph/rewriting/basic.py 64.36% <0.00%> (-0.14%) ⬇️
pytensor/tensor/inplace.py 100.00% <0.00%> (ø)
pytensor/compile/compiledir.py 0.00% <0.00%> (ø)
... and 36 more

@twiecki twiecki changed the title Add slogdet for Numba Add slogdet for Numba and JAX Jan 7, 2023
@twiecki twiecki merged commit b8831aa into pymc-devs:main Jan 8, 2023
@twiecki
Copy link
Member

twiecki commented Jan 8, 2023

Thanks @mtsokol -- this is a great contribution!

@mtsokol mtsokol deleted the api/numba_slogdet branch January 8, 2023 13:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants