Skip to content

Commit 3aa5c54

Browse files
Evaluate initial values lazily
Related to pymc-devs#4924
1 parent bcc40ce commit 3aa5c54

File tree

3 files changed

+64
-19
lines changed

3 files changed

+64
-19
lines changed

pymc3/model.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -945,32 +945,43 @@ def recompute_initial_point(self) -> Dict[str, np.ndarray]:
945945
Returns
946946
-------
947947
initial_point : dict
948-
Maps free variable names to transformed, numeric initial values.
948+
Maps transformed free variable names to transformed, numeric initial values.
949949
"""
950-
self._initial_point_cache = Point(list(self.initial_values.items()), model=self)
950+
numeric_initvals = {}
951+
# The entries in `initial_values` are already in topological order and can be evaluated one by one.
952+
for rv_value, initval in self.initial_values.items():
953+
rv_var = self.values_to_rvs[rv_value]
954+
transform = getattr(rv_value.tag, "transform", None)
955+
if isinstance(initval, np.ndarray) and transform is None:
956+
# Only untransformed, numeric initvals can be taken as they are.
957+
numeric_initvals[rv_value] = initval
958+
else:
959+
# Evaluate initvals that are None, symbolic or need to be transformed.
960+
# They can depend on other initvals from higher up in the graph,
961+
# which are therefore fed to the evaluation as "givens".
962+
test_value = getattr(rv_var.tag, "test_value", None)
963+
numeric_initvals[rv_value] = self._eval_initval(
964+
rv_var, initval, test_value, transform, given=numeric_initvals
965+
)
966+
967+
# Cache the evaluation results for next time.
968+
self._initial_point_cache = Point(list(numeric_initvals.items()), model=self)
951969
return self._initial_point_cache
952970

953971
@property
954-
def initial_values(self) -> Dict[TensorVariable, np.ndarray]:
955-
"""Maps transformed variables to initial values.
972+
def initial_values(self) -> Dict[TensorVariable, Optional[Union[np.ndarray, Variable]]]:
973+
"""Maps transformed variables to initial value placeholders.
956974
957975
⚠ The keys are NOT the objects returned by, `pm.Normal(...)`.
958-
For a name-based dictionary use the `initial_point` property.
976+
For a name-based dictionary use the `get_initial_point()` method.
959977
"""
960978
return self._initial_values
961979

962980
def set_initval(self, rv_var, initval):
963981
if initval is not None:
964982
initval = rv_var.type.filter(initval)
965983

966-
test_value = getattr(rv_var.tag, "test_value", None)
967-
968984
rv_value_var = self.rvs_to_values[rv_var]
969-
transform = getattr(rv_value_var.tag, "transform", None)
970-
971-
if initval is None or transform:
972-
initval = self._eval_initval(rv_var, initval, test_value, transform)
973-
974985
self.initial_values[rv_value_var] = initval
975986

976987
def _eval_initval(
@@ -979,6 +990,7 @@ def _eval_initval(
979990
initval: Optional[Variable],
980991
test_value: Optional[np.ndarray],
981992
transform: Optional[Transform],
993+
given: Optional[Dict[TensorVariable, np.ndarray]] = None,
982994
) -> np.ndarray:
983995
"""Sample/evaluate an initial value using the existing initial values,
984996
and with the least effect on the RNGs involved (i.e. no in-placing).
@@ -997,6 +1009,8 @@ def _eval_initval(
9971009
transform : optional, Transform
9981010
A transformation associated with the random variable.
9991011
Transformations are automatically applied to initial values.
1012+
given : optional, dict
1013+
Numeric initial values to be used for givens instead of `self.initial_values`.
10001014
10011015
Returns
10021016
-------
@@ -1007,6 +1021,9 @@ def _eval_initval(
10071021
opt_qry = mode.provided_optimizer.excluding("random_make_inplace")
10081022
mode = Mode(linker=mode.linker, optimizer=opt_qry)
10091023

1024+
if given is None:
1025+
given = self.initial_values
1026+
10101027
if transform:
10111028
if initval is not None:
10121029
value = initval
@@ -1023,9 +1040,7 @@ def initval_to_rvval(value_var, value):
10231040
else:
10241041
return initval
10251042

1026-
givens = {
1027-
self.values_to_rvs[k]: initval_to_rvval(k, v) for k, v in self.initial_values.items()
1028-
}
1043+
givens = {self.values_to_rvs[k]: initval_to_rvval(k, v) for k, v in given.items()}
10291044
initval_fn = aesara.function([], rv_var, mode=mode, givens=givens, on_unused_input="ignore")
10301045
try:
10311046
initval = initval_fn()

pymc3/tests/test_initvals.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import aesara
1415
import numpy as np
1516
import pytest
1617

@@ -37,7 +38,8 @@ def test_new_warnings(self):
3738
with pm.Model() as pmodel:
3839
with pytest.warns(DeprecationWarning, match="`testval` argument is deprecated"):
3940
rv = pm.Uniform("u", 0, 1, testval=0.75)
40-
assert pmodel.initial_values[rv.tag.value_var] == transform_fwd(rv, 0.75)
41+
initial_point = pmodel.recompute_initial_point()
42+
assert initial_point["u_interval__"] == transform_fwd(rv, 0.75)
4143
assert not hasattr(rv.tag, "test_value")
4244
pass
4345

@@ -82,6 +84,33 @@ def test_falls_back_to_test_value(self):
8284
assert iv == 0.6
8385
pass
8486

87+
def test_dependent_initvals(self):
88+
with pm.Model() as pmodel:
89+
L = pm.Uniform("L", 0, 1, initval=0.5)
90+
B = pm.Uniform("B", lower=L, upper=2, initval=1.25)
91+
ip = pmodel.recompute_initial_point()
92+
assert ip["L_interval__"] == 0
93+
assert ip["B_interval__"] == 0
94+
95+
# Modify initval of L and re-evaluate
96+
pmodel.initial_values[pmodel.rvs_to_values[L]] = 0.9
97+
ip = pmodel.recompute_initial_point()
98+
assert ip["B_interval__"] < 0
99+
pass
100+
101+
def test_initval_resizing(self):
102+
with pm.Model() as pmodel:
103+
data = aesara.shared(np.arange(4))
104+
rv = pm.Uniform("u", lower=data, upper=10)
105+
106+
ip = pmodel.recompute_initial_point()
107+
assert np.shape(ip["u_interval__"]) == (4,)
108+
109+
data.set_value(np.arange(5))
110+
ip = pmodel.recompute_initial_point()
111+
assert np.shape(ip["u_interval__"]) == (5,)
112+
pass
113+
85114

86115
class TestSpecialDistributions:
87116
def test_automatically_assigned_test_values(self):

pymc3/tests/test_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,8 @@ def test_initial_point():
516516

517517
assert model.rvs_to_values[a] in model.initial_values
518518
assert model.rvs_to_values[x] in model.initial_values
519-
assert model.initial_values[b_value_var] == b_initval_trans
519+
assert model.initial_values[b_value_var] == b_initval
520+
assert model.recompute_initial_point()["b_interval__"] == b_initval_trans
520521
assert model.initial_values[model.rvs_to_values[y]] == y_initval
521522

522523

@@ -641,8 +642,8 @@ def test_set_initval():
641642
value = pm.NegativeBinomial("value", mu=mu, alpha=alpha)
642643

643644
assert np.array_equal(model.initial_values[model.rvs_to_values[mu]], np.array([[100.0]]))
644-
np.testing.assert_almost_equal(model.initial_values[model.rvs_to_values[alpha]], np.log(100))
645-
assert 50 < model.initial_values[model.rvs_to_values[value]] < 150
645+
np.testing.assert_array_equal(model.initial_values[model.rvs_to_values[alpha]], np.array(100))
646+
assert model.initial_values[model.rvs_to_values[value]] is None
646647

647648
# `Flat` cannot be sampled, so let's make sure that doesn't break initial
648649
# value computations

0 commit comments

Comments
 (0)