From 46cb30a4653b6c07027c57c6f57d83f204d62e69 Mon Sep 17 00:00:00 2001 From: kc611 Date: Thu, 2 Sep 2021 14:52:41 +0530 Subject: [PATCH] Added dispatch function for representative point --- pymc3/distributions/continuous.py | 6 ++++++ pymc3/distributions/distribution.py | 27 +++++++++++++++++++++++++++ pymc3/tests/test_initvals.py | 14 ++++++++++++++ 3 files changed, 47 insertions(+) diff --git a/pymc3/distributions/continuous.py b/pymc3/distributions/continuous.py index b9c5200258..2e750784ee 100644 --- a/pymc3/distributions/continuous.py +++ b/pymc3/distributions/continuous.py @@ -366,6 +366,9 @@ def dist(cls, *, size=None, **kwargs): res.tag.test_value = np.full(size, floatX(0.0)) return res + def get_moment(rv, size, *rv_inputs) -> np.ndarray: + return at.zeros(size, dtype=aesara.config.floatX) + def logp(value): """ Calculate log-probability of Flat distribution at specified value. @@ -428,6 +431,9 @@ def dist(cls, *, size=None, **kwargs): res.tag.test_value = np.full(size, floatX(1.0)) return res + def get_moment(value_var, size, *rv_inputs) -> np.ndarray: + return at.ones(size, dtype=aesara.config.floatX) + def logp(value): """ Calculate log-probability of HalfFlat distribution at specified value. diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 1d15c46deb..582fa6b2f8 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -19,6 +19,7 @@ import warnings from abc import ABCMeta +from functools import singledispatch from typing import Optional import aesara @@ -26,6 +27,7 @@ from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.var import RandomStateSharedVariable +from aesara.tensor.var import TensorVariable from pymc3.aesaraf import change_rv_size from pymc3.distributions import _logcdf, _logp @@ -107,6 +109,13 @@ def logcdf(op, var, rvs_to_values, *dist_params, **kwargs): value_var = rvs_to_values.get(var, var) return class_logcdf(value_var, *dist_params, **kwargs) + class_initval = clsdict.get("get_moment") + if class_initval: + + @_get_moment.register(rv_type) + def get_moment(op, rv, size, *rv_inputs): + return class_initval(rv, size, *rv_inputs) + # Register the Aesara `RandomVariable` type as a subclass of this # `Distribution` type. new_cls.register(rv_type) @@ -328,6 +337,24 @@ def dist( return rv_out +@singledispatch +def _get_moment(op, rv, size, *rv_inputs) -> TensorVariable: + """Fallback method for creating an initial value for a random variable. + + Parameters are the same as for the `.dist()` method. + """ + return None + + +def get_moment(rv: TensorVariable) -> TensorVariable: + """Fallback method for creating an initial value for a random variable. + + Parameters are the same as for the `.dist()` method. + """ + size = rv.owner.inputs[1] + return _get_moment(rv.owner.op, rv, size, *rv.owner.inputs[3:]) + + class NoDistribution(Distribution): def __init__( self, diff --git a/pymc3/tests/test_initvals.py b/pymc3/tests/test_initvals.py index 51bd52028e..0e9232681d 100644 --- a/pymc3/tests/test_initvals.py +++ b/pymc3/tests/test_initvals.py @@ -16,6 +16,8 @@ import pymc3 as pm +from pymc3.distributions.distribution import get_moment + def transform_fwd(rv, expected_untransformed): return rv.tag.value_var.tag.transform.forward(rv, expected_untransformed).eval() @@ -89,3 +91,15 @@ def test_automatically_assigned_test_values(self): rv = pm.HalfFlat.dist() assert hasattr(rv.tag, "test_value") pass + + +class TestMoment: + def test_basic(self): + rv = pm.Flat.dist() + assert get_moment(rv).eval() == np.zeros(()) + rv = pm.HalfFlat.dist() + assert get_moment(rv).eval() == np.ones(()) + rv = pm.Flat.dist(size=(2, 4)) + assert np.all(get_moment(rv).eval() == np.zeros((2, 4))) + rv = pm.HalfFlat.dist(size=(2, 4)) + assert np.all(get_moment(rv).eval() == np.ones((2, 4)))