Skip to content

Commit bca2b1e

Browse files
Chirag3841ricardoV94
authored andcommitted
Fix failure in change_dist_size when RVs don't take dtype argument
1 parent 0df7824 commit bca2b1e

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

pymc/distributions/shape_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,7 @@ def change_rv_size(op, rv, new_size, expand) -> TensorVariable:
305305
# to not unnecessarily pick up a `Cast` in some cases (see #4652).
306306
new_size = pt.as_tensor(new_size, ndim=1, dtype="int64")
307307

308-
new_rv = rv_node.op(*dist_params, size=new_size, dtype=rv.type.dtype)
309-
308+
new_rv = rv_node.op(*dist_params, size=new_size)
310309
# Replicate "traditional" rng default_update, if that was set for old_rng
311310
default_update = getattr(old_rng, "default_update", None)
312311
if default_update is not None:

tests/sampling/test_forward.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import pymc as pm
3636

3737
from pymc.backends.base import MultiTrace
38+
from pymc.distributions.shape_utils import change_dist_size
3839
from pymc.model.transform.optimization import freeze_dims_and_data
3940
from pymc.pytensorf import compile, rvs_in_graph
4041
from pymc.sampling.forward import (
@@ -1979,3 +1980,12 @@ def test_vectorize_over_posterior_with_intermediate_rvs():
19791980
assert np.array_equiv(a_ancestor1.eval(), idata.posterior.a.data)
19801981
assert isinstance(a_ancestor2, TensorConstant)
19811982
assert np.array_equiv(a_ancestor2.eval(), idata.posterior.a.data)
1983+
1984+
1985+
def test_change_dist_size_zero_sum_normal():
1986+
with pm.Model():
1987+
intercept = pm.ZeroSumNormal("intercept", sigma=1.0, shape=2)
1988+
1989+
resized = change_dist_size(intercept, new_size=(10,), expand=True)
1990+
1991+
assert resized.type.shape == (10, 2)

0 commit comments

Comments
 (0)