Skip to content

Commit 5b68edc

Browse files
ricardoV94twiecki
authored andcommitted
Fix nested and single output IfElse logp
1 parent a30e0d4 commit 5b68edc

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

pymc/logprob/mixture.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -547,4 +547,7 @@ def logprob_ifelse(op, values, if_var, *base_rvs, **kwargs):
547547
logps_then = replace_rvs_by_values(logps_then, rvs_to_values=rvs_to_values_then)
548548
logps_else = replace_rvs_by_values(logps_else, rvs_to_values=rvs_to_values_else)
549549

550-
return ifelse(if_var, logps_then, logps_else)
550+
logps = ifelse(if_var, logps_then, logps_else)
551+
if len(logps) == 1:
552+
return logps[0]
553+
return logps

tests/logprob/test_mixture.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
as_index_constant,
5353
)
5454

55-
from pymc.logprob.basic import factorized_joint_logprob
55+
from pymc.logprob.basic import factorized_joint_logprob, logp
5656
from pymc.logprob.mixture import MixtureRV, expand_indices
5757
from pymc.logprob.rewriting import construct_ir_fgraph
5858
from pymc.logprob.utils import dirac_delta
@@ -1112,3 +1112,23 @@ def test_joint_logprob_subtensor():
11121112
logp_vals = logp_vals_fn(A_idx_value, I_value)
11131113

11141114
np.testing.assert_almost_equal(logp_vals, exp_obs_logps, decimal=decimals)
1115+
1116+
1117+
def test_nested_ifelse():
1118+
idx = pt.scalar("idx", dtype=int)
1119+
1120+
dist0 = pt.random.normal(-5, 1)
1121+
dist1 = pt.random.normal(0, 1)
1122+
dist2 = pt.random.normal(5, 1)
1123+
mix = ifelse(pt.eq(idx, 0), dist0, ifelse(pt.eq(idx, 1), dist1, dist2))
1124+
mix.name = "mix"
1125+
1126+
value = mix.clone()
1127+
mix_logp = logp(mix, value)
1128+
assert mix_logp.name == "mix_logprob"
1129+
mix_logp_fn = pytensor.function([idx, value], mix_logp)
1130+
1131+
test_value = 0.25
1132+
np.testing.assert_almost_equal(mix_logp_fn(0, test_value), sp.norm.logpdf(test_value, -5, 1))
1133+
np.testing.assert_almost_equal(mix_logp_fn(1, test_value), sp.norm.logpdf(test_value, 0, 1))
1134+
np.testing.assert_almost_equal(mix_logp_fn(2, test_value), sp.norm.logpdf(test_value, 5, 1))

0 commit comments

Comments
 (0)