Skip to content

Rename and refactor start dict in sampling #5027

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 28, 2021

Conversation

michaelosthege
Copy link
Member

@michaelosthege michaelosthege commented Sep 26, 2021

This PR is in preparation of switching to the new initval framework.

To make debugging and reviewing a little easier, I added type hints and moved the start-related code closer together.

API/Behavior changes:

  • A few lines related to the now unsupported use of length-zero traces as pm.sample inputs were removed.
  • The result from init_nuts is now always used as the initial/starting point, whereas before it was only used unless a start dict was manually specified. init_nuts itself combines the automatically determined initial point with the user-provided initvals such that initvals take priority.
  • The pm.sample(start=...) kwarg was renamed to initvals, to reflect that it takes the same keys/values/signature as model.initial_values or the corresponding Distribution.__new__(initval=...) kwarg.
  • start kwargs of lower-level sampling functions are now required to be numeric & complete. model.update_start_vals is no longer applied by lower-level functions.
  • Moving forward only pm.sample(initvals=...) and init_nuts(initvals=...) take the fully-flexibly initval-style dictionary of potentially incomplete (and soon also symbolic, "prior", or "moment" valued) initval strategies.
  • Checks of non-inf/nan initial points and corresponding shapes now run under all circumstances.

@codecov
Copy link

codecov bot commented Sep 26, 2021

Codecov Report

Merging #5027 (6156cbb) into main (641b278) will increase coverage by 0.01%.
The diff coverage is 83.33%.

❗ Current head 6156cbb differs from pull request most recent head 75f4c46. Consider uploading reports for the commit 75f4c46 to get more accurate results
Impacted file tree graph

@@            Coverage Diff             @@
##             main    #5027      +/-   ##
==========================================
+ Coverage   77.82%   77.83%   +0.01%     
==========================================
  Files         128      128              
  Lines       24380    24384       +4     
==========================================
+ Hits        18973    18980       +7     
+ Misses       5407     5404       -3     
Impacted Files Coverage Δ
pymc/sampling.py 87.06% <81.81%> (+0.04%) ⬆️
pymc/parallel_sampling.py 87.33% <100.00%> (+1.04%) ⬆️

@michaelosthege
Copy link
Member Author

@aseyboldt the only remaining test failure is limited to float32. The test was written by you 4 years ago and I can't tell why it failed:
https://github.com/pymc-devs/pymc3/pull/5027/checks?check_run_id=3720168022#step:7:2486

I was able to establish that the initial point that failes the _check_start_shape() was {'a': array([0., 0.], dtype=float32), 'c': array([0., 0.], dtype=float32)}.

@michaelosthege
Copy link
Member Author

I XFAILed the test, since the error appears in a branch (start shape checking) that was not triggered before.

@michaelosthege michaelosthege marked this pull request as ready for review September 27, 2021 17:49
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs a release note

@michaelosthege
Copy link
Member Author

I'm collecting all release notes in the Hackmd document. Also there will be another update on this API with the next PR.

If that's fine with you, I will update the release notes in the next PR?

try:
step = CompoundStep(step)
except TypeError:
pass

point = Point(start, model=model, filter_model_vars=True)
point = start
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to simply use "start" instead of renaming the variable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided against that because the "point" variable is overwritten all over again while iterating the draws. So it's no longer the "start" after the first iteration and I wanted to avoid confusion because of lines like strace.record(start, stats).

michaelosthege and others added 2 commits September 28, 2021 09:51
Take out leftover start-from-trace support.
And rearrange some code blocks for easier refactoring later.
Co-authored-by: Osvaldo Martin <[email protected]>
The initial point is now determined exactly once in the control flow:
+ By `init_nuts` (initvals replace init results).
+ In `sample`, if the above does not apply or fails.

Lower-level sampling functions now require the `start` kwarg to be a
complete dictionary of numeric initial values for all free variables.

The initial points for _each_ chain is checked for shape and logp inf/nan
once in `sample`, even if they may be identical for all chains.

Co-authored-by: Osvaldo Martin <[email protected]>
Copy link
Member Author

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @aloctavodia ! With the exception of one item I implemented your suggestions :)

Please let me know if we can move forward with this PR.

try:
step = CompoundStep(step)
except TypeError:
pass

point = Point(start, model=model, filter_model_vars=True)
point = start
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided against that because the "point" variable is overwritten all over again while iterating the draws. So it's no longer the "start" after the first iteration and I wanted to avoid confusion because of lines like strace.record(start, stats).

@michaelosthege michaelosthege merged commit 00e6eb9 into pymc-devs:main Sep 28, 2021
@michaelosthege michaelosthege deleted the rename-start branch September 28, 2021 14:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants