Skip to content

Commit 255c8e9

Browse files
ricardoV94fonnesbeck
authored andcommitted
Make AR steps extend shape beyond initial_dist
This is consistent with the meaning of steps in the GaussianRandomWalk and translates directly to the number of scan steps taken
1 parent be6fa5c commit 255c8e9

File tree

2 files changed

+48
-30
lines changed

2 files changed

+48
-30
lines changed

pymc/distributions/timeseries.py

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -393,14 +393,19 @@ class AR(SymbolicDistribution):
393393
394394
"""
395395

396-
def __new__(cls, *args, steps=None, **kwargs):
396+
def __new__(cls, name, rho, *args, steps=None, constant=False, ar_order=None, **kwargs):
397+
rhos = at.atleast_1d(at.as_tensor_variable(floatX(rho)))
398+
ar_order = cls._get_ar_order(rhos=rhos, constant=constant, ar_order=ar_order)
397399
steps = get_steps(
398400
steps=steps,
399401
shape=None, # Shape will be checked in `cls.dist`
400402
dims=kwargs.get("dims", None),
401403
observed=kwargs.get("observed", None),
404+
step_shape_offset=ar_order,
405+
)
406+
return super().__new__(
407+
cls, name, rhos, *args, steps=steps, constant=constant, ar_order=ar_order, **kwargs
402408
)
403-
return super().__new__(cls, *args, steps=steps, **kwargs)
404409

405410
@classmethod
406411
def dist(
@@ -426,34 +431,12 @@ def dist(
426431
)
427432
init_dist = kwargs["init"]
428433

429-
steps = get_steps(steps=steps, shape=kwargs.get("shape", None))
434+
ar_order = cls._get_ar_order(rhos=rhos, constant=constant, ar_order=ar_order)
435+
steps = get_steps(steps=steps, shape=kwargs.get("shape", None), step_shape_offset=ar_order)
430436
if steps is None:
431437
raise ValueError("Must specify steps or shape parameter")
432438
steps = at.as_tensor_variable(intX(steps), ndim=0)
433439

434-
if ar_order is None:
435-
# If ar_order is not specified we do constant folding on the shape of rhos
436-
# to retrieve it. For example, this will detect that
437-
# Normal(size=(5, 3)).shape[-1] == 3, which is not known by Aesara before.
438-
shape_fg = FunctionGraph(
439-
outputs=[rhos.shape[-1]],
440-
features=[ShapeFeature()],
441-
clone=True,
442-
)
443-
(folded_shape,) = optimize_graph(shape_fg, custom_opt=topo_constant_folding).outputs
444-
folded_shape = getattr(folded_shape, "data", None)
445-
if folded_shape is None:
446-
raise ValueError(
447-
"Could not infer ar_order from last dimension of rho. Pass it "
448-
"explictily or make sure rho have a static shape"
449-
)
450-
ar_order = int(folded_shape) - int(constant)
451-
if ar_order < 1:
452-
raise ValueError(
453-
"Inferred ar_order is smaller than 1. Increase the last dimension "
454-
"of rho or remove constant_term"
455-
)
456-
457440
if init_dist is not None:
458441
if not isinstance(init_dist, TensorVariable) or not isinstance(
459442
init_dist.owner.op, RandomVariable
@@ -477,6 +460,41 @@ def dist(
477460

478461
return super().dist([rhos, sigma, init_dist, steps, ar_order, constant], **kwargs)
479462

463+
@classmethod
464+
def _get_ar_order(cls, rhos: TensorVariable, ar_order: Optional[int], constant: bool) -> int:
465+
"""Compute ar_order given inputs
466+
467+
If ar_order is not specified we do constant folding on the shape of rhos
468+
to retrieve it. For example, this will detect that
469+
Normal(size=(5, 3)).shape[-1] == 3, which is not known by Aesara before.
470+
471+
Raises
472+
------
473+
ValueError
474+
If inferred ar_order cannot be inferred from rhos or if it is less than 1
475+
"""
476+
if ar_order is None:
477+
shape_fg = FunctionGraph(
478+
outputs=[rhos.shape[-1]],
479+
features=[ShapeFeature()],
480+
clone=True,
481+
)
482+
(folded_shape,) = optimize_graph(shape_fg, custom_opt=topo_constant_folding).outputs
483+
folded_shape = getattr(folded_shape, "data", None)
484+
if folded_shape is None:
485+
raise ValueError(
486+
"Could not infer ar_order from last dimension of rho. Pass it "
487+
"explictily or make sure rho have a static shape"
488+
)
489+
ar_order = int(folded_shape) - int(constant)
490+
if ar_order < 1:
491+
raise ValueError(
492+
"Inferred ar_order is smaller than 1. Increase the last dimension "
493+
"of rho or remove constant_term"
494+
)
495+
496+
return ar_order
497+
480498
@classmethod
481499
def num_rngs(cls, *args, **kwargs):
482500
return 2
@@ -540,7 +558,7 @@ def step(*args):
540558
fn=step,
541559
outputs_info=[{"initial": init_.T, "taps": range(-ar_order, 0)}],
542560
non_sequences=[rhos_bcast_.T[::-1], sigma_.T, noise_rng],
543-
n_steps=at.max((0, steps_ - ar_order)),
561+
n_steps=steps_,
544562
strict=True,
545563
)
546564
(noise_next_rng,) = tuple(innov_updates_.values())

pymc/tests/test_distributions_timeseries.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def test_batched_sigma(self):
363363
beta_tp.set_value(np.zeros((ar_order,))) # Should always be close to zero
364364
sigma_tp = np.full(batch_size, [0.01, 0.1, 1, 10, 100])
365365
y_eval = t0["y"].eval({t0["sigma"]: sigma_tp})
366-
assert y_eval.shape == (*batch_size, steps)
366+
assert y_eval.shape == (*batch_size, steps + ar_order)
367367
assert np.allclose(y_eval.std(axis=(0, 2)), [0.01, 0.1, 1, 10, 100], rtol=0.1)
368368

369369
def test_batched_init_dist(self):
@@ -389,7 +389,7 @@ def test_batched_init_dist(self):
389389
init_dist = t0["y"].owner.inputs[2]
390390
init_dist_tp = np.full((batch_size, ar_order), (np.arange(batch_size) * 100)[:, None])
391391
y_eval = t0["y"].eval({init_dist: init_dist_tp})
392-
assert y_eval.shape == (batch_size, steps)
392+
assert y_eval.shape == (batch_size, steps + ar_order)
393393
assert np.allclose(
394394
y_eval[:, -10:].mean(-1), np.arange(batch_size) * 100, rtol=0.1, atol=0.5
395395
)
@@ -429,7 +429,7 @@ def test_multivariate_init_dist(self):
429429
def test_moment(self, size, expected):
430430
with Model() as model:
431431
init_dist = Constant.dist([[1.0, 2.0], [3.0, 4.0]])
432-
AR("x", rho=[0, 0], init_dist=init_dist, steps=7, size=size)
432+
AR("x", rho=[0, 0], init_dist=init_dist, steps=5, size=size)
433433
assert_moment_is_expected(model, expected, check_finite_logp=False)
434434

435435

0 commit comments

Comments
 (0)