|
52 | 52 | as_index_constant,
|
53 | 53 | )
|
54 | 54 |
|
55 |
| -from pymc.logprob.basic import factorized_joint_logprob |
| 55 | +from pymc.logprob.basic import factorized_joint_logprob, logp |
56 | 56 | from pymc.logprob.mixture import MixtureRV, expand_indices
|
57 | 57 | from pymc.logprob.rewriting import construct_ir_fgraph
|
58 | 58 | from pymc.logprob.utils import dirac_delta
|
@@ -1112,3 +1112,23 @@ def test_joint_logprob_subtensor():
|
1112 | 1112 | logp_vals = logp_vals_fn(A_idx_value, I_value)
|
1113 | 1113 |
|
1114 | 1114 | np.testing.assert_almost_equal(logp_vals, exp_obs_logps, decimal=decimals)
|
| 1115 | + |
| 1116 | + |
| 1117 | +def test_nested_ifelse(): |
| 1118 | + idx = pt.scalar("idx", dtype=int) |
| 1119 | + |
| 1120 | + dist0 = pt.random.normal(-5, 1) |
| 1121 | + dist1 = pt.random.normal(0, 1) |
| 1122 | + dist2 = pt.random.normal(5, 1) |
| 1123 | + mix = ifelse(pt.eq(idx, 0), dist0, ifelse(pt.eq(idx, 1), dist1, dist2)) |
| 1124 | + mix.name = "mix" |
| 1125 | + |
| 1126 | + value = mix.clone() |
| 1127 | + mix_logp = logp(mix, value) |
| 1128 | + assert mix_logp.name == "mix_logprob" |
| 1129 | + mix_logp_fn = pytensor.function([idx, value], mix_logp) |
| 1130 | + |
| 1131 | + test_value = 0.25 |
| 1132 | + np.testing.assert_almost_equal(mix_logp_fn(0, test_value), sp.norm.logpdf(test_value, -5, 1)) |
| 1133 | + np.testing.assert_almost_equal(mix_logp_fn(1, test_value), sp.norm.logpdf(test_value, 0, 1)) |
| 1134 | + np.testing.assert_almost_equal(mix_logp_fn(2, test_value), sp.norm.logpdf(test_value, 5, 1)) |
0 commit comments