Skip to content

Change initial_point default strategy from "prior" to "moment" #5140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conda-envs/environment-dev-py37.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
- defaults
dependencies:
- aeppl=0.0.18
- aesara>=2.2.6
- aesara=2.3.2
- arviz>=0.11.4
- cachetools>=4.2.1
- cloudpickle
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-dev-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
- defaults
dependencies:
- aeppl=0.0.18
- aesara>=2.2.6
- aesara=2.3.2
- arviz>=0.11.4
- cachetools>=4.2.1
- cloudpickle
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-dev-py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
- defaults
dependencies:
- aeppl=0.0.18
- aesara>=2.2.6
- aesara=2.3.2
- arviz>=0.11.4
- cachetools>=4.2.1
- cloudpickle
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test-py37.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
- defaults
dependencies:
- aeppl=0.0.18
- aesara>=2.2.6
- aesara=2.3.2
- arviz>=0.11.4
- cachetools>=4.2.1
- cloudpickle
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
- defaults
dependencies:
- aeppl=0.0.18
- aesara>=2.2.6
- aesara=2.3.2
- arviz>=0.11.4
- cachetools>=4.2.1
- cloudpickle
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test-py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
- defaults
dependencies:
- aeppl=0.0.18
- aesara>=2.2.6
- aesara=2.3.2
- arviz>=0.11.4
- cachetools
- cloudpickle
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
dependencies:
# base dependencies (see install guide for Windows)
- aeppl=0.0.18
- aesara>=2.2.6
- aesara=2.3.2
- arviz>=0.11.4
- cachetools>=4.2.1
- cloudpickle
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
dependencies:
# base dependencies (see install guide for Windows)
- aeppl=0.0.18
- aesara>=2.2.6
- aesara=2.3.2
- arviz>=0.11.2
- cachetools
- cloudpickle
Expand Down
6 changes: 3 additions & 3 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,16 +717,16 @@ def dist(
def get_moment(rv, size, mu, sigma, lower, upper):
mu, _, lower, upper = at.broadcast_arrays(mu, sigma, lower, upper)
moment = at.switch(
at.isinf(lower),
at.eq(lower, -np.inf),
at.switch(
at.isinf(upper),
at.eq(upper, np.inf),
# lower = -inf, upper = inf
mu,
# lower = -inf, upper = x
upper - 1,
),
at.switch(
at.isinf(upper),
at.eq(upper, np.inf),
# lower = x, upper = inf
lower + 1,
# lower = x, upper = x
Expand Down
19 changes: 16 additions & 3 deletions pymc/initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import warnings

from typing import Callable, Dict, List, Optional, Sequence, Set, Union

Expand Down Expand Up @@ -131,7 +132,7 @@ def make_initial_point_fn(
model,
overrides: Optional[StartDict] = None,
jitter_rvs: Optional[Set[TensorVariable]] = None,
default_strategy: str = "prior",
default_strategy: str = "moment",
return_transformed: bool = True,
) -> Callable:
"""Create seeded function that computes initial values for all free model variables.
Expand Down Expand Up @@ -226,7 +227,7 @@ def make_initial_point_expression(
rvs_to_values: Dict[TensorVariable, TensorVariable],
initval_strategies: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]],
jitter_rvs: Set[TensorVariable] = None,
default_strategy: str = "prior",
default_strategy: str = "moment",
return_transformed: bool = False,
) -> List[TensorVariable]:
"""Creates the tensor variables that need to be evaluated to obtain an initial point.
Expand Down Expand Up @@ -269,7 +270,19 @@ def make_initial_point_expression(

if isinstance(strategy, str):
if strategy == "moment":
value = get_moment(variable)
try:
value = get_moment(variable)
except NotImplementedError:
warnings.warn(
f"Moment not defined for variable {variable} of type "
f"{variable.owner.op.__class__.__name__}, defaulting to "
f"a draw from the prior. This can lead to difficulties "
f"during tuning. You can manually define an initval or "
f"implement a get_moment dispatched function for this "
f"distribution.",
UserWarning,
)
value = variable
elif strategy == "prior":
value = variable
else:
Expand Down
117 changes: 93 additions & 24 deletions pymc/tests/test_distributions_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import pymc as pm

from pymc import Simulator
from pymc.distributions import (
AsymmetricLaplace,
Bernoulli,
Expand Down Expand Up @@ -50,6 +49,7 @@
Poisson,
PolyaGamma,
Rice,
Simulator,
SkewNormal,
StudentT,
Triangular,
Expand All @@ -62,13 +62,74 @@
ZeroInflatedNegativeBinomial,
ZeroInflatedPoisson,
)
from pymc.distributions.distribution import get_moment
from pymc.distributions.distribution import _get_moment, get_moment
from pymc.distributions.logprob import logpt
from pymc.distributions.multivariate import MvNormal
from pymc.distributions.shape_utils import rv_size_is_none, to_tuple
from pymc.initial_point import make_initial_point_fn
from pymc.model import Model


def test_all_distributions_have_moments():
import pymc.distributions as dist_module

from pymc.distributions.distribution import DistributionMeta

dists = (getattr(dist_module, dist) for dist in dist_module.__all__)
dists = (dist for dist in dists if isinstance(dist, DistributionMeta))
missing_moments = {
dist for dist in dists if type(getattr(dist, "rv_op", None)) not in _get_moment.registry
}

# Ignore super classes
missing_moments -= {
dist_module.Distribution,
dist_module.Discrete,
dist_module.Continuous,
dist_module.NoDistribution,
dist_module.DensityDist,
dist_module.simulator.Simulator,
}

# Distributions that have not been refactored for V4 yet
not_implemented = {
dist_module.multivariate.LKJCorr,
dist_module.mixture.Mixture,
dist_module.mixture.MixtureSameFamily,
dist_module.mixture.NormalMixture,
dist_module.timeseries.AR,
dist_module.timeseries.AR1,
dist_module.timeseries.GARCH11,
dist_module.timeseries.GaussianRandomWalk,
dist_module.timeseries.MvGaussianRandomWalk,
dist_module.timeseries.MvStudentTRandomWalk,
}

# Distributions that have been refactored but don't yet have moments
not_implemented |= {
dist_module.discrete.DiscreteWeibull,
dist_module.multivariate.CAR,
dist_module.multivariate.DirichletMultinomial,
dist_module.multivariate.KroneckerNormal,
dist_module.multivariate.Wishart,
}

unexpected_implemented = not_implemented - missing_moments
if unexpected_implemented:
raise Exception(
f"Distributions {unexpected_implemented} have a `get_moment` implemented. "
"This test must be updated to expect this."
)

unexpected_not_implemented = missing_moments - not_implemented
if unexpected_not_implemented:
raise NotImplementedError(
f"Unexpected by this test, distributions {unexpected_not_implemented} do "
"not have a `get_moment` implementation. Either add a moment or filter "
"these distributions in this test."
)


def test_rv_size_is_none():
rv = Normal.dist(0, 1, size=None)
assert rv_size_is_none(rv.owner.inputs[1])
Expand All @@ -85,20 +146,25 @@ def test_rv_size_is_none():
assert not rv_size_is_none(rv.owner.inputs[1])


def assert_moment_is_expected(model, expected):
def assert_moment_is_expected(model, expected, check_finite_logp=True):
fn = make_initial_point_fn(
model=model,
return_transformed=False,
default_strategy="moment",
)
result = fn(0)["x"]
moment = fn(0)["x"]
expected = np.asarray(expected)
try:
random_draw = model["x"].eval()
except NotImplementedError:
random_draw = result
assert result.shape == expected.shape == random_draw.shape
assert np.allclose(result, expected)
random_draw = moment

assert moment.shape == expected.shape == random_draw.shape
assert np.allclose(moment, expected)

if check_finite_logp:
logp_moment = logpt(model["x"], at.constant(moment), transformed=False).eval()
assert np.isfinite(logp_moment)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -189,14 +255,13 @@ def test_halfstudentt_moment(nu, sigma, size, expected):
assert_moment_is_expected(model, expected)


@pytest.mark.skip(reason="aeppl interval transform fails when both edges are None")
@pytest.mark.parametrize(
"mu, sigma, lower, upper, size, expected",
[
(0.9, 1, -1, 1, None, 0),
(0.9, 1, -np.inf, np.inf, 5, np.full(5, 0.9)),
(0.9, 1, -5, 5, None, 0),
(1, np.ones(5), -10, np.inf, None, np.full(5, -9)),
(np.arange(5), 1, None, 10, (2, 5), np.full((2, 5), 9)),
(1, np.ones(5), -10, np.inf, None, np.full((2, 5), -9)),
(1, 1, [-np.inf, -np.inf, -np.inf], 10, None, np.full(3, 9)),
],
)
def test_truncatednormal_moment(mu, sigma, lower, upper, size, expected):
Expand Down Expand Up @@ -371,11 +436,11 @@ def test_lognormal_moment(mu, sigma, size, expected):
[
(1, None, 1),
(1, 5, np.ones(5)),
(np.arange(5), None, np.arange(5)),
(np.arange(1, 5), None, np.arange(1, 5)),
(
np.arange(5),
(2, 5),
np.full((2, 5), np.arange(5)),
np.arange(1, 5),
(2, 4),
np.full((2, 4), np.arange(1, 5)),
),
],
)
Expand Down Expand Up @@ -617,11 +682,11 @@ def test_logistic_moment(mu, s, size, expected):
@pytest.mark.parametrize(
"mu, nu, sigma, size, expected",
[
(1, 1, None, None, 2),
(1, 1, 1, None, 2),
(1, 1, np.ones((2, 5)), None, np.full([2, 5], 2)),
(1, 1, None, 5, np.full(5, 2)),
(1, np.arange(1, 6), None, None, np.arange(2, 7)),
(1, np.arange(1, 6), None, (2, 5), np.full((2, 5), np.arange(2, 7))),
(1, 1, 3, 5, np.full(5, 2)),
(1, np.arange(1, 6), 5, None, np.arange(2, 7)),
(1, np.arange(1, 6), 1, (2, 5), np.full((2, 5), np.arange(2, 7))),
],
)
def test_exgaussian_moment(mu, nu, sigma, size, expected):
Expand Down Expand Up @@ -861,8 +926,10 @@ def test_interpolated_moment(x_points, pdf_points, size, expected):
)
def test_mv_normal_moment(mu, cov, size, expected):
with Model() as model:
MvNormal("x", mu=mu, cov=cov, size=size)
assert_moment_is_expected(model, expected)
x = MvNormal("x", mu=mu, cov=cov, size=size)

# MvNormal logp is only impemented for up to 2D variables
assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -898,8 +965,10 @@ def test_moyal_moment(mu, sigma, size, expected):
)
def test_mvstudentt_moment(nu, mu, cov, size, expected):
with Model() as model:
MvStudentT("x", nu=nu, mu=mu, cov=cov, size=size)
assert_moment_is_expected(model, expected)
x = MvStudentT("x", nu=nu, mu=mu, cov=cov, size=size)

# MvStudentT logp is only impemented for up to 2D variables
assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3)


def check_matrixnormal_moment(mu, rowchol, colchol, size, expected):
Expand Down Expand Up @@ -1035,7 +1104,7 @@ def test_density_dist_default_moment_univariate(get_moment, size, expected):
get_moment = lambda rv, size, *rv_inputs: 5 * at.ones(size, dtype=rv.dtype)
with Model() as model:
DensityDist("x", get_moment=get_moment, size=size)
assert_moment_is_expected(model, expected)
assert_moment_is_expected(model, expected, check_finite_logp=False)


@pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str)
Expand Down
26 changes: 26 additions & 0 deletions pymc/tests/test_initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import numpy as np
import pytest

from aesara.tensor.random.op import RandomVariable

import pymc as pm

from pymc.distributions.distribution import get_moment
Expand Down Expand Up @@ -255,6 +257,30 @@ def test_moment_from_dims(self, rv_cls):
assert tuple(get_moment(rv).shape.eval()) == (4, 3)
pass

def test_moment_not_implemented_fallback(self):
class MyNormalRV(RandomVariable):
name = "my_normal"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "floatX"

@classmethod
def rng_fn(cls, rng, mu, sigma, size):
return np.pi

class MyNormalDistribution(pm.Normal):
rv_op = MyNormalRV()

with pm.Model() as m:
x = MyNormalDistribution("x", 0, 1, initval="moment")

with pytest.warns(
UserWarning, match="Moment not defined for variable x of type MyNormalRV"
):
res = m.recompute_initial_point()

assert np.isclose(res["x"], np.pi)


def test_pickling_issue_5090():
with pm.Model() as model:
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# See that file for comments about the need/usage of each dependency.

aeppl==0.0.18
aesara>=2.2.6
aesara==2.3.2
arviz>=0.11.4
cachetools>=4.2.1
cloudpickle
Expand Down
Loading