Skip to content

Commit d02707e

Browse files
committed
Move GaussianRandomWalk init resizing to Op.make_node
1 parent 63aa44c commit d02707e

File tree

2 files changed

+19
-15
lines changed

2 files changed

+19
-15
lines changed

pymc/distributions/timeseries.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@
1919

2020
from aesara import scan
2121
from aesara.tensor.random.op import RandomVariable
22+
from aesara.tensor.random.utils import normalize_size_param
2223

2324
from pymc.aesaraf import change_rv_size, floatX, intX
2425
from pymc.distributions import distribution, logprob, multivariate
2526
from pymc.distributions.continuous import Flat, Normal, get_tau_sigma
2627
from pymc.distributions.dist_math import check_parameters
27-
from pymc.distributions.shape_utils import to_tuple
28+
from pymc.distributions.shape_utils import rv_size_is_none, to_tuple
2829
from pymc.util import check_dist_not_registered
2930

3031
__all__ = [
@@ -54,6 +55,16 @@ def make_node(self, rng, size, dtype, mu, sigma, init, steps):
5455
if not steps.ndim == 0 or not steps.dtype.startswith("int"):
5556
raise ValueError("steps must be an integer scalar (ndim=0).")
5657

58+
mu = at.as_tensor_variable(mu)
59+
sigma = at.as_tensor_variable(sigma)
60+
init = at.as_tensor_variable(init)
61+
62+
# Resize init distribution
63+
size = normalize_size_param(size)
64+
# If not explicit, size is determined by the shapes of mu, sigma, and init
65+
init_size = size if not rv_size_is_none(size) else at.broadcast_shape(mu, sigma, init)
66+
init = change_rv_size(init, init_size)
67+
5768
return super().make_node(rng, size, dtype, mu, sigma, init, steps)
5869

5970
def _supp_shape_from_params(self, dist_params, reop_param_idx=0, param_shapes=None):
@@ -160,15 +171,9 @@ def dist(
160171
raise ValueError("Must specify steps parameter")
161172
steps = at.as_tensor_variable(intX(steps))
162173

163-
shape = kwargs.get("shape", None)
164-
if size is None and shape is None:
165-
init_size = None
166-
else:
167-
init_size = to_tuple(size) if size is not None else to_tuple(shape)[:-1]
168-
169174
# If no scalar distribution is passed then initialize with a Normal of same mu and sigma
170175
if init is None:
171-
init = Normal.dist(mu, sigma, size=init_size)
176+
init = Normal.dist(mu, sigma)
172177
else:
173178
if not (
174179
isinstance(init, at.TensorVariable)
@@ -178,13 +183,6 @@ def dist(
178183
):
179184
raise TypeError("init must be a univariate distribution variable")
180185

181-
if init_size is not None:
182-
init = change_rv_size(init, init_size)
183-
else:
184-
# If not explicit, size is determined by the shapes of mu, sigma, and init
185-
bcast_shape = at.broadcast_arrays(mu, sigma, init)[0].shape
186-
init = change_rv_size(init, bcast_shape)
187-
188186
# Ignores logprob of init var because that's accounted for in the logp method
189187
init.tag.ignore_logprob = True
190188

pymc/tests/test_distributions_timeseries.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@ def test_gaussian_random_walk_init_dist_shape(init):
103103
assert tuple(grw.owner.inputs[-2].shape.eval()) == (3, 2)
104104

105105

106+
def test_shape_ellipsis():
107+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=5, init=pm.Normal.dist(), shape=(3, ...))
108+
assert tuple(grw.shape.eval()) == (3, 6)
109+
assert tuple(grw.owner.inputs[-2].shape.eval()) == (3,)
110+
111+
106112
def test_gaussianrandomwalk_broadcasted_by_init_dist():
107113
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=4, init=pm.Normal.dist(size=(2, 3)))
108114
assert tuple(grw.shape.eval()) == (2, 3, 5)

0 commit comments

Comments
 (0)