Skip to content

Commit 6ae7628

Browse files
canyon289ricardoV94
authored andcommitted
Infer steps from shape in GaussianRandomWalk
1 parent c80657e commit 6ae7628

File tree

2 files changed

+37
-3
lines changed

2 files changed

+37
-3
lines changed

pymc/distributions/timeseries.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919

2020
from aesara import scan
21+
from aesara.raise_op import Assert
2122
from aesara.tensor.random.op import RandomVariable
2223
from aesara.tensor.random.utils import normalize_size_param
2324

@@ -167,8 +168,26 @@ def dist(
167168

168169
mu = at.as_tensor_variable(floatX(mu))
169170
sigma = at.as_tensor_variable(floatX(sigma))
171+
172+
# Check if shape contains information about number of steps
173+
steps_from_shape = None
174+
shape = kwargs.get("shape", None)
175+
if shape is not None:
176+
shape = to_tuple(shape)
177+
if shape[-1] is not ...:
178+
steps_from_shape = shape[-1] - 1
179+
170180
if steps is None:
171-
raise ValueError("Must specify steps parameter")
181+
if steps_from_shape is not None:
182+
steps = steps_from_shape
183+
else:
184+
raise ValueError("Must specify steps or shape parameter")
185+
elif steps_from_shape is not None:
186+
# Assert that steps and shape are consistent
187+
steps = Assert(msg="Steps do not match last shape dimension")(
188+
steps, at.eq(steps, steps_from_shape)
189+
)
190+
172191
steps = at.as_tensor_variable(intX(steps))
173192

174193
# If no scalar distribution is passed then initialize with a Normal of same mu and sigma

pymc/tests/test_distributions_timeseries.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,10 @@ def test_gaussian_random_walk_init_dist_shape(self, init):
8686
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init, size=(5,))
8787
assert tuple(grw.owner.inputs[-2].shape.eval()) == (5,)
8888

89-
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init, shape=1)
89+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init, shape=2)
9090
assert tuple(grw.owner.inputs[-2].shape.eval()) == ()
9191

92-
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init, shape=(5, 1))
92+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init, shape=(5, 2))
9393
assert tuple(grw.owner.inputs[-2].shape.eval()) == (5,)
9494

9595
grw = pm.GaussianRandomWalk.dist(mu=[0, 0], sigma=1, steps=1, init=init)
@@ -113,6 +113,21 @@ def test_gaussianrandomwalk_broadcasted_by_init_dist(self):
113113
assert tuple(grw.shape.eval()) == (2, 3, 5)
114114
assert grw.eval().shape == (2, 3, 5)
115115

116+
@pytest.mark.parametrize("shape", ((6,), (3, 6)))
117+
def test_inferred_steps_from_shape(self, shape):
118+
x = GaussianRandomWalk.dist(shape=shape)
119+
steps = x.owner.inputs[-1]
120+
assert steps.eval() == 5
121+
122+
@pytest.mark.parametrize("shape", (None, (5, ...)))
123+
def test_missing_steps(self, shape):
124+
with pytest.raises(ValueError, match="Must specify steps or shape parameter"):
125+
GaussianRandomWalk.dist(shape=shape)
126+
127+
def test_inconsistent_steps_and_shape(self):
128+
with pytest.raises(AssertionError, match="Steps do not match last shape dimension"):
129+
x = GaussianRandomWalk.dist(steps=12, shape=45)
130+
116131
@pytest.mark.parametrize(
117132
"init",
118133
[

0 commit comments

Comments
 (0)