Skip to content

Commit 53298ab

Browse files
committed
Avoid the nested scans in gamma implementation
1 parent 0934e57 commit 53298ab

File tree

1 file changed

+15
-33
lines changed

1 file changed

+15
-33
lines changed

pymc/distributions/rewrites/reparametrization.py

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,11 @@ def inv_gamma_reparametrization(fgraph, node):
132132

133133
def gamma_reparametrization_impl(rng, size, shape, scale):
134134
# We'll implement the Marsaglia-Tsang boosted algorithm to sample
135-
# We follow the implementation used in jax
135+
# https://dl.acm.org/doi/epdf/10.1145/358407.358414
136+
# We follow the algorithm from section 3 without squeezing.
137+
# The squeeze algorithm from section 4 is more efficient if we can avoid
138+
# computing all branches of the if-else clauses, which is not possible
139+
# when using pytensor switches.
136140
# For context, shape is equal to alpha in all of the following math.
137141
# Sampling for alpha >= 1 is done with a rejection algorithm that finishes in constant time.
138142
# For alpha < 1, we need to boost the samples using:
@@ -141,9 +145,6 @@ def gamma_reparametrization_impl(rng, size, shape, scale):
141145
# log(Gamma(alpha, 1)) = log(Gamma(alpha + 1, 1)) + log(Uniform(0, 1)) / alpha
142146
# We can note that log(Uniform(0, 1)) = -Exponential(1)
143147
# and we have to guard agains the case where the exponential sample is equal to 0.
144-
#
145-
# Jax's implementation of the rejection algorithm does not match with what's
146-
# written in wikipedia.
147148
assert size == (), (
148149
"Gamma reparametrization requires that you first apply the local_rv_size_lift "
149150
"rewrite in order to have size equal to an empty tuple."
@@ -164,40 +165,21 @@ def rejection_step(output, chosen, rng, c, d, alpha):
164165
ones_like(alpha),
165166
rng=rng,
166167
).owner.outputs
167-
168-
def inner_step(x, v, chosen, rng, c):
169-
next_rng, X = NormalRV()(
170-
zeros_like(c),
171-
ones_like(c),
172-
rng=rng,
173-
).owner.outputs
174-
v = 1 + c * X
175-
indicators = and_(pt.ge(v, 0), ~chosen)
176-
chosen = switch(indicators, ones_like(chosen, dtype="bool"), chosen)
177-
return (X, v, chosen, next_rng), until(pt.all(chosen))
178-
179-
xs, vs, _, next_rng = scan(
180-
inner_step,
181-
outputs_info=[
182-
zeros_like(alpha), # X
183-
-ones_like(alpha), # V
184-
zeros_like(alpha, dtype="bool"), # Chosen
185-
next_rng,
186-
],
187-
non_sequences=[c],
188-
n_steps=10_000,
189-
return_updates=False,
190-
)
191-
192-
x = xs[-1]
193-
v = vs[-1]
168+
next_rng, x = NormalRV()(
169+
zeros_like(c),
170+
ones_like(c),
171+
rng=next_rng,
172+
).owner.outputs
173+
V = (1 + c * x) ** 3
194174

195175
X = x * x
196-
V = v**3
197176

198177
indicators = and_(
199178
~chosen,
200-
~and_(pt.gt(U, 1 - 0.0331 * X * X), pt.gt(log(U), X / 2 + d * (1 - V + log(V)))),
179+
and_(
180+
V > 0,
181+
pt.lt(log(U), X / 2 + d * (1 - V + log(V))),
182+
),
201183
)
202184
chosen = switch(indicators, ones_like(chosen, dtype="bool"), chosen)
203185
output = switch(indicators, V, output)

0 commit comments

Comments
 (0)