Remove tune stat from steps and fix non-discarding of tuning draws from trace#8015
Conversation
|
@OriolAbril / @aloctavodia does any part of Arviz require the step samples to have a tune flag? Is it enough that we have warmup / posterior distinction, each with their number of draws? |
|
Taking a step back, would it make sense for a Even if that's the case, I think it still makes sense to remove this currently useless stat and reintroduce in a separate PR (provided nobody finds a reason why it is actually useful/needed). |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #8015 +/- ##
==========================================
+ Coverage 84.55% 84.56% +0.01%
==========================================
Files 124 124
Lines 19872 19866 -6
==========================================
- Hits 16802 16799 -3
+ Misses 3070 3067 -3
🚀 New features to boost your workflow:
|
Automatically stopping the warmup early would be nice. I think we should agree on cleanly separated definitions of warmup, burn-in and tuning. Samplers not needing to tune parameters doesn't mean that there's no need for a warmup phase of burn-in iterations (however one might call it). Our current implementation is bad because it doesn't separate the concepts. |
|
ArviZ does not require or use a "tune" stats anywhere. |
michaelosthege
left a comment
There was a problem hiding this comment.
I like where this is going!
Using a slightly different naming I think we can simplify a bit more.
|
@michaelosthege does this make sense? |
|
@michaelosthege check this out |
michaelosthege
left a comment
There was a problem hiding this comment.
I'm not familiar with how the progress bar gets updated. Possibly my two comments on that matter are invalid, but please check them.
I'll also trigger the CI tests
| tune = mtrace._straces[0].get_sampler_stats("tune") | ||
| assert isinstance(tune, np.ndarray) | ||
| # warmup is tracked by the sampling driver | ||
| if discard_warmup: | ||
| assert tune.shape == (7, 3) | ||
| assert len(mtrace) == 7 | ||
| else: | ||
| assert tune.shape == (12, 3) | ||
| pass |
There was a problem hiding this comment.
can this test remain as before, but using the in_warmup stat instead?
There was a problem hiding this comment.
@eclipse1605 this comment still sounds relevant though
hey, sorry for the delay but i think they're valid because warmup bookkeeping is now explicitly driver owned |
|
@michaelosthege ive made the tests consistent with the changes, running the ci tests again will mostly pass now |
michaelosthege
left a comment
There was a problem hiding this comment.
Looks good to me!
Thanks @eclipse1605 for your endurance with this!
thanks a ton for the reviews and guidance @michaelosthege and @ricardoV94, really appreciate the patience since im still getting my bearings here :) |
| test_dict = { | ||
| "posterior": ["u1", "n1"], | ||
| "sample_stats": ["~tune", "accept"], | ||
| "sample_stats": ["~in_warmup", "accept"], |
There was a problem hiding this comment.
I'm not sure about changing the output variable name, this seems like a breaking change for users?
The specific line I pointed to may not be relevant. The general question is whether we changed anything in MultiTrace/InferenceData output with this PR other than the tune flag not existing per step.
There was a problem hiding this comment.
it now writes the warmup flag once as in_warmup, but for users, nothing new shows up. when we persist sampler stats (e.g. in mcbackend) we store that boolean and keep trace.get_sampler_stats("tune") working by aliasing to the new field. the default NDArray backend still omits both names, just like before. and to_inference_data continues to drop whichever warmup marker exists, so the resulting InferenceData matches main; the test only switches the "absent" check to the new internal name. no other MultiTrace/InferenceData variables changed.
ricardoV94
left a comment
There was a problem hiding this comment.
This looks sleek, I just want to do a manual integration test locally before merging
sounds good! |
|
hey @ricardoV94 i tried to understand the failed test but didn't really get very far with it. is it failing because jax spits out NaNs when the dirichlet concentration is super skewed, so the multinomial never sees a clean prob vector? |
|
That one fails now and then, don't worry about it |
do you mean this |
|
i saw that, but i wanted a clarification as to whether we want to add an explicit tune alias assertion to preserve that compatibility |
| assert all(len(s) == 7 for s in in_warmup) | ||
| assert all(not np.any(s) for s in in_warmup) |
There was a problem hiding this comment.
What is this in_warmup object we're seeing here? From the test alone I have a hard time figuring out. Is it a numpy array?
There was a problem hiding this comment.
It's unclear to me why this changed, it seemed like we just moved the source of tune/warmup, not the final stored contents?
|
@ricardoV94 any changes required in this? |
|
I have a question about why the test changed, thought the output would still be the same. Also we merged another PR so this one now has conflicts that need to be solved. Let me know if you need help |
|
as i said above, the test changed because the tune sampler stat has been removed right. so the warmup tracking is now handled by the sampling driver, and the backend no longer stores tune. earlier in the test, tune was retrieved using get_sampler_stats("tune"), and its shape was checked to verify the number of warmup and posterior samples. the test asserted that tune was a NumPy array and checked its shape:
because warmup tracking is now managed directly by the sampling process, tune doesn't need to be stored in the backend, so the test cannot retrieve it. so instead of checking the shape of tune, it checks the length of the MultiTrace object (len(mtrace)) to determine the number of posterior samples. let me know if that helps clarify things. |
|
also, if we move on to merging this, ill likely need help fixing the merge conflicts |
I'm happy to help but I'd want to let #8047 go in first because it will change things around again. |
sure, makes sense. let me know when we want to merge this given there are no more changes required :) |
4f537f0 to
e9c066b
Compare
|
Thanks @michaelosthege. The failing tests are a known issue. I'll try to fix them after, but need not block this PR. |
…from trace (pymc-devs#8015) Co-authored-by: Michael Osthege <michael.osthege@outlook.com>
😅 I just fixed them. Will do another PR then ;) |
…from trace (#8015) Co-authored-by: Michael Osthege <michael.osthege@outlook.com>
Description
I tried to make “tuning vs draws” a driver owned concept again. Right now, parts of sampling/postprocessing infer warmup length from a per-step
"tune"sampler stat, which can get out of sync (e.g. a step method returning"tune": Falseeverywhere makes PyMC thinkn_tune == 0, so warmup isn’t discarded and the logs look wrong).Related Issues
Fixes: #7997
Context: #7776 (progressbar/stat refactor that exposed the mismatch)
Related discussion/attempts: #7730, #7721, #7724, #8014