diff --git a/pymc/initial_point.py b/pymc/initial_point.py index f25aed8f20..b5f29a7e98 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -291,6 +291,7 @@ def make_initial_point_expression( jitter.name = f"{variable.name}_jitter" value = value + jitter + value = value.astype(variable.dtype) initial_values_transformed.append(value) if transform is not None: @@ -310,18 +311,17 @@ def make_initial_point_expression( initial_values_clone = copy_graph.outputs[n_variables:-n_variables] initial_values_transformed_clone = copy_graph.outputs[-n_variables:] - # In the order the variables were created, replace each previous variable - # with the init_point for that variable. - initial_values = [] - initial_values_transformed = [] - - for i in range(n_variables): - outputs = [initial_values_clone[i], initial_values_transformed_clone[i]] - graph = FunctionGraph(outputs=outputs, clone=False) - graph.replace_all(zip(free_rvs_clone[:i], initial_values), import_missing=True) - initial_values.append(graph.outputs[0]) - initial_values_transformed.append(graph.outputs[1]) - - if return_transformed: - return initial_values_transformed - return initial_values + # We now replace all rvs by the respective initial_point expressions + # in the constrained (untransformed) space. We do this in reverse topological + # order, so that later nodes do not reintroduce expressions with earlier + # rvs that would need to once again be replaced by their initial_points + graph = FunctionGraph(outputs=free_rvs_clone, clone=False) + replacements = reversed(list(zip(free_rvs_clone, initial_values_clone))) + graph.replace_all(replacements, import_missing=True) + + if not return_transformed: + return graph.outputs + # Because the unconstrained (transformed) expressions are a subgraph of the + # constrained initial point they were also automatically updated inplace + # when calling graph.replace_all above, so we don't need to do anything else + return initial_values_transformed_clone diff --git a/pymc/tests/test_initial_point.py b/pymc/tests/test_initial_point.py index 1d3275de92..899209399e 100644 --- a/pymc/tests/test_initial_point.py +++ b/pymc/tests/test_initial_point.py @@ -74,17 +74,46 @@ def test_make_initial_point_fns_per_chain_checks_kwargs(self): def test_dependent_initvals(self): with pm.Model() as pmodel: L = pm.Uniform("L", 0, 1, initval=0.5) - B = pm.Uniform("B", lower=L, upper=2, initval=1.25) + U = pm.Uniform("U", lower=9, upper=10, initval=9.5) + B1 = pm.Uniform("B1", lower=L, upper=U, initval=5) + B2 = pm.Uniform("B2", lower=L, upper=U, initval=(L + U) / 2) ip = pmodel.recompute_initial_point(seed=0) assert ip["L_interval__"] == 0 - assert ip["B_interval__"] == 0 + assert ip["U_interval__"] == 0 + assert ip["B1_interval__"] == 0 + assert ip["B2_interval__"] == 0 # Modify initval of L and re-evaluate - pmodel.initial_values[L] = 0.9 + pmodel.initial_values[U] = 9.9 ip = pmodel.recompute_initial_point(seed=0) - assert ip["B_interval__"] < 0 + assert ip["B1_interval__"] < 0 + assert ip["B2_interval__"] == 0 pass + def test_nested_initvals(self): + # See issue #5168 + with pm.Model() as pmodel: + one = pm.LogNormal("one", mu=np.log(1), sd=1e-5, initval="prior") + two = pm.Lognormal("two", mu=np.log(one * 2), sd=1e-5, initval="prior") + three = pm.LogNormal("three", mu=np.log(two * 2), sd=1e-5, initval="prior") + four = pm.LogNormal("four", mu=np.log(three * 2), sd=1e-5, initval="prior") + five = pm.LogNormal("five", mu=np.log(four * 2), sd=1e-5, initval="prior") + six = pm.LogNormal("six", mu=np.log(five * 2), sd=1e-5, initval="prior") + + ip_vals = list(make_initial_point_fn(model=pmodel, return_transformed=True)(0).values()) + assert np.allclose(np.exp(ip_vals), [1, 2, 4, 8, 16, 32], rtol=1e-3) + + ip_vals = list(make_initial_point_fn(model=pmodel, return_transformed=False)(0).values()) + assert np.allclose(ip_vals, [1, 2, 4, 8, 16, 32], rtol=1e-3) + + pmodel.initial_values[four] = 1 + + ip_vals = list(make_initial_point_fn(model=pmodel, return_transformed=True)(0).values()) + assert np.allclose(np.exp(ip_vals), [1, 2, 4, 1, 2, 4], rtol=1e-3) + + ip_vals = list(make_initial_point_fn(model=pmodel, return_transformed=False)(0).values()) + assert np.allclose(ip_vals, [1, 2, 4, 1, 2, 4], rtol=1e-3) + def test_initval_resizing(self): with pm.Model() as pmodel: data = aesara.shared(np.arange(4))