Skip to content

Commit 011c68f

Browse files
ricardoV94michaelosthege
authored andcommitted
Adapt to new size kwarg interpretation
Closes #5446
1 parent b09805a commit 011c68f

7 files changed

+59
-62
lines changed

pymc/aesaraf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,9 @@ def change_rv_size(
173173
tag = rv_var.tag
174174

175175
if expand:
176-
if rv_node.op.ndim_supp == 0 and at.get_vector_length(size) == 0:
177-
size = rv_node.op._infer_shape(size, dist_params)
178-
new_size = tuple(new_size) + tuple(size)
176+
old_shape = tuple(rv_node.op._infer_shape(size, dist_params))
177+
old_size = old_shape[: len(old_shape) - rv_node.op.ndim_supp]
178+
new_size = tuple(new_size) + tuple(old_size)
179179

180180
# Make sure the new size is a tensor. This dtype-aware conversion helps
181181
# to not unnecessarily pick up a `Cast` in some cases (see #4652).

pymc/distributions/multivariate.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs):
247247
def get_moment(rv, size, mu, cov):
248248
moment = mu
249249
if not rv_size_is_none(size):
250-
moment_size = at.concatenate([size, mu.shape])
250+
moment_size = at.concatenate([size, [mu.shape[-1]]])
251251
moment = at.full(moment_size, mu)
252252
return moment
253253

@@ -301,18 +301,13 @@ def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
301301
@classmethod
302302
def rng_fn(cls, rng, nu, mu, cov, size):
303303

304-
# Don't reassign broadcasted cov, since MvNormal expects two dimensional cov only.
305-
mu, _ = broadcast_params([mu, cov], cls.ndims_params[1:])
306-
307-
chi2_samples = np.sqrt(rng.chisquare(nu, size=size) / nu)
308-
# Add distribution shape to chi2 samples
309-
chi2_samples = chi2_samples.reshape(chi2_samples.shape + (1,) * len(mu.shape))
310-
311304
mv_samples = multivariate_normal.rng_fn(rng=rng, mean=np.zeros_like(mu), cov=cov, size=size)
312305

313-
size = tuple(size or ())
306+
# Take chi2 draws and add an axis of length 1 to the right for correct broadcasting below
307+
chi2_samples = np.sqrt(rng.chisquare(nu, size=size) / nu)[..., None]
308+
314309
if size:
315-
mu = np.broadcast_to(mu, size + mu.shape)
310+
mu = np.broadcast_to(mu, size + (mu.shape[-1],))
316311

317312
return (mv_samples / chi2_samples) + mu
318313

@@ -379,7 +374,7 @@ def dist(cls, nu, Sigma=None, mu=None, cov=None, tau=None, chol=None, lower=True
379374
def get_moment(rv, size, nu, mu, cov):
380375
moment = mu
381376
if not rv_size_is_none(size):
382-
moment_size = at.concatenate([size, moment.shape])
377+
moment_size = at.concatenate([size, [mu.shape[-1]]])
383378
moment = at.full(moment_size, moment)
384379
return moment
385380

@@ -453,7 +448,7 @@ def get_moment(rv, size, a):
453448
norm_constant = at.sum(a, axis=-1)[..., None]
454449
moment = a / norm_constant
455450
if not rv_size_is_none(size):
456-
moment = at.full(at.concatenate([size, a.shape]), moment)
451+
moment = at.full(at.concatenate([size, [a.shape[-1]]]), moment)
457452
return moment
458453

459454
def logp(value, a):
@@ -497,8 +492,8 @@ def rng_fn(cls, rng, n, p, size):
497492
size = tuple(size or ())
498493

499494
if size:
500-
n = np.broadcast_to(n, size + n.shape)
501-
p = np.broadcast_to(p, size + p.shape)
495+
n = np.broadcast_to(n, size)
496+
p = np.broadcast_to(p, size + (p.shape[-1],))
502497

503498
res = np.empty(p.shape)
504499
for idx in np.ndindex(p.shape[:-1]):
@@ -571,7 +566,7 @@ def get_moment(rv, size, n, p):
571566
inc_bool_arr = at.abs_(diff) > 0
572567
mode = at.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])
573568
if not rv_size_is_none(size):
574-
output_size = at.concatenate([size, p.shape])
569+
output_size = at.concatenate([size, [p.shape[-1]]])
575570
mode = at.full(output_size, mode)
576571
return mode
577572

@@ -623,8 +618,8 @@ def rng_fn(cls, rng, n, a, size):
623618
size = tuple(size or ())
624619

625620
if size:
626-
n = np.broadcast_to(n, size + n.shape)
627-
a = np.broadcast_to(a, size + a.shape)
621+
n = np.broadcast_to(n, size)
622+
a = np.broadcast_to(a, size + (a.shape[-1],))
628623

629624
res = np.empty(a.shape)
630625
for idx in np.ndindex(a.shape[:-1]):
@@ -688,10 +683,14 @@ def get_moment(rv, size, n, a):
688683
diff = n - at.sum(mode, axis=-1, keepdims=True)
689684
inc_bool_arr = at.abs_(diff) > 0
690685
mode = at.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])
691-
# Reshape mode according to base shape (ignoring size)
692-
mode = at.reshape(mode, rv.shape[size.size :])
686+
687+
# Reshape mode according to dimensions implied by the parameters
688+
# This can include axes of length 1
689+
_, p_bcast = broadcast_params([n, p], ndims_params=[0, 1])
690+
mode = at.reshape(mode, p_bcast.shape)
691+
693692
if not rv_size_is_none(size):
694-
output_size = at.concatenate([size, mode.shape])
693+
output_size = at.concatenate([size, [p.shape[-1]]])
695694
mode = at.full(output_size, mode)
696695
return mode
697696

@@ -2070,7 +2069,7 @@ def make_node(self, rng, size, dtype, mu, W, alpha, tau):
20702069
return super().make_node(rng, size, dtype, mu, W, alpha, tau)
20712070

20722071
def _infer_shape(self, size, dist_params, param_shapes=None):
2073-
shape = tuple(size) + tuple(dist_params[0].shape)
2072+
shape = tuple(size) + (dist_params[0].shape[-1],)
20742073
return shape
20752074

20762075
@classmethod
@@ -2105,7 +2104,7 @@ def rng_fn(cls, rng: np.random.RandomState, mu, W, alpha, tau, size):
21052104

21062105
size = tuple(size or ())
21072106
if size:
2108-
mu = np.broadcast_to(mu, size + mu.shape)
2107+
mu = np.broadcast_to(mu, size + (mu.shape[-1],))
21092108
z = rng.normal(size=mu.shape)
21102109
samples = np.empty(z.shape)
21112110
for idx in np.ndindex(mu.shape[:-1]):
@@ -2165,11 +2164,7 @@ def dist(cls, mu, W, alpha, tau, *args, **kwargs):
21652164
return super().dist([mu, W, alpha, tau], **kwargs)
21662165

21672166
def get_moment(rv, size, mu, W, alpha, tau):
2168-
moment = mu
2169-
if not rv_size_is_none(size):
2170-
moment_size = at.concatenate([size, moment.shape])
2171-
moment = at.full(moment_size, mu)
2172-
return moment
2167+
return at.full_like(rv, mu)
21732168

21742169
def logp(value, mu, W, alpha, tau):
21752170
"""

pymc/distributions/shape_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import numpy as np
2727

2828
from aesara.graph.basic import Constant, Variable
29+
from aesara.graph.op import Op
2930
from aesara.tensor.var import TensorVariable
3031
from typing_extensions import TypeAlias
3132

@@ -614,10 +615,10 @@ def find_size(shape=None, size=None, ndim_supp=None):
614615

615616

616617
def maybe_resize(
617-
rv_out,
618-
rv_op,
618+
rv_out: TensorVariable,
619+
rv_op: Op,
619620
dist_params,
620-
ndim_expected,
621+
ndim_expected: int,
621622
ndim_batch,
622623
ndim_supp,
623624
shape,

pymc/tests/test_distributions.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import numpy.random as nr
2222

2323
from aeppl.logprob import ParameterValueError
24+
from aesara.tensor.random.utils import broadcast_params
2425

2526
from pymc.distributions.continuous import get_tau_sigma
2627
from pymc.util import UNSET
@@ -2130,9 +2131,10 @@ def test_dirichlet_invalid(self):
21302131
(np.abs(np.random.randn(2, 2, 4)) + 1),
21312132
],
21322133
)
2133-
@pytest.mark.parametrize("size", [2, (1, 2), (2, 4, 3)])
2134-
def test_dirichlet_vectorized(self, a, size):
2134+
@pytest.mark.parametrize("extra_size", [(2,), (1, 2), (2, 4, 3)])
2135+
def test_dirichlet_vectorized(self, a, extra_size):
21352136
a = floatX(np.array(a))
2137+
size = extra_size + a.shape[:-1]
21362138

21372139
dir = pm.Dirichlet.dist(a=a, size=size)
21382140
vals = dir.eval()
@@ -2200,12 +2202,15 @@ def test_multinomial_p_not_normalized_symbolic(self):
22002202
(np.abs(np.random.randn(2, 2, 4))),
22012203
],
22022204
)
2203-
@pytest.mark.parametrize("size", [1, 2, (2, 3)])
2204-
def test_multinomial_vectorized(self, n, p, size):
2205+
@pytest.mark.parametrize("extra_size", [(1,), (2,), (2, 3)])
2206+
def test_multinomial_vectorized(self, n, p, extra_size):
22052207
n = intX(np.array(n))
22062208
p = floatX(np.array(p))
22072209
p /= p.sum(axis=-1, keepdims=True)
22082210

2211+
_, bcast_p = broadcast_params([n, p], ndims_params=[0, 1])
2212+
size = extra_size + bcast_p.shape[:-1]
2213+
22092214
mn = pm.Multinomial.dist(n=n, p=p, size=size)
22102215
vals = mn.eval()
22112216

@@ -2269,11 +2274,14 @@ def test_dirichlet_multinomial_matches_beta_binomial(self):
22692274
(np.abs(np.random.randn(2, 2, 4))),
22702275
],
22712276
)
2272-
@pytest.mark.parametrize("size", [1, 2, (2, 3)])
2273-
def test_dirichlet_multinomial_vectorized(self, n, a, size):
2277+
@pytest.mark.parametrize("extra_size", [(1,), (2,), (2, 3)])
2278+
def test_dirichlet_multinomial_vectorized(self, n, a, extra_size):
22742279
n = intX(np.array(n))
22752280
a = floatX(np.array(a))
22762281

2282+
_, bcast_a = broadcast_params([n, a], ndims_params=[0, 1])
2283+
size = extra_size + bcast_a.shape[:-1]
2284+
22772285
dm = pm.DirichletMultinomial.dist(n=n, a=a, size=size)
22782286
vals = dm.eval()
22792287

pymc/tests/test_distributions_moments.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,7 @@ def test_discrete_weibull_moment(q, beta, size, expected):
789789
),
790790
(
791791
np.array([[1, 2, 3], [5, 6, 7]]),
792-
7,
792+
(7, 2),
793793
np.apply_along_axis(
794794
lambda x: np.divide(x, np.array([6, 18])),
795795
1,
@@ -798,7 +798,7 @@ def test_discrete_weibull_moment(q, beta, size, expected):
798798
),
799799
(
800800
np.full(shape=np.array([7, 3]), fill_value=np.array([13, 17, 19])),
801-
(11, 5),
801+
(11, 5, 7),
802802
np.broadcast_to([13, 17, 19], shape=[11, 5, 7, 3]) / 49,
803803
),
804804
],
@@ -940,7 +940,7 @@ def test_interpolated_moment(x_points, pdf_points, size, expected):
940940
(
941941
np.array([[3.0, 5], [1, 4]]),
942942
np.identity(2),
943-
(4, 5),
943+
(4, 5, 2),
944944
np.full((4, 5, 2, 2), [[3.0, 5], [1, 4]]),
945945
),
946946
],
@@ -965,7 +965,7 @@ def test_mv_normal_moment(mu, cov, size, expected):
965965
(np.array([1, 0, 3.0, 4]), (5, 3), np.full((5, 3, 4), [1, 0, 3.0, 4])),
966966
(
967967
np.array([[3.0, 5, 2, 1], [1, 4, 0.5, 9]]),
968-
(4, 5),
968+
(4, 5, 2),
969969
np.full((4, 5, 2, 4), [[3.0, 5, 2, 1], [1, 4, 0.5, 9]]),
970970
),
971971
],
@@ -1008,8 +1008,8 @@ def test_moyal_moment(mu, sigma, size, expected):
10081008
(2, rand1d, np.eye(2), 2, np.full((2, 2), rand1d)),
10091009
(2, rand1d, np.eye(2), (2, 5), np.full((2, 5, 2), rand1d)),
10101010
(2, rand2d, np.eye(3), None, rand2d),
1011-
(2, rand2d, np.eye(3), 2, np.full((2, 2, 3), rand2d)),
1012-
(2, rand2d, np.eye(3), (2, 5), np.full((2, 5, 2, 3), rand2d)),
1011+
(2, rand2d, np.eye(3), (2, 2), np.full((2, 2, 3), rand2d)),
1012+
(2, rand2d, np.eye(3), (2, 5, 2), np.full((2, 5, 2, 3), rand2d)),
10131013
],
10141014
)
10151015
def test_mvstudentt_moment(nu, mu, cov, size, expected):
@@ -1326,7 +1326,7 @@ def test_polyagamma_moment(h, z, size, expected):
13261326
(
13271327
np.array([[0.25, 0.25, 0.25, 0.25], [0.26, 0.26, 0.26, 0.22]]),
13281328
np.array([1, 10]),
1329-
2,
1329+
(2, 2),
13301330
np.full((2, 2, 4), [[1, 0, 0, 0], [2, 3, 3, 2]]),
13311331
),
13321332
],
@@ -1476,7 +1476,7 @@ def test_lkjcholeskycov_moment(n, eta, size, expected):
14761476
(
14771477
np.array([[26, 26, 26, 22]]), # Dim: 1 x 4
14781478
np.array([[1], [10]]), # Dim: 2 x 1
1479-
(2, 1),
1479+
(2, 1, 2, 1),
14801480
np.full(
14811481
(2, 1, 2, 1, 4),
14821482
np.array([[[1, 0, 0, 0]], [[2, 3, 3, 2]]]), # Dim: 2 x 1 x 4

pymc/tests/test_distributions_random.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,8 +1197,8 @@ def test_issue_3706(self):
11971197

11981198
class TestMvStudentTCov(BaseTestDistributionRandom):
11991199
def mvstudentt_rng_fn(self, size, nu, mu, cov, rng):
1200-
chi2_samples = rng.chisquare(nu, size=size)
12011200
mv_samples = rng.multivariate_normal(np.zeros_like(mu), cov, size=size)
1201+
chi2_samples = rng.chisquare(nu, size=size)
12021202
return (mv_samples / np.sqrt(chi2_samples[:, None] / nu)) + mu
12031203

12041204
pymc_dist = pm.MvStudentT
@@ -1345,7 +1345,7 @@ def check_random_draws(self):
13451345
draws = pm.DirichletMultinomial.dist(
13461346
n=np.array([5, 100]),
13471347
a=np.array([[0.001, 0.001, 0.001, 1000], [1000, 1000, 0.001, 0.001]]),
1348-
size=(2, 3),
1348+
size=(2, 3, 2),
13491349
rng=default_rng,
13501350
).eval()
13511351
assert np.all(draws.sum(-1) == np.array([5, 100]))
@@ -1361,7 +1361,7 @@ class TestDirichletMultinomial_1D_n_2D_a(BaseTestDistributionRandom):
13611361
"n": np.array([23, 29]),
13621362
"a": np.array([[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]]),
13631363
}
1364-
sizes_to_check = [None, 1, (4,), (3, 4)]
1364+
sizes_to_check = [None, (1, 2), (4, 2), (3, 4, 2)]
13651365
sizes_expected = [(2, 4), (1, 2, 4), (4, 2, 4), (3, 4, 2, 4)]
13661366
checks_to_run = ["check_rv_size"]
13671367

pymc/tests/test_shape_handling.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
shapes_broadcasting,
3232
to_tuple,
3333
)
34-
from pymc.exceptions import ShapeWarning
3534

3635
test_shapes = [
3736
(tuple(), (1,), (4,), (5, 4)),
@@ -323,7 +322,7 @@ def test_simultaneous_size_and_dims(self, with_dims_ellipsis):
323322
assert "ddata" in pmodel.dim_lengths
324323

325324
# Size does not include support dims, so this test must use a dist with support dims.
326-
kwargs = dict(name="y", size=2, mu=at.ones((3, 4)), cov=at.eye(4))
325+
kwargs = dict(name="y", size=(2, 3), mu=at.ones((3, 4)), cov=at.eye(4))
327326
if with_dims_ellipsis:
328327
y = pm.MvNormal(**kwargs, dims=("dsize", ...))
329328
assert pmodel.RV_dims["y"] == ("dsize", None, None)
@@ -434,17 +433,11 @@ def test_mvnormal_shape_size_difference(self):
434433
assert rv.ndim == 5
435434
assert tuple(rv.shape.eval()) == (6, 5, 4, 3, 2)
436435

437-
with pytest.warns(None):
438-
rv = pm.MvNormal.dist(mu=[1, 2, 3], cov=np.eye(3), size=(5, 4))
439-
assert tuple(rv.shape.eval()) == (5, 4, 3)
436+
rv = pm.MvNormal.dist(mu=[1, 2, 3], cov=np.eye(3), size=(5, 4))
437+
assert tuple(rv.shape.eval()) == (5, 4, 3)
440438

441-
# When using `size` the API behaves like Aesara/NumPy
442-
with pytest.warns(
443-
ShapeWarning,
444-
match=r"You may have expected a \(2\+1\)-dimensional RV, but the resulting RV will be 5-dimensional",
445-
):
446-
rv = pm.MvNormal.dist(mu=np.ones((5, 4, 3)), cov=np.eye(3), size=(5, 4))
447-
assert tuple(rv.shape.eval()) == (5, 4, 5, 4, 3)
439+
rv = pm.MvNormal.dist(mu=np.ones((5, 4, 3)), cov=np.eye(3), size=(5, 4))
440+
assert tuple(rv.shape.eval()) == (5, 4, 3)
448441

449442
def test_convert_dims(self):
450443
assert convert_dims(dims="town") == ("town",)

0 commit comments

Comments
 (0)