Skip to content

Fix dist resizes #5719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions pymc/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from aesara.tensor import TensorVariable
from aesara.tensor.random.op import RandomVariable

from pymc.aesaraf import change_rv_size
from pymc.distributions.distribution import SymbolicDistribution, _moment
from pymc.util import check_dist_not_registered

Expand Down Expand Up @@ -74,10 +75,13 @@ def dist(cls, dist, lower, upper, **kwargs):

@classmethod
def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):
if lower is None:
lower = at.constant(-np.inf)
if upper is None:
upper = at.constant(np.inf)

lower = at.constant(-np.inf) if lower is None else at.as_tensor_variable(lower)
upper = at.constant(np.inf) if upper is None else at.as_tensor_variable(upper)

# When size is not specified, dist may have to be broadcasted according to lower/upper
dist_shape = size if size is not None else at.broadcast_shape(dist, lower, upper)
dist = change_rv_size(dist, dist_shape)

# Censoring is achieved by clipping the base distribution between lower and upper
rv_out = at.clip(dist, lower, upper)
Expand All @@ -88,8 +92,6 @@ def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):
rv_out.tag.lower = lower
rv_out.tag.upper = upper

if size is not None:
rv_out = cls.change_size(rv_out, size)
if rngs is not None:
rv_out = cls.change_rngs(rv_out, rngs)

Expand All @@ -101,12 +103,10 @@ def ndim_supp(cls, *dist_params):

@classmethod
def change_size(cls, rv, new_size, expand=False):
dist_node = rv.tag.dist.owner
dist = rv.tag.dist
lower = rv.tag.lower
upper = rv.tag.upper
rng, old_size, dtype, *dist_params = dist_node.inputs
new_size = new_size if not expand else tuple(new_size) + tuple(old_size)
new_dist = dist_node.op.make_node(rng, new_size, dtype, *dist_params).default_output()
new_dist = change_rv_size(dist, new_size, expand=expand)
return cls.rv_op(new_dist, lower, upper)

@classmethod
Expand Down
31 changes: 16 additions & 15 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from aesara.tensor.nlinalg import det, eigh, matrix_inverse, trace
from aesara.tensor.random.basic import MultinomialRV, dirichlet, multivariate_normal
from aesara.tensor.random.op import RandomVariable, default_supp_shape_from_params
from aesara.tensor.random.utils import broadcast_params
from aesara.tensor.random.utils import broadcast_params, normalize_size_param
from aesara.tensor.slinalg import Cholesky
from aesara.tensor.slinalg import solve_lower_triangular as solve_lower
from aesara.tensor.slinalg import solve_upper_triangular as solve_upper
Expand Down Expand Up @@ -1134,6 +1134,19 @@ def make_node(self, rng, size, dtype, n, eta, D):

D = at.as_tensor_variable(D)

# We resize the sd_dist `D` automatically so that it has (size x n) independent
# draws which is what the `_LKJCholeskyCovRV.rng_fn` expects. This makes the
# random and logp methods equivalent, as the latter also assumes a unique value
# for each diagonal element.
# Since `eta` and `n` are forced to be scalars we don't need to worry about
# implied batched dimensions for the time being.
size = normalize_size_param(size)
if D.owner.op.ndim_supp == 0:
D = change_rv_size(D, at.concatenate((size, (n,))))
else:
# The support shape must be `n` but we have no way of controlling it
D = change_rv_size(D, size)

return super().make_node(rng, size, dtype, n, eta, D)

def _infer_shape(self, size, dist_params, param_shapes=None):
Expand Down Expand Up @@ -1179,7 +1192,7 @@ def __new__(cls, name, eta, n, sd_dist, **kwargs):
return super().__new__(cls, name, eta, n, sd_dist, **kwargs)

@classmethod
def dist(cls, eta, n, sd_dist, size=None, **kwargs):
def dist(cls, eta, n, sd_dist, **kwargs):
eta = at.as_tensor_variable(floatX(eta))
n = at.as_tensor_variable(intX(n))

Expand All @@ -1191,18 +1204,6 @@ def dist(cls, eta, n, sd_dist, size=None, **kwargs):
):
raise TypeError("sd_dist must be a scalar or vector distribution variable")

# We resize the sd_dist automatically so that it has (size x n) independent draws
# which is what the `_LKJCholeskyCovRV.rng_fn` expects. This makes the random
# and logp methods equivalent, as the latter also assumes a unique value for each
# diagonal element.
# Since `eta` and `n` are forced to be scalars we don't need to worry about
# implied batched dimensions for the time being.
if sd_dist.owner.op.ndim_supp == 0:
sd_dist = change_rv_size(sd_dist, to_tuple(size) + (n,))
else:
# The support shape must be `n` but we have no way of controlling it
sd_dist = change_rv_size(sd_dist, to_tuple(size))

# sd_dist is part of the generative graph, but should be completely ignored
# by the logp graph, since the LKJ logp explicitly includes these terms.
# Setting sd_dist.tag.ignore_logprob to True, will prevent Aeppl warning about
Expand All @@ -1211,7 +1212,7 @@ def dist(cls, eta, n, sd_dist, size=None, **kwargs):
# sd_dist prior components from the logp expression.
sd_dist.tag.ignore_logprob = True

return super().dist([n, eta, sd_dist], size=size, **kwargs)
return super().dist([n, eta, sd_dist], **kwargs)

def moment(rv, size, n, eta, sd_dists):
diag_idxs = (at.cumsum(at.arange(1, n + 1)) - 1).astype("int32")
Expand Down
29 changes: 13 additions & 16 deletions pymc/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@

from aesara import scan
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.utils import normalize_size_param

from pymc.aesaraf import change_rv_size, floatX, intX
from pymc.distributions import distribution, logprob, multivariate
from pymc.distributions.continuous import Flat, Normal, get_tau_sigma
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.shape_utils import to_tuple
from pymc.distributions.shape_utils import rv_size_is_none, to_tuple
from pymc.util import check_dist_not_registered

__all__ = [
Expand Down Expand Up @@ -54,6 +55,16 @@ def make_node(self, rng, size, dtype, mu, sigma, init, steps):
if not steps.ndim == 0 or not steps.dtype.startswith("int"):
raise ValueError("steps must be an integer scalar (ndim=0).")

mu = at.as_tensor_variable(mu)
sigma = at.as_tensor_variable(sigma)
init = at.as_tensor_variable(init)

# Resize init distribution
size = normalize_size_param(size)
# If not explicit, size is determined by the shapes of mu, sigma, and init
init_size = size if not rv_size_is_none(size) else at.broadcast_shape(mu, sigma, init)
init = change_rv_size(init, init_size)

return super().make_node(rng, size, dtype, mu, sigma, init, steps)

def _supp_shape_from_params(self, dist_params, reop_param_idx=0, param_shapes=None):
Expand Down Expand Up @@ -116,7 +127,6 @@ def rng_fn(

# If size is None then the returned series should be (*size, 1+steps)
else:
init_size = (*size, 1)
dist_shape = (*size, int(steps))

innovations = rng.normal(loc=mu, scale=sigma, size=dist_shape)
Expand Down Expand Up @@ -161,15 +171,9 @@ def dist(
raise ValueError("Must specify steps parameter")
steps = at.as_tensor_variable(intX(steps))

shape = kwargs.get("shape", None)
if size is None and shape is None:
init_size = None
else:
init_size = to_tuple(size) if size is not None else to_tuple(shape)[:-1]

# If no scalar distribution is passed then initialize with a Normal of same mu and sigma
if init is None:
init = Normal.dist(mu, sigma, size=init_size)
init = Normal.dist(mu, sigma)
else:
if not (
isinstance(init, at.TensorVariable)
Expand All @@ -179,13 +183,6 @@ def dist(
):
raise TypeError("init must be a univariate distribution variable")

if init_size is not None:
init = change_rv_size(init, init_size)
else:
# If not explicit, size is determined by the shapes of mu, sigma, and init
bcast_shape = at.broadcast_arrays(mu, sigma, init)[0].shape
init = change_rv_size(init, bcast_shape)

# Ignores logprob of init var because that's accounted for in the logp method
init.tag.ignore_logprob = True

Expand Down
26 changes: 24 additions & 2 deletions pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3381,6 +3381,18 @@ def test_change_size(self):
new_dist = pm.Censored.change_size(base_dist, (4,), expand=True)
assert new_dist.eval().shape == (4, 3, 2)

def test_dist_broadcasted_by_lower_upper(self):
x = pm.Censored.dist(pm.Normal.dist(), lower=np.zeros((2,)), upper=None)
assert tuple(x.owner.inputs[0].shape.eval()) == (2,)

x = pm.Censored.dist(pm.Normal.dist(), lower=np.zeros((2,)), upper=np.zeros((4, 2)))
assert tuple(x.owner.inputs[0].shape.eval()) == (4, 2)

x = pm.Censored.dist(
pm.Normal.dist(size=(3, 4, 2)), lower=np.zeros((2,)), upper=np.zeros((4, 2))
)
assert tuple(x.owner.inputs[0].shape.eval()) == (3, 4, 2)


class TestLKJCholeskCov:
def test_dist(self):
Expand Down Expand Up @@ -3425,8 +3437,18 @@ def test_no_warning_logp(self):
pm.MvNormal.dist(np.ones(3), np.eye(3)),
],
)
def test_sd_dist_automatically_resized(self, sd_dist):
x = pm.LKJCholeskyCov.dist(n=3, eta=1, sd_dist=sd_dist, size=10, compute_corr=False)
@pytest.mark.parametrize(
"size, shape",
[
((10,), None),
(None, (10, 6)),
(None, (10, ...)),
],
)
def test_sd_dist_automatically_resized(self, sd_dist, size, shape):
x = pm.LKJCholeskyCov.dist(
n=3, eta=1, sd_dist=sd_dist, size=size, shape=shape, compute_corr=False
)
resized_sd_dist = x.owner.inputs[-1]
assert resized_sd_dist.eval().shape == (10, 3)
# LKJCov has support shape `(n * (n+1)) // 2`
Expand Down
6 changes: 6 additions & 0 deletions pymc/tests/test_distributions_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ def test_gaussian_random_walk_init_dist_shape(init):
assert tuple(grw.owner.inputs[-2].shape.eval()) == (3, 2)


def test_shape_ellipsis():
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=5, init=pm.Normal.dist(), shape=(3, ...))
assert tuple(grw.shape.eval()) == (3, 6)
assert tuple(grw.owner.inputs[-2].shape.eval()) == (3,)


def test_gaussianrandomwalk_broadcasted_by_init_dist():
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=4, init=pm.Normal.dist(size=(2, 3)))
assert tuple(grw.shape.eval()) == (2, 3, 5)
Expand Down
6 changes: 3 additions & 3 deletions pymc/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,16 +1406,16 @@ def test_draw_aesara_function_kwargs(self):
assert np.all(draws == np.arange(5))


class test_step_args(SeededTest):
with pm.Model() as model:
def test_step_args():
with pm.Model(rng_seeder=1410) as model:
a = pm.Normal("a")
idata0 = pm.sample(target_accept=0.5)
idata1 = pm.sample(nuts={"target_accept": 0.5})

npt.assert_almost_equal(idata0.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)
npt.assert_almost_equal(idata1.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)

with pm.Model() as model:
with pm.Model(rng_seeder=1418) as model:
a = pm.Normal("a")
b = pm.Poisson("b", 1)
idata0 = pm.sample(target_accept=0.5)
Expand Down