diff --git a/pymc/step_methods/mlda.py b/pymc/step_methods/mlda.py index d66013189d..eb29244f2d 100644 --- a/pymc/step_methods/mlda.py +++ b/pymc/step_methods/mlda.py @@ -66,22 +66,14 @@ def __init__(self, *args, **kwargs): self.Q_last = np.nan self.Q_reg = [np.nan] * self.mlda_subsampling_rate_above - # extract some necessary variables - vars = kwargs.get("vars", None) - if vars is None: - vars = model.value_vars - else: - vars = [model.rvs_to_values.get(var, var) for var in vars] - vars = pm.inputvars(vars) - shared = pm.make_shared_replacements(initial_values, vars, model) - # call parent class __init__ super().__init__(*args, **kwargs) # modify the delta function and point to model if VR is used if self.mlda_variance_reduction: - self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared) self.model = model + self.delta_logp_factory = self.delta_logp + self.delta_logp = lambda q, q0: -self.delta_logp_factory(q0, q) def reset_tuning(self): """ @@ -136,22 +128,14 @@ def __init__(self, *args, **kwargs): self.Q_last = np.nan self.Q_reg = [np.nan] * self.mlda_subsampling_rate_above - # extract some necessary variables - vars = kwargs.get("vars", None) - if vars is None: - vars = model.value_vars - else: - vars = [model.rvs_to_values.get(var, var) for var in vars] - vars = pm.inputvars(vars) - shared = pm.make_shared_replacements(initial_values, vars, model) - # call parent class __init__ super().__init__(*args, **kwargs) # modify the delta function and point to model if VR is used if self.mlda_variance_reduction: - self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared) self.model = model + self.delta_logp_factory = self.delta_logp + self.delta_logp = lambda q, q0: -self.delta_logp_factory(q0, q) def reset_tuning(self): """Skips resetting of tuned sampler parameters @@ -556,7 +540,7 @@ def __init__( # Construct Aesara function for current-level model likelihood # (for use in acceptance) shared = pm.make_shared_replacements(initial_values, vars, model) - self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared) + self.delta_logp = delta_logp(initial_values, model.logpt, vars, shared) # Construct Aesara function for below-level model likelihood # (for use in acceptance) @@ -749,7 +733,9 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: accept = np.float64(0.0) skipped_logp = True else: - accept = self.delta_logp(q.data, q0.data) + self.delta_logp_below(q0.data, q.data) + # NB! The order and sign of the first term are swapped compared + # to the convention to make sure the proposal is evaluated last. + accept = -self.delta_logp(q0.data, q.data) + self.delta_logp_below(q0.data, q.data) skipped_logp = False # Accept/reject sample - next sample is stored in q_new @@ -954,19 +940,6 @@ def update(self, x): self.t += 1 -def delta_logp_inverse(point, logp, vars, shared): - [logp0], inarray0 = pm.join_nonshared_inputs(point, [logp], vars, shared) - - tensor_type = inarray0.type - inarray1 = tensor_type("inarray1") - - logp1 = pm.CallableTensor(logp0)(inarray1) - - f = compile_rv_inplace([inarray1, inarray0], -logp0 + logp1) - f.trust_input = True - return f - - def extract_Q_estimate(trace, levels): """ Returns expectation and standard error of quantity of interest,