-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Fix rejection-based truncation of scalar variables #6923
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix rejection-based truncation of scalar variables #6923
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6923 +/- ##
==========================================
+ Coverage 91.78% 92.17% +0.38%
==========================================
Files 100 100
Lines 16845 16847 +2
==========================================
+ Hits 15462 15528 +66
+ Misses 1383 1319 -64
|
629fdef
to
ff30b9d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks very nice @ricardoV94. I might need you to guide me a bit through the logic here because I haven't grasped the subtleties of trying to truncate SymbolicDistribution
s yet.
truncated_rv = pt.set_subtensor( | ||
truncated_rv[reject_draws], | ||
new_truncated_rv[reject_draws], | ||
) | ||
reject_draws = pt.or_((truncated_rv < lower), (truncated_rv > upper)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this also have some kind of pt.and_(not(reject_draws), ...
so that the draws that were already accepted don't get resampled? I fail to see where that is happening.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The draws that were already accepted will always be within upper and lower. The set_subtensor only changes the indexes that were not already valid.
pymc/distributions/truncated.py
Outdated
reject_draws = pt.or_((truncated_rv < lower), (truncated_rv > upper)) | ||
|
||
return ( | ||
(truncated_rv, reject_draws), | ||
[(rng, next_rng)], | ||
collect_default_updates([new_truncated_rv]), | ||
until(~pt.any(reject_draws)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this until
cut the scan
short of the max_n_steps
if the condition is met sooner?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's the whole point of the until
(Also you can never define one without max_n_steps
if that helps)
d583fd6
to
b39ceae
Compare
0e0b8ef
to
3a42050
Compare
@lucianopaz, I realized I needed a bigger refactor, mostly because pymc-devs/pytensor#473 makes it hard to box other SymbolicRVs safely. This PR is now just fixing the bug with the scalar case, and I'll open another one later with then new functionality |
3a42050
to
b574db7
Compare
b574db7
to
5fa5721
Compare
Reported in pymc-devs/pytensor#442
📚 Documentation preview 📚: https://pymc--6923.org.readthedocs.build/en/6923/