Skip to content

Commit 46cb30a

Browse files
committed
Added dispatch function for representative point
1 parent 389f818 commit 46cb30a

File tree

3 files changed

+47
-0
lines changed

3 files changed

+47
-0
lines changed

pymc3/distributions/continuous.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,9 @@ def dist(cls, *, size=None, **kwargs):
366366
res.tag.test_value = np.full(size, floatX(0.0))
367367
return res
368368

369+
def get_moment(rv, size, *rv_inputs) -> np.ndarray:
370+
return at.zeros(size, dtype=aesara.config.floatX)
371+
369372
def logp(value):
370373
"""
371374
Calculate log-probability of Flat distribution at specified value.
@@ -428,6 +431,9 @@ def dist(cls, *, size=None, **kwargs):
428431
res.tag.test_value = np.full(size, floatX(1.0))
429432
return res
430433

434+
def get_moment(value_var, size, *rv_inputs) -> np.ndarray:
435+
return at.ones(size, dtype=aesara.config.floatX)
436+
431437
def logp(value):
432438
"""
433439
Calculate log-probability of HalfFlat distribution at specified value.

pymc3/distributions/distribution.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
import warnings
2020

2121
from abc import ABCMeta
22+
from functools import singledispatch
2223
from typing import Optional
2324

2425
import aesara
2526
import aesara.tensor as at
2627

2728
from aesara.tensor.random.op import RandomVariable
2829
from aesara.tensor.random.var import RandomStateSharedVariable
30+
from aesara.tensor.var import TensorVariable
2931

3032
from pymc3.aesaraf import change_rv_size
3133
from pymc3.distributions import _logcdf, _logp
@@ -107,6 +109,13 @@ def logcdf(op, var, rvs_to_values, *dist_params, **kwargs):
107109
value_var = rvs_to_values.get(var, var)
108110
return class_logcdf(value_var, *dist_params, **kwargs)
109111

112+
class_initval = clsdict.get("get_moment")
113+
if class_initval:
114+
115+
@_get_moment.register(rv_type)
116+
def get_moment(op, rv, size, *rv_inputs):
117+
return class_initval(rv, size, *rv_inputs)
118+
110119
# Register the Aesara `RandomVariable` type as a subclass of this
111120
# `Distribution` type.
112121
new_cls.register(rv_type)
@@ -328,6 +337,24 @@ def dist(
328337
return rv_out
329338

330339

340+
@singledispatch
341+
def _get_moment(op, rv, size, *rv_inputs) -> TensorVariable:
342+
"""Fallback method for creating an initial value for a random variable.
343+
344+
Parameters are the same as for the `.dist()` method.
345+
"""
346+
return None
347+
348+
349+
def get_moment(rv: TensorVariable) -> TensorVariable:
350+
"""Fallback method for creating an initial value for a random variable.
351+
352+
Parameters are the same as for the `.dist()` method.
353+
"""
354+
size = rv.owner.inputs[1]
355+
return _get_moment(rv.owner.op, rv, size, *rv.owner.inputs[3:])
356+
357+
331358
class NoDistribution(Distribution):
332359
def __init__(
333360
self,

pymc3/tests/test_initvals.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import pymc3 as pm
1818

19+
from pymc3.distributions.distribution import get_moment
20+
1921

2022
def transform_fwd(rv, expected_untransformed):
2123
return rv.tag.value_var.tag.transform.forward(rv, expected_untransformed).eval()
@@ -89,3 +91,15 @@ def test_automatically_assigned_test_values(self):
8991
rv = pm.HalfFlat.dist()
9092
assert hasattr(rv.tag, "test_value")
9193
pass
94+
95+
96+
class TestMoment:
97+
def test_basic(self):
98+
rv = pm.Flat.dist()
99+
assert get_moment(rv).eval() == np.zeros(())
100+
rv = pm.HalfFlat.dist()
101+
assert get_moment(rv).eval() == np.ones(())
102+
rv = pm.Flat.dist(size=(2, 4))
103+
assert np.all(get_moment(rv).eval() == np.zeros((2, 4)))
104+
rv = pm.HalfFlat.dist(size=(2, 4))
105+
assert np.all(get_moment(rv).eval() == np.ones((2, 4)))

0 commit comments

Comments
 (0)