Skip to content

Commit 6b486b9

Browse files
committed
Fix truncated rejection sampling for scalar RVs
1 parent efa0d34 commit 6b486b9

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

pymc/distributions/truncated.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,14 @@ def rv_op(cls, dist, lower, upper, max_n_steps, size=None):
211211
# Fallback to rejection sampling
212212
def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
213213
next_rng, new_truncated_rv = dist.owner.op.make_node(rng, *rv_inputs).outputs
214-
truncated_rv = pt.set_subtensor(
215-
truncated_rv[reject_draws],
216-
new_truncated_rv[reject_draws],
217-
)
214+
# Avoid scalar boolean indexing
215+
if truncated_rv.type.ndim == 0:
216+
truncated_rv = new_truncated_rv
217+
else:
218+
truncated_rv = pt.set_subtensor(
219+
truncated_rv[reject_draws],
220+
new_truncated_rv[reject_draws],
221+
)
218222
reject_draws = pt.or_((truncated_rv < lower), (truncated_rv > upper))
219223

220224
return (

tests/distributions/test_truncated.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,12 @@ def test_truncation_specialized_op(shape_info):
101101

102102
@pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)])
103103
@pytest.mark.parametrize("op_type", ["icdf", "rejection"])
104-
def test_truncation_continuous_random(op_type, lower, upper):
104+
@pytest.mark.parametrize("scalar", [True, False])
105+
def test_truncation_continuous_random(op_type, lower, upper, scalar):
105106
loc = 0.15
106107
scale = 10
107108
normal_op = icdf_normal if op_type == "icdf" else rejection_normal
108-
x = normal_op(loc, scale, name="x", size=100)
109+
x = normal_op(loc, scale, name="x", size=() if scalar else (100,))
109110

110111
xt = Truncated.dist(x, lower=lower, upper=upper)
111112
assert isinstance(xt.owner.op, TruncatedRV)
@@ -134,7 +135,7 @@ def test_truncation_continuous_random(op_type, lower, upper):
134135
assert np.unique(xt_draws).size == xt_draws.size
135136
else:
136137
with pytest.raises(TruncationError, match="^Truncation did not converge"):
137-
draw(xt)
138+
draw(xt, draws=100 if scalar else 1)
138139

139140

140141
@pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)])

0 commit comments

Comments
 (0)