-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Restore support for default initvals #4867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Restore support for default initvals #4867
Conversation
Codecov Report
@@ Coverage Diff @@
## main #4867 +/- ##
==========================================
+ Coverage 73.12% 73.17% +0.05%
==========================================
Files 86 86
Lines 13856 13892 +36
==========================================
+ Hits 10132 10166 +34
- Misses 3724 3726 +2
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that this is a WIP, but what's the basic idea (e.g. what is the mechanism for setting default initial values, how will it work, etc.)?
I'm close to pushing another two commits. I'm primarily adding a lot of testing to nail down what even is the expected behavior.
My intention there is to have the implementation of choosing a default initval in the distributions (where we had it before), while allowing users to disable that too. |
b5ba6fd
to
79e1911
Compare
79e1911
to
8d15f10
Compare
0ccd7be
to
32dd28f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems reasonable but I am not sure we need a separate helper function just to choose between None
, Unset
, and something else. Also I don't really see the need for type checking, it seems things would fail pretty quickly without it (e.g. with the tag.test_value assignment), but that's just a personal opinion.
What about having a new distrobution classmethod that takes the RV inputs and returns an initval? Then during the super The default would obviously be class Distribution:
...
@classmethod
def initval(cls, *args):
return None class Normal(Continuous):
...
@classmethod
def initval(cls, mu, sigma):
return mu Edit: Or a dispatched method like |
Even though it's in the git history like that, I didn't write the helper function first: I ran into a few subtle traps when implementing it without the helper function. For example |
In |
I think that was more out of habit than due to a conscious decision to do it that way. One advantage of having a separate method is for instance that the inputs will already be symbolic. This avoids bugs when the initval logic was based on pure numpy operations that could fail with symbolic inputs. Also many times it's just a question of returning one of the parameters. I don't think it makes the code more complicated, if anything it seems like a more clean separation of responsibility Edit: it's not intuitive at all that |
Okay fair enough. One thing though: Right now we don't support symbolic initvals, because the A switch to a classmethod doesn't make a difference for the |
Hmm. How did it work before? For example in the case where the initval is just the
No, but then you don't really need the helper function I think. You just call it in |
b03aca8
to
a6c6162
Compare
59910ab
to
4c09cb6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only two minor questions / suggestions:
- I would rename
make_initval
to[choose|pick]_initval
or simplyinitval
. - Perhaps it would make more sense to move the new calls inside
model.set_initval
? Since that already handles the prior-sampling case (It seemed to me they aren't needed before). To achieve that we can simply dispatch themake_initval
to the underlying RV, just like logp / logcdf does.
@ricardoV94 what's the matter with this test? https://github.com/pymc-devs/pymc3/pull/4867/checks?check_run_id=3150614327#step:7:448 I restarted it already, but it's still flaky?? |
Yeah, it was not tweaked, but let's just remove it. |
How are we feeling about this? |
I think it would be more clean to call I am not sure what the new Otherwise LGTM |
Oh sorry, I forgot to mention: I tried this, but
The intent is to assist migration by pointing out distributions that still follow the old patterns. |
We can get the class |
7099181
to
c9e460f
Compare
I rebased this and reconsidered the dispatching approach, but concluded that it's not worth the additional complexity. While logp and logcdf are related to the RV, the Please take a final look so we can (squash?) merge and get this over with. |
How does your implementation cope with hierarchical models? If it's not based on the with pm.Model() as m:
x = pm.Normal("x", 0, 1)
y = pm.Normal("y", x, 1) Assuming the |
By the way @kc611 is exploring the alternative solution with dispatching. So he might have some ideas as well. |
I'm just forwarding them to |
Looking at And even when transforming initvals, it seems to expect them to be a Constant |
Improving the Please open a new Issue if it misbehaves. |
I don't see how that can be out-of-scope for a PR that is refactoring how initvals work.
I am pretty sure it will misbehave in the example I just gave. If the new implementation is not designed to handle a simple realistic case, how can it be evaluated? |
import pymc3 as pm
import aesara.tensor as at
import aesara.tensor.random as atr
class NormalWithInitval(pm.distributions.Continuous):
"""
A distribution that defaults the initial value.
"""
rv_op = atr.normal
@classmethod
def dist(cls, mu=0, sigma=1, **kwargs):
mu = at.as_tensor_variable(pm.floatX(mu))
sigma = at.as_tensor_variable(pm.floatX(sigma))
return super().dist([mu, sigma], **kwargs)
@classmethod
def pick_initval(cls, mu, sigma, **kwargs):
return mu
with pm.Model() as m:
x = NormalWithInitval('x', 0, 1)
y = NormalWithInitval('y', x, 1) # raises TypeError |
Your example also raises the TypeError on It was also not too difficult to fix, even though that was not the goal of this PR. |
b7adcff
to
db5a804
Compare
db5a804
to
335bc94
Compare
062edf2
to
1b1c7b6
Compare
This changes the initval default on Distribution.__new__ and Distribution.dist to UNSET. It allows for implementing distribution-specific initial values similar to how it was done in pymc3 <4. To simplify code that actually picks the initvals, a helper function was added. Closes pymc-devs#4911 by enabling Model.set_initval to take symbolic initial values.
1b1c7b6
to
2743be6
Compare
superseded by #4983 |
This builds on top of #4913 to implement an API for setting default initial values.
Changes
initval=UNSET
in the signature ofDistribution.__new__
to distinguish between the user not passing the kwarg vs. asking for random initial values withinitval=None
.Distribution.pick_initval
classmethod that distributions can override to implement their own initval creation logic (as suggested by @ricardoV94 ).