@@ -132,7 +132,11 @@ def inv_gamma_reparametrization(fgraph, node):
132132
133133def gamma_reparametrization_impl (rng , size , shape , scale ):
134134 # We'll implement the Marsaglia-Tsang boosted algorithm to sample
135- # We follow the implementation used in jax
135+ # https://dl.acm.org/doi/epdf/10.1145/358407.358414
136+ # We follow the algorithm from section 3 without squeezing.
137+ # The squeeze algorithm from section 4 is more efficient if we can avoid
138+ # computing all branches of the if-else clauses, which is not possible
139+ # when using pytensor switches.
136140 # For context, shape is equal to alpha in all of the following math.
137141 # Sampling for alpha >= 1 is done with a rejection algorithm that finishes in constant time.
138142 # For alpha < 1, we need to boost the samples using:
@@ -141,9 +145,6 @@ def gamma_reparametrization_impl(rng, size, shape, scale):
141145 # log(Gamma(alpha, 1)) = log(Gamma(alpha + 1, 1)) + log(Uniform(0, 1)) / alpha
142146 # We can note that log(Uniform(0, 1)) = -Exponential(1)
143147 # and we have to guard agains the case where the exponential sample is equal to 0.
144- #
145- # Jax's implementation of the rejection algorithm does not match with what's
146- # written in wikipedia.
147148 assert size == (), (
148149 "Gamma reparametrization requires that you first apply the local_rv_size_lift "
149150 "rewrite in order to have size equal to an empty tuple."
@@ -164,40 +165,21 @@ def rejection_step(output, chosen, rng, c, d, alpha):
164165 ones_like (alpha ),
165166 rng = rng ,
166167 ).owner .outputs
167-
168- def inner_step (x , v , chosen , rng , c ):
169- next_rng , X = NormalRV ()(
170- zeros_like (c ),
171- ones_like (c ),
172- rng = rng ,
173- ).owner .outputs
174- v = 1 + c * X
175- indicators = and_ (pt .ge (v , 0 ), ~ chosen )
176- chosen = switch (indicators , ones_like (chosen , dtype = "bool" ), chosen )
177- return (X , v , chosen , next_rng ), until (pt .all (chosen ))
178-
179- xs , vs , _ , next_rng = scan (
180- inner_step ,
181- outputs_info = [
182- zeros_like (alpha ), # X
183- - ones_like (alpha ), # V
184- zeros_like (alpha , dtype = "bool" ), # Chosen
185- next_rng ,
186- ],
187- non_sequences = [c ],
188- n_steps = 10_000 ,
189- return_updates = False ,
190- )
191-
192- x = xs [- 1 ]
193- v = vs [- 1 ]
168+ next_rng , x = NormalRV ()(
169+ zeros_like (c ),
170+ ones_like (c ),
171+ rng = next_rng ,
172+ ).owner .outputs
173+ V = (1 + c * x ) ** 3
194174
195175 X = x * x
196- V = v ** 3
197176
198177 indicators = and_ (
199178 ~ chosen ,
200- ~ and_ (pt .gt (U , 1 - 0.0331 * X * X ), pt .gt (log (U ), X / 2 + d * (1 - V + log (V )))),
179+ and_ (
180+ V > 0 ,
181+ pt .lt (log (U ), X / 2 + d * (1 - V + log (V ))),
182+ ),
201183 )
202184 chosen = switch (indicators , ones_like (chosen , dtype = "bool" ), chosen )
203185 output = switch (indicators , V , output )
0 commit comments