-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Rename and refactor start dict in sampling #5027
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## main #5027 +/- ##
==========================================
+ Coverage 77.82% 77.83% +0.01%
==========================================
Files 128 128
Lines 24380 24384 +4
==========================================
+ Hits 18973 18980 +7
+ Misses 5407 5404 -3
|
6156cbb
to
81183a8
Compare
@aseyboldt the only remaining test failure is limited to float32. The test was written by you 4 years ago and I can't tell why it failed: I was able to establish that the initial point that failes the |
81183a8
to
4e4f1c8
Compare
I XFAILed the test, since the error appears in a branch (start shape checking) that was not triggered before. |
4e4f1c8
to
10b71d9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs a release note
I'm collecting all release notes in the Hackmd document. Also there will be another update on this API with the next PR. If that's fine with you, I will update the release notes in the next PR? |
try: | ||
step = CompoundStep(step) | ||
except TypeError: | ||
pass | ||
|
||
point = Point(start, model=model, filter_model_vars=True) | ||
point = start |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to simply use "start" instead of renaming the variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I decided against that because the "point" variable is overwritten all over again while iterating the draws. So it's no longer the "start" after the first iteration and I wanted to avoid confusion because of lines like strace.record(start, stats)
.
Take out leftover start-from-trace support. And rearrange some code blocks for easier refactoring later.
Co-authored-by: Osvaldo Martin <[email protected]>
10b71d9
to
62f6408
Compare
The initial point is now determined exactly once in the control flow: + By `init_nuts` (initvals replace init results). + In `sample`, if the above does not apply or fails. Lower-level sampling functions now require the `start` kwarg to be a complete dictionary of numeric initial values for all free variables. The initial points for _each_ chain is checked for shape and logp inf/nan once in `sample`, even if they may be identical for all chains. Co-authored-by: Osvaldo Martin <[email protected]>
62f6408
to
75f4c46
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @aloctavodia ! With the exception of one item I implemented your suggestions :)
Please let me know if we can move forward with this PR.
try: | ||
step = CompoundStep(step) | ||
except TypeError: | ||
pass | ||
|
||
point = Point(start, model=model, filter_model_vars=True) | ||
point = start |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I decided against that because the "point" variable is overwritten all over again while iterating the draws. So it's no longer the "start" after the first iteration and I wanted to avoid confusion because of lines like strace.record(start, stats)
.
This PR is in preparation of switching to the new initval framework.
To make debugging and reviewing a little easier, I added type hints and moved the
start
-related code closer together.API/Behavior changes:
pm.sample
inputs were removed.init_nuts
is now always used as the initial/starting point, whereas before it was only used unless astart
dict was manually specified.init_nuts
itself combines the automatically determined initial point with the user-providedinitvals
such thatinitvals
take priority.pm.sample(start=...)
kwarg was renamed toinitvals
, to reflect that it takes the same keys/values/signature asmodel.initial_values
or the correspondingDistribution.__new__(initval=...)
kwarg.start
kwargs of lower-level sampling functions are now required to be numeric & complete.model.update_start_vals
is no longer applied by lower-level functions.pm.sample(initvals=...)
andinit_nuts(initvals=...)
take the fully-flexiblyinitval
-style dictionary of potentially incomplete (and soon also symbolic, "prior", or "moment" valued) initval strategies.