diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index cbb2371368..d5d5dd39ed 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -336,6 +336,11 @@ def logcdf(value, lower, upper): ), ) + def get_moment(value, size, lower, upper): + lower = at.full(size, lower, dtype=aesara.config.floatX) + upper = at.full(size, upper, dtype=aesara.config.floatX) + return (lower + upper) / 2 + class FlatRV(RandomVariable): name = "flat" @@ -366,7 +371,7 @@ def dist(cls, *, size=None, **kwargs): res.tag.test_value = np.full(size, floatX(0.0)) return res - def get_moment(rv, size, *rv_inputs) -> np.ndarray: + def get_moment(rv, size, *rv_inputs): return at.zeros(size, dtype=aesara.config.floatX) def logp(value): @@ -431,7 +436,7 @@ def dist(cls, *, size=None, **kwargs): res.tag.test_value = np.full(size, floatX(1.0)) return res - def get_moment(value_var, size, *rv_inputs) -> np.ndarray: + def get_moment(value_var, size, *rv_inputs): return at.ones(size, dtype=aesara.config.floatX) def logp(value): @@ -588,6 +593,9 @@ def logcdf(value, mu, sigma): 0 < sigma, ) + def get_moment(value_var, size, mu, sigma): + return at.full(size, mu, dtype=aesara.config.floatX) + class TruncatedNormalRV(RandomVariable): name = "truncated_normal" diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index d3269f6d28..2f2714edd6 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -394,6 +394,10 @@ def logcdf(value, p): p <= 1, ) + def get_moment(value, size, p): + p = at.full(size, p) + return at.switch(p < 0.5, at.zeros_like(value), at.ones_like(value)) + def _distr_parameters_for_repr(self): return ["p"] diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 2ffa092cd6..5f7cd7bc12 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -351,7 +351,9 @@ def dist( @singledispatch def _get_moment(op, rv, size, *rv_inputs) -> TensorVariable: - return None + raise NotImplementedError( + f"Random variable {rv} of type {op} has no get_moment implementation." + ) def get_moment(rv: TensorVariable) -> TensorVariable: diff --git a/pymc/tests/test_initvals.py b/pymc/tests/test_initvals.py index 9ebfb98e0d..6b4ef717a4 100644 --- a/pymc/tests/test_initvals.py +++ b/pymc/tests/test_initvals.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import aesara.tensor as at import numpy as np import pytest @@ -95,6 +96,11 @@ def test_automatically_assigned_test_values(self): class TestMoment: def test_basic(self): + # Standard distributions + rv = pm.Normal.dist(mu=2.3) + np.testing.assert_allclose(get_moment(rv).eval(), 2.3) + + # Special distributions rv = pm.Flat.dist() assert get_moment(rv).eval() == np.zeros(()) rv = pm.HalfFlat.dist() @@ -103,3 +109,33 @@ def test_basic(self): assert np.all(get_moment(rv).eval() == np.zeros((2, 4))) rv = pm.HalfFlat.dist(size=(2, 4)) assert np.all(get_moment(rv).eval() == np.ones((2, 4))) + + @pytest.mark.xfail(reason="Test values are still used for initvals.") + @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) + def test_numeric_moment_shape(self, rv_cls): + rv = rv_cls.dist(shape=(2,)) + assert not hasattr(rv.tag, "test_value") + assert tuple(get_moment(rv).shape.eval()) == (2,) + + @pytest.mark.xfail(reason="Test values are still used for initvals.") + @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) + def test_symbolic_moment_shape(self, rv_cls): + s = at.scalar() + rv = rv_cls.dist(shape=(s,)) + assert not hasattr(rv.tag, "test_value") + assert tuple(get_moment(rv).shape.eval({s: 4})) == (4,) + pass + + @pytest.mark.xfail(reason="Test values are still used for initvals.") + @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) + def test_moment_from_dims(self, rv_cls): + with pm.Model( + coords={ + "year": [2019, 2020, 2021, 2022], + "city": ["Bonn", "Paris", "Lisbon"], + } + ): + rv = rv_cls("rv", dims=("year", "city")) + assert not hasattr(rv.tag, "test_value") + assert tuple(get_moment(rv).shape.eval()) == (4, 3) + pass