Skip to content

Restore support for default initvals #4867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

michaelosthege
Copy link
Member

@michaelosthege michaelosthege commented Jul 17, 2021

This builds on top of #4913 to implement an API for setting default initial values.

Changes

  • initval=UNSET in the signature of Distribution.__new__ to distinguish between the user not passing the kwarg vs. asking for random initial values with initval=None.
  • Distribution.pick_initval classmethod that distributions can override to implement their own initval creation logic (as suggested by @ricardoV94 ).

@codecov
Copy link

codecov bot commented Jul 17, 2021

Codecov Report

Merging #4867 (1b1c7b6) into main (d2bf35e) will increase coverage by 0.05%.
The diff coverage is 98.36%.

❗ Current head 1b1c7b6 differs from pull request most recent head 2743be6. Consider uploading reports for the commit 2743be6 to get more accurate results
Impacted file tree graph

@@            Coverage Diff             @@
##             main    #4867      +/-   ##
==========================================
+ Coverage   73.12%   73.17%   +0.05%     
==========================================
  Files          86       86              
  Lines       13856    13892      +36     
==========================================
+ Hits        10132    10166      +34     
- Misses       3724     3726       +2     
Impacted Files Coverage Δ
pymc3/model.py 83.49% <92.30%> (+0.02%) ⬆️
pymc3/distributions/continuous.py 96.30% <100.00%> (+0.03%) ⬆️
pymc3/distributions/distribution.py 82.63% <100.00%> (+2.20%) ⬆️
pymc3/util.py 75.90% <100.00%> (+2.39%) ⬆️
pymc3/parallel_sampling.py 87.20% <0.00%> (-1.02%) ⬇️

Copy link
Contributor

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that this is a WIP, but what's the basic idea (e.g. what is the mechanism for setting default initial values, how will it work, etc.)?

@michaelosthege
Copy link
Member Author

I understand that this is a WIP, but what's the basic idea (e.g. what is the mechanism for setting default initial values, how will it work, etc.)?

I'm close to pushing another two commits. I'm primarily adding a lot of testing to nail down what even is the expected behavior.
The mechanism I'm implementing uses UNSET and None to distinguish two initval behaviors:

  • UNSET → The distribution may implement a default initval.
  • None → This means "use a random draw"

My intention there is to have the implementation of choosing a default initval in the distributions (where we had it before), while allowing users to disable that too.

@michaelosthege michaelosthege force-pushed the robustify-start-values branch from b5ba6fd to 79e1911 Compare July 17, 2021 19:59
@michaelosthege michaelosthege force-pushed the robustify-start-values branch from 79e1911 to 8d15f10 Compare July 17, 2021 20:18
@michaelosthege michaelosthege force-pushed the robustify-start-values branch 2 times, most recently from 0ccd7be to 32dd28f Compare July 18, 2021 14:00
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems reasonable but I am not sure we need a separate helper function just to choose between None, Unset, and something else. Also I don't really see the need for type checking, it seems things would fail pretty quickly without it (e.g. with the tag.test_value assignment), but that's just a personal opinion.

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 18, 2021

What about having a new distrobution classmethod that takes the RV inputs and returns an initval?

Then during the super __new__ / dist we just call it whenever the initval is set to UNSET. I think this would be safer than policing future PRs in order to respect the API by calling select_initval

The default would obviously be None

class Distribution:

    ...

    @classmethod
    def initval(cls, *args):
      return None
class Normal(Continuous):

    ...

    @classmethod
    def initval(cls, mu, sigma):
        return mu

Edit: Or a dispatched method like logp and logcdf if the initval is not needed at creation time

@michaelosthege
Copy link
Member Author

[...] I am not sure we need a separate helper function just to choose between None, Unset, and something else.

Even though it's in the git history like that, I didn't write the helper function first: I ran into a few subtle traps when implementing it without the helper function. For example initval == UNSET is a bad idea, because of elementwise array comparisons.. Also one often can't use the default, because it's symbolic.. Asking future devs to do that correctly every time felt nasty..

@michaelosthege
Copy link
Member Author

What about having a new classmethod that takes the RV inputs and returns an initval?

Then during super __new__ / dist we just call it whenever the initval is set to UNSET. I think this would be safer than policing future PRs in order to respect the API by calling select_initval

In v3 most of the initvals were computed from some kind of input transformations that were done in .dist() anways. Moving that into a separate method could lead to quite a bit of code duplication, no?
W.r.t. the policing: That's done in __new__/dist and I wrote quite a few tests for that too.

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 18, 2021

In v3 most of the initvals were computed from some kind of input transformations that were done in .dist() anways. Moving that into a separate method could lead to quite a bit of code duplication, no?

I think that was more out of habit than due to a conscious decision to do it that way. One advantage of having a separate method is for instance that the inputs will already be symbolic. This avoids bugs when the initval logic was based on pure numpy operations that could fail with symbolic inputs.

Also many times it's just a question of returning one of the parameters. I don't think it makes the code more complicated, if anything it seems like a more clean separation of responsibility

Edit: it's not intuitive at all that .dist should be responsible for implement thr initval. Having an explicit classmethod seems much more obvious. Don't you think so?

@michaelosthege
Copy link
Member Author

In v3 most of the initvals were computed from some kind of input transformations that were done in .dist() anways. Moving that into a separate method could lead to quite a bit of code duplication, no?

I think that was more out of habit than due to a conscious decision to do it that way. One advantage of having a separate method is for instance that the inputs will already be symbolic. This avoids bugs when the initval logic was based on pure numpy operations that could fail with symbolic inputs.

Also many times it's just a question of returning one of the parameters. I don't think it makes the code more complicated, if anything it seems like a more clean separation of responsibility

Okay fair enough. One thing though: Right now we don't support symbolic initvals, because the test_value needs to be numeric.

A switch to a classmethod doesn't make a difference for the UNSET / None logic, right?

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 18, 2021

One thing though: Right now we don't support symbolic initvals, because the test_value needs to be numeric.

Hmm. How did it work before? For example in the case where the initval is just the mu of the Normal, which itself comes from a parent distribution.

A switch to a classmethod doesn't make a difference for the UNSET / None logic, right?

No, but then you don't really need the helper function I think. You just call it in __new__, dist whenever you get an UNSET (but if you really see an advantage in the extra type checks that's fine, again it's just an aesthetic opinion)

@michaelosthege michaelosthege force-pushed the robustify-start-values branch 2 times, most recently from 59910ab to 4c09cb6 Compare July 19, 2021 21:44
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only two minor questions / suggestions:

  1. I would rename make_initval to [choose|pick]_initval or simplyinitval.
  2. Perhaps it would make more sense to move the new calls inside model.set_initval? Since that already handles the prior-sampling case (It seemed to me they aren't needed before). To achieve that we can simply dispatch the make_initval to the underlying RV, just like logp / logcdf does.

@michaelosthege michaelosthege marked this pull request as ready for review July 23, 2021 20:43
@michaelosthege
Copy link
Member Author

@ricardoV94 what's the matter with this test? https://github.com/pymc-devs/pymc3/pull/4867/checks?check_run_id=3150614327#step:7:448

I restarted it already, but it's still flaky??

@ricardoV94
Copy link
Member

I restarted it already, but it's still flaky??

Yeah, it was not tweaked, but let's just remove it.

@michaelosthege michaelosthege changed the title Implement a mechanism to set default initvals Mechanism for setting default initvals Jul 24, 2021
@michaelosthege
Copy link
Member Author

How are we feeling about this?

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 28, 2021

How are we feeling about this?

I think it would be more clean to call pick_initval inside the model.set_initval function, so that all the initival logic is found in one place.

I am not sure what the new test_val check achieves.

Otherwise LGTM

@michaelosthege
Copy link
Member Author

michaelosthege commented Jul 28, 2021

I think it would be more clean to call pick_initval inside the model.set_initval function, so that all the initival logic is found in one place.

Oh sorry, I forgot to mention: I tried this, but Model.register_rv and Model.set_initval are fed a TensorVariable and thereby doesn't have a handle on the corresponding Distribution class - thereby not having a handle towards the pick_initval classmethod.

I am not sure what the new test_val check achieves.

The intent is to assist migration by pointing out distributions that still follow the old patterns.
But we can take those warnings out if you like.

@ricardoV94
Copy link
Member

Oh sorry, I forgot to mention: I tried this, but Model.register_rv and Model.set_initval are fed a TensorVariable and thereby doesn't have a handle on the corresponding Distribution class - thereby not having a handle towards the pick_initval classmethod.

We can get the class pick_initval via the owner.op of the TensorVariable, similar to how we get the logp / logcdf

@michaelosthege michaelosthege force-pushed the robustify-start-values branch 2 times, most recently from 7099181 to c9e460f Compare August 7, 2021 11:48
@michaelosthege
Copy link
Member Author

I rebased this and reconsidered the dispatching approach, but concluded that it's not worth the additional complexity. While logp and logcdf are related to the RV, the initval just concerns the PyMC3 model, so I don't think we should dispatch it onto the Aesara Op.

Please take a final look so we can (squash?) merge and get this over with.

@michaelosthege michaelosthege requested a review from twiecki August 7, 2021 12:30
@ricardoV94
Copy link
Member

ricardoV94 commented Aug 7, 2021

How does your implementation cope with hierarchical models? If it's not based on the tag.test_value trick, the initvals will be symbolic expressions.

with pm.Model() as m:
  x = pm.Normal("x", 0, 1)
  y = pm.Normal("y", x, 1)

Assuming the pick_initval of pm.Normal is just mu, how would the y initval be calculated? The only way I see it (without relying on tag.test_value) is to compile the expression and evaluate it, similarly to how it's done in model.set_initval for forward samples.

@ricardoV94
Copy link
Member

By the way @kc611 is exploring the alternative solution with dispatching. So he might have some ideas as well.

@michaelosthege
Copy link
Member Author

How does your implementation cope with hierarchical models? If it's not based on the tag.test_value trick, the initvals will be symbolic expressions.

with pm.Model() as m:
  x = pm.Normal("x", 0, 1)
  y = pm.Normal("y", x, 1)

Assuming the pick_initval of pm.Normal is just mu, how would the y initval be calculated? The only way I see it (without relying on tag.test_value) is to compile the expression and evaluate it, similarly to how it's done in model.set_initval for forward samples.

I'm just forwarding them to model.set_initval, so whatever it did with symbolic initvals before this PR still applies the same way.
I'm not on the laptop right now, but if I remember correctly there already was code to compile and evaluate symbolic initvals.

@ricardoV94
Copy link
Member

ricardoV94 commented Aug 7, 2021

I'm just forwarding them to model.set_initval, so whatever it did with symbolic initvals before this PR still applies the same way.
I'm not on the laptop right now, but if I remember correctly there already was code to compile and evaluate symbolic initvals.

Looking at model.set_initval it seems that when given a non-None initval it only checks if it should be transformed. It wouldn't do anything for the Normal for example

https://github.com/pymc-devs/pymc3/blob/c9e460f29396130ff625bf84a99943c935ac5be9/pymc3/model.py#L951

And even when transforming initvals, it seems to expect them to be a Constant Nvm that's for the previous initvals

@michaelosthege
Copy link
Member Author

Looking at model.set_initval it seems that when given a non-None initval it only checks if it should be transformed. It wouldn't do anything for the Normal for example

Improving the model.set_initval method is out of scope of this PR.
I edited just one line in there to stop it from accessing test_values, but that doesn't change its behavor for scenarios where initval is a default OR user-provided TensorVariable.

Please open a new Issue if it misbehaves.

@ricardoV94
Copy link
Member

Improving the model.set_initval method is out of scope of this PR.

I don't see how that can be out-of-scope for a PR that is refactoring how initvals work.

Please open a new Issue if it misbehaves.

I am pretty sure it will misbehave in the example I just gave.

If the new implementation is not designed to handle a simple realistic case, how can it be evaluated?

@ricardoV94
Copy link
Member

ricardoV94 commented Aug 7, 2021

import pymc3 as pm
import aesara.tensor as at
import aesara.tensor.random as atr

class NormalWithInitval(pm.distributions.Continuous):
    """
    A distribution that defaults the initial value.
    """

    rv_op = atr.normal

    @classmethod
    def dist(cls, mu=0, sigma=1, **kwargs):
        mu = at.as_tensor_variable(pm.floatX(mu))
        sigma = at.as_tensor_variable(pm.floatX(sigma))
        return super().dist([mu, sigma], **kwargs)

    @classmethod
    def pick_initval(cls, mu, sigma, **kwargs):
        return mu

with pm.Model() as m:
    x = NormalWithInitval('x', 0, 1)
    y = NormalWithInitval('y', x, 1)  # raises TypeError

@michaelosthege
Copy link
Member Author

Your example also raises the TypeError on main. I have therefore opened #4911.

It was also not too difficult to fix, even though that was not the goal of this PR.

@michaelosthege michaelosthege changed the title Mechanism for setting default initvals Restore support for symbolic and default initvals Aug 7, 2021
@michaelosthege michaelosthege force-pushed the robustify-start-values branch from b7adcff to db5a804 Compare August 7, 2021 19:17
@michaelosthege michaelosthege force-pushed the robustify-start-values branch from db5a804 to 335bc94 Compare August 8, 2021 09:58
@michaelosthege michaelosthege changed the title Restore support for symbolic and default initvals Restore support for default initvals Aug 8, 2021
@michaelosthege michaelosthege force-pushed the robustify-start-values branch 3 times, most recently from 062edf2 to 1b1c7b6 Compare August 8, 2021 16:16
This changes the initval default on Distribution.__new__ and Distribution.dist to UNSET.
It allows for implementing distribution-specific initial values similar to how it was done in pymc3 <4.
To simplify code that actually picks the initvals, a helper function was added.

Closes pymc-devs#4911 by enabling Model.set_initval to take symbolic initial values.
@michaelosthege michaelosthege force-pushed the robustify-start-values branch from 1b1c7b6 to 2743be6 Compare August 10, 2021 15:11
@michaelosthege
Copy link
Member Author

superseded by #4983

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants