Skip to content

Add support for symbolic initval using a singledispatch approach #4912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ def dist(cls, *, size=None, **kwargs):
res.tag.test_value = np.full(size, floatX(0.0))
return res

def get_moment(rv, size, *rv_inputs) -> np.ndarray:
return at.zeros(size, dtype=aesara.config.floatX)

def logp(value):
"""
Calculate log-probability of Flat distribution at specified value.
Expand Down Expand Up @@ -428,6 +431,9 @@ def dist(cls, *, size=None, **kwargs):
res.tag.test_value = np.full(size, floatX(1.0))
return res

def get_moment(value_var, size, *rv_inputs) -> np.ndarray:
return at.ones(size, dtype=aesara.config.floatX)

def logp(value):
"""
Calculate log-probability of HalfFlat distribution at specified value.
Expand Down
27 changes: 27 additions & 0 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
import warnings

from abc import ABCMeta
from functools import singledispatch
from typing import Optional

import aesara
import aesara.tensor as at

from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.var import RandomStateSharedVariable
from aesara.tensor.var import TensorVariable

from pymc3.aesaraf import change_rv_size
from pymc3.distributions import _logcdf, _logp
Expand Down Expand Up @@ -107,6 +109,13 @@ def logcdf(op, var, rvs_to_values, *dist_params, **kwargs):
value_var = rvs_to_values.get(var, var)
return class_logcdf(value_var, *dist_params, **kwargs)

class_initval = clsdict.get("get_moment")
if class_initval:

@_get_moment.register(rv_type)
def get_moment(op, rv, size, *rv_inputs):
return class_initval(rv, size, *rv_inputs)

# Register the Aesara `RandomVariable` type as a subclass of this
# `Distribution` type.
new_cls.register(rv_type)
Expand Down Expand Up @@ -328,6 +337,24 @@ def dist(
return rv_out


@singledispatch
def _get_moment(op, rv, size, *rv_inputs) -> TensorVariable:
"""Fallback method for creating an initial value for a random variable.

Parameters are the same as for the `.dist()` method.
"""
return None


def get_moment(rv: TensorVariable) -> TensorVariable:
"""Fallback method for creating an initial value for a random variable.

Parameters are the same as for the `.dist()` method.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this was merged already, but this part of the docstrings is wrong

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, we should fix that then. CC @kc611

Copy link
Contributor Author

@kc611 kc611 Sep 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes I missed that, those docstrings were supposed to be removed.

I'm not sure what (docstring) will go in it's place though. Maybe I should just remove them for now ? We can add a proper explanation when we give the get_moment a proper entry point in the initval framework (if that's being planned)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, then just remove them for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did it in #4979

"""
size = rv.owner.inputs[1]
return _get_moment(rv.owner.op, rv, size, *rv.owner.inputs[3:])


class NoDistribution(Distribution):
def __init__(
self,
Expand Down
14 changes: 14 additions & 0 deletions pymc3/tests/test_initvals.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import pymc3 as pm

from pymc3.distributions.distribution import get_moment


def transform_fwd(rv, expected_untransformed):
return rv.tag.value_var.tag.transform.forward(rv, expected_untransformed).eval()
Expand Down Expand Up @@ -89,3 +91,15 @@ def test_automatically_assigned_test_values(self):
rv = pm.HalfFlat.dist()
assert hasattr(rv.tag, "test_value")
pass


class TestMoment:
def test_basic(self):
rv = pm.Flat.dist()
assert get_moment(rv).eval() == np.zeros(())
rv = pm.HalfFlat.dist()
assert get_moment(rv).eval() == np.ones(())
rv = pm.Flat.dist(size=(2, 4))
assert np.all(get_moment(rv).eval() == np.zeros((2, 4)))
rv = pm.HalfFlat.dist(size=(2, 4))
assert np.all(get_moment(rv).eval() == np.ones((2, 4)))