Skip to content

Fix make_initial_point_expression recursion bug #5170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions pymc/initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def make_initial_point_expression(
jitter.name = f"{variable.name}_jitter"
value = value + jitter

value = value.astype(variable.dtype)
initial_values_transformed.append(value)

if transform is not None:
Expand All @@ -310,18 +311,17 @@ def make_initial_point_expression(
initial_values_clone = copy_graph.outputs[n_variables:-n_variables]
initial_values_transformed_clone = copy_graph.outputs[-n_variables:]

# In the order the variables were created, replace each previous variable
# with the init_point for that variable.
initial_values = []
initial_values_transformed = []

for i in range(n_variables):
outputs = [initial_values_clone[i], initial_values_transformed_clone[i]]
graph = FunctionGraph(outputs=outputs, clone=False)
graph.replace_all(zip(free_rvs_clone[:i], initial_values), import_missing=True)
initial_values.append(graph.outputs[0])
initial_values_transformed.append(graph.outputs[1])

if return_transformed:
return initial_values_transformed
return initial_values
# We now replace all rvs by the respective initial_point expressions
# in the constrained (untransformed) space. We do this in reverse topological
# order, so that later nodes do not reintroduce expressions with earlier
# rvs that would need to once again be replaced by their initial_points
graph = FunctionGraph(outputs=free_rvs_clone, clone=False)
replacements = reversed(list(zip(free_rvs_clone, initial_values_clone)))
graph.replace_all(replacements, import_missing=True)

if not return_transformed:
return graph.outputs
# Because the unconstrained (transformed) expressions are a subgraph of the
# constrained initial point they were also automatically updated inplace
# when calling graph.replace_all above, so we don't need to do anything else
return initial_values_transformed_clone
37 changes: 33 additions & 4 deletions pymc/tests/test_initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,46 @@ def test_make_initial_point_fns_per_chain_checks_kwargs(self):
def test_dependent_initvals(self):
with pm.Model() as pmodel:
L = pm.Uniform("L", 0, 1, initval=0.5)
B = pm.Uniform("B", lower=L, upper=2, initval=1.25)
U = pm.Uniform("U", lower=9, upper=10, initval=9.5)
B1 = pm.Uniform("B1", lower=L, upper=U, initval=5)
B2 = pm.Uniform("B2", lower=L, upper=U, initval=(L + U) / 2)
ip = pmodel.recompute_initial_point(seed=0)
assert ip["L_interval__"] == 0
assert ip["B_interval__"] == 0
assert ip["U_interval__"] == 0
assert ip["B1_interval__"] == 0
assert ip["B2_interval__"] == 0

# Modify initval of L and re-evaluate
pmodel.initial_values[L] = 0.9
pmodel.initial_values[U] = 9.9
ip = pmodel.recompute_initial_point(seed=0)
assert ip["B_interval__"] < 0
assert ip["B1_interval__"] < 0
assert ip["B2_interval__"] == 0
pass

def test_nested_initvals(self):
# See issue #5168
with pm.Model() as pmodel:
one = pm.LogNormal("one", mu=np.log(1), sd=1e-5, initval="prior")
two = pm.Lognormal("two", mu=np.log(one * 2), sd=1e-5, initval="prior")
three = pm.LogNormal("three", mu=np.log(two * 2), sd=1e-5, initval="prior")
four = pm.LogNormal("four", mu=np.log(three * 2), sd=1e-5, initval="prior")
five = pm.LogNormal("five", mu=np.log(four * 2), sd=1e-5, initval="prior")
six = pm.LogNormal("six", mu=np.log(five * 2), sd=1e-5, initval="prior")

ip_vals = list(make_initial_point_fn(model=pmodel, return_transformed=True)(0).values())
assert np.allclose(np.exp(ip_vals), [1, 2, 4, 8, 16, 32], rtol=1e-3)

ip_vals = list(make_initial_point_fn(model=pmodel, return_transformed=False)(0).values())
assert np.allclose(ip_vals, [1, 2, 4, 8, 16, 32], rtol=1e-3)

pmodel.initial_values[four] = 1

ip_vals = list(make_initial_point_fn(model=pmodel, return_transformed=True)(0).values())
assert np.allclose(np.exp(ip_vals), [1, 2, 4, 1, 2, 4], rtol=1e-3)

ip_vals = list(make_initial_point_fn(model=pmodel, return_transformed=False)(0).values())
assert np.allclose(ip_vals, [1, 2, 4, 1, 2, 4], rtol=1e-3)

def test_initval_resizing(self):
with pm.Model() as pmodel:
data = aesara.shared(np.arange(4))
Expand Down