Skip to content

Commit c868a84

Browse files
authored
Infer logprob of sum of Normals (#8067)
1 parent cce84d5 commit c868a84

File tree

5 files changed

+168
-0
lines changed

5 files changed

+168
-0
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ jobs:
119119
tests/backends/test_zarr.py
120120
tests/distributions/test_truncated.py
121121
tests/logprob/test_abstract.py
122+
tests/logprob/test_arithmetic.py
122123
tests/logprob/test_basic.py
123124
tests/logprob/test_binary.py
124125
tests/logprob/test_checks.py

pymc/logprob/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
# Add rewrites to the DBs
4949
import pymc.logprob.binary
5050
import pymc.logprob.censoring
51+
import pymc.logprob.arithmetic
5152
import pymc.logprob.cumsum
5253
import pymc.logprob.checks
5354
import pymc.logprob.linalg

pymc/logprob/arithmetic.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2024 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# MIT License
16+
#
17+
# Copyright (c) 2021-2022 aesara-devs
18+
#
19+
# Permission is hereby granted, free of charge, to any person obtaining a copy
20+
# of this software and associated documentation files (the "Software"), to deal
21+
# in the Software without restriction, including without limitation the rights
22+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
23+
# copies of the Software, and to permit persons to whom the Software is
24+
# furnished to do so, subject to the following conditions:
25+
#
26+
# The above copyright notice and this permission notice shall be included in all
27+
# copies or substantial portions of the Software.
28+
#
29+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
30+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
31+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
32+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
33+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
34+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
35+
# SOFTWARE.
36+
"""Measurable rewrites for arithmetic operations."""
37+
38+
from pytensor import tensor as pt
39+
from pytensor.graph.basic import Apply
40+
from pytensor.graph.fg import FunctionGraph
41+
from pytensor.graph.rewriting.basic import node_rewriter
42+
from pytensor.tensor.math import Sum
43+
from pytensor.tensor.random.basic import NormalRV
44+
from pytensor.tensor.type_other import NoneTypeT
45+
from pytensor.tensor.variable import TensorVariable
46+
47+
from pymc.logprob.rewriting import measurable_ir_rewrites_db
48+
49+
50+
@node_rewriter([Sum])
51+
def sum_of_normals(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
52+
[base_var] = node.inputs
53+
if base_var.owner is None:
54+
return None
55+
56+
latent_op = base_var.owner.op
57+
if not isinstance(latent_op, NormalRV):
58+
return None
59+
60+
rng, size, mu, sigma = base_var.owner.inputs
61+
62+
if isinstance(size.type, NoneTypeT):
63+
mu_b, sigma_b = pt.broadcast_arrays(mu, sigma)
64+
else:
65+
mu_b = pt.broadcast_to(mu, size) # type: ignore[arg-type]
66+
sigma_b = pt.broadcast_to(sigma, size) # type: ignore[arg-type]
67+
68+
axis = node.op.axis
69+
mu_sum = pt.sum(mu_b, axis=axis)
70+
sigma_sum = pt.sqrt(pt.sum(pt.square(sigma_b), axis=axis))
71+
72+
sum_rv = latent_op(mu_sum, sigma_sum, rng=rng, size=None)
73+
return [sum_rv]
74+
75+
76+
measurable_ir_rewrites_db.register(
77+
"sum_of_normals",
78+
sum_of_normals,
79+
"basic",
80+
"arithmetic",
81+
)

tests/logprob/test_arithmetic.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2024 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# MIT License
16+
#
17+
# Copyright (c) 2021-2022 aesara-devs
18+
#
19+
# Permission is hereby granted, free of charge, to any person obtaining a copy
20+
# of this software and associated documentation files (the "Software"), to deal
21+
# in the Software without restriction, including without limitation the rights
22+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
23+
# copies of the Software, and to permit persons to whom the Software is
24+
# furnished to do so, subject to the following conditions:
25+
#
26+
# The above copyright notice and this permission notice shall be included in all
27+
# copies or substantial portions of the Software.
28+
#
29+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
30+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
31+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
32+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
33+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
34+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
35+
# SOFTWARE.
36+
37+
import numpy as np
38+
import pytest
39+
40+
from pytensor import tensor as pt
41+
42+
from pymc.logprob.basic import logp
43+
44+
45+
@pytest.mark.parametrize("axis", (None, 0))
46+
def test_sum_of_normals_logprob(axis):
47+
mu = pt.constant([[1.0, 2.0, 3.0], [0.5, 1.5, 2.5]])
48+
sigma = pt.constant([[1.0, 2.0, 3.0], [1.5, 2.5, 3.5]])
49+
50+
x_rv = pt.random.normal(mu, sigma, name="x")
51+
x_sum = pt.sum(x_rv, axis=axis)
52+
x_sum_vv = pt.scalar("x_sum")
53+
54+
sum_logp = logp(x_sum, x_sum_vv)
55+
56+
ref_mu = pt.sum(mu, axis=axis)
57+
ref_sigma = pt.sqrt(pt.sum(pt.square(sigma), axis=axis))
58+
ref_rv = pt.random.normal(ref_mu, ref_sigma, name="ref")
59+
ref_vv = pt.scalar("ref_vv")
60+
ref_logp = logp(ref_rv, ref_vv)
61+
62+
test_val = 0.5
63+
np.testing.assert_allclose(
64+
sum_logp.eval({x_sum_vv: test_val}),
65+
ref_logp.eval({ref_vv: test_val}),
66+
)

tests/model/transform/test_conditioning.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,25 @@ def test_observe_deterministic():
108108
pm.Censored("y_censored", pm.Normal.dist(x), lower=-1, upper=1, observed=y_censored_obs)
109109

110110

111+
def test_observe_sum_normal():
112+
with pm.Model() as m_old:
113+
x = pm.Normal("x")
114+
y = pm.Normal.dist(mu=x, sigma=1.0, shape=(5,))
115+
y_sum = pm.Deterministic("y_sum", pm.math.sum(y))
116+
117+
m_new = observe(m_old, {y_sum: 2.0})
118+
119+
with pm.Model() as m_ref:
120+
x = pm.Normal("x")
121+
pm.Normal("y_sum", mu=5.0 * x, sigma=np.sqrt(5.0), observed=2.0)
122+
123+
test_point = {"x": 0.3}
124+
np.testing.assert_allclose(
125+
m_new.compile_logp()(test_point),
126+
m_ref.compile_logp()(test_point),
127+
)
128+
129+
111130
def test_observe_dims():
112131
with pm.Model(coords={"test_dim": range(5)}) as m_old:
113132
x = pm.Normal("x", dims="test_dim")

0 commit comments

Comments
 (0)