@@ -945,32 +945,43 @@ def recompute_initial_point(self) -> Dict[str, np.ndarray]:
945
945
Returns
946
946
-------
947
947
initial_point : dict
948
- Maps free variable names to transformed, numeric initial values.
948
+ Maps transformed free variable names to transformed, numeric initial values.
949
949
"""
950
- self ._initial_point_cache = Point (list (self .initial_values .items ()), model = self )
950
+ numeric_initvals = {}
951
+ # The entries in `initial_values` are already in topological order and can be evaluated one by one.
952
+ for rv_value , initval in self .initial_values .items ():
953
+ rv_var = self .values_to_rvs [rv_value ]
954
+ transform = getattr (rv_value .tag , "transform" , None )
955
+ if isinstance (initval , np .ndarray ) and transform is None :
956
+ # Only untransformed, numeric initvals can be taken as they are.
957
+ numeric_initvals [rv_value ] = initval
958
+ else :
959
+ # Evaluate initvals that are None, symbolic or need to be transformed.
960
+ # They can depend on other initvals from higher up in the graph,
961
+ # which are therefore fed to the evaluation as "givens".
962
+ test_value = getattr (rv_var .tag , "test_value" , None )
963
+ numeric_initvals [rv_value ] = self ._eval_initval (
964
+ rv_var , initval , test_value , transform , given = numeric_initvals
965
+ )
966
+
967
+ # Cache the evaluation results for next time.
968
+ self ._initial_point_cache = Point (list (numeric_initvals .items ()), model = self )
951
969
return self ._initial_point_cache
952
970
953
971
@property
954
- def initial_values (self ) -> Dict [TensorVariable , np .ndarray ]:
955
- """Maps transformed variables to initial values .
972
+ def initial_values (self ) -> Dict [TensorVariable , Optional [ Union [ np .ndarray , Variable ]] ]:
973
+ """Maps transformed variables to initial value placeholders .
956
974
957
975
⚠ The keys are NOT the objects returned by, `pm.Normal(...)`.
958
- For a name-based dictionary use the `initial_point` property .
976
+ For a name-based dictionary use the `get_initial_point()` method .
959
977
"""
960
978
return self ._initial_values
961
979
962
980
def set_initval (self , rv_var , initval ):
963
981
if initval is not None :
964
982
initval = rv_var .type .filter (initval )
965
983
966
- test_value = getattr (rv_var .tag , "test_value" , None )
967
-
968
984
rv_value_var = self .rvs_to_values [rv_var ]
969
- transform = getattr (rv_value_var .tag , "transform" , None )
970
-
971
- if initval is None or transform :
972
- initval = self ._eval_initval (rv_var , initval , test_value , transform )
973
-
974
985
self .initial_values [rv_value_var ] = initval
975
986
976
987
def _eval_initval (
@@ -979,6 +990,7 @@ def _eval_initval(
979
990
initval : Optional [Variable ],
980
991
test_value : Optional [np .ndarray ],
981
992
transform : Optional [Transform ],
993
+ given : Optional [Dict [TensorVariable , np .ndarray ]] = None ,
982
994
) -> np .ndarray :
983
995
"""Sample/evaluate an initial value using the existing initial values,
984
996
and with the least effect on the RNGs involved (i.e. no in-placing).
@@ -997,6 +1009,8 @@ def _eval_initval(
997
1009
transform : optional, Transform
998
1010
A transformation associated with the random variable.
999
1011
Transformations are automatically applied to initial values.
1012
+ given : optional, dict
1013
+ Numeric initial values to be used for givens instead of `self.initial_values`.
1000
1014
1001
1015
Returns
1002
1016
-------
@@ -1007,6 +1021,9 @@ def _eval_initval(
1007
1021
opt_qry = mode .provided_optimizer .excluding ("random_make_inplace" )
1008
1022
mode = Mode (linker = mode .linker , optimizer = opt_qry )
1009
1023
1024
+ if given is None :
1025
+ given = self .initial_values
1026
+
1010
1027
if transform :
1011
1028
if initval is not None :
1012
1029
value = initval
@@ -1023,9 +1040,7 @@ def initval_to_rvval(value_var, value):
1023
1040
else :
1024
1041
return initval
1025
1042
1026
- givens = {
1027
- self .values_to_rvs [k ]: initval_to_rvval (k , v ) for k , v in self .initial_values .items ()
1028
- }
1043
+ givens = {self .values_to_rvs [k ]: initval_to_rvval (k , v ) for k , v in given .items ()}
1029
1044
initval_fn = aesara .function ([], rv_var , mode = mode , givens = givens , on_unused_input = "ignore" )
1030
1045
try :
1031
1046
initval = initval_fn ()
0 commit comments