Skip to content

Fixed bug in delta_logp for MLDA that broke AEM and VR #5104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 8 additions & 35 deletions pymc/step_methods/mlda.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,14 @@ def __init__(self, *args, **kwargs):
self.Q_last = np.nan
self.Q_reg = [np.nan] * self.mlda_subsampling_rate_above

# extract some necessary variables
vars = kwargs.get("vars", None)
if vars is None:
vars = model.value_vars
else:
vars = [model.rvs_to_values.get(var, var) for var in vars]
vars = pm.inputvars(vars)
shared = pm.make_shared_replacements(initial_values, vars, model)

# call parent class __init__
super().__init__(*args, **kwargs)

# modify the delta function and point to model if VR is used
if self.mlda_variance_reduction:
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared)
self.model = model
self.delta_logp_factory = self.delta_logp
self.delta_logp = lambda q, q0: -self.delta_logp_factory(q0, q)

def reset_tuning(self):
"""
Expand Down Expand Up @@ -136,22 +128,14 @@ def __init__(self, *args, **kwargs):
self.Q_last = np.nan
self.Q_reg = [np.nan] * self.mlda_subsampling_rate_above

# extract some necessary variables
vars = kwargs.get("vars", None)
if vars is None:
vars = model.value_vars
else:
vars = [model.rvs_to_values.get(var, var) for var in vars]
vars = pm.inputvars(vars)
shared = pm.make_shared_replacements(initial_values, vars, model)

# call parent class __init__
super().__init__(*args, **kwargs)

# modify the delta function and point to model if VR is used
if self.mlda_variance_reduction:
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared)
self.model = model
self.delta_logp_factory = self.delta_logp
self.delta_logp = lambda q, q0: -self.delta_logp_factory(q0, q)

def reset_tuning(self):
"""Skips resetting of tuned sampler parameters
Expand Down Expand Up @@ -556,7 +540,7 @@ def __init__(
# Construct Aesara function for current-level model likelihood
# (for use in acceptance)
shared = pm.make_shared_replacements(initial_values, vars, model)
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared)
self.delta_logp = delta_logp(initial_values, model.logpt, vars, shared)

# Construct Aesara function for below-level model likelihood
# (for use in acceptance)
Expand Down Expand Up @@ -749,7 +733,9 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
accept = np.float64(0.0)
skipped_logp = True
else:
accept = self.delta_logp(q.data, q0.data) + self.delta_logp_below(q0.data, q.data)
# NB! The order and sign of the first term are swapped compared
# to the convention to make sure the proposal is evaluated last.
accept = -self.delta_logp(q0.data, q.data) + self.delta_logp_below(q0.data, q.data)
skipped_logp = False

# Accept/reject sample - next sample is stored in q_new
Expand Down Expand Up @@ -954,19 +940,6 @@ def update(self, x):
self.t += 1


def delta_logp_inverse(point, logp, vars, shared):
[logp0], inarray0 = pm.join_nonshared_inputs(point, [logp], vars, shared)

tensor_type = inarray0.type
inarray1 = tensor_type("inarray1")

logp1 = pm.CallableTensor(logp0)(inarray1)

f = compile_rv_inplace([inarray1, inarray0], -logp0 + logp1)
f.trust_input = True
return f


def extract_Q_estimate(trace, levels):
"""
Returns expectation and standard error of quantity of interest,
Expand Down