59
59
import xarray
60
60
61
61
from pytensor .graph .basic import Variable
62
+ from pytensor .graph .replace import graph_replace
63
+ from pytensor .tensor .shape import unbroadcast
62
64
63
65
import pymc as pm
64
66
@@ -1002,7 +1004,7 @@ def set_size_and_deterministic(
1002
1004
"""
1003
1005
1004
1006
flat2rand = self .make_size_and_deterministic_replacements (s , d , more_replacements )
1005
- node_out = pytensor . clone_replace (node , flat2rand )
1007
+ node_out = graph_replace (node , flat2rand , strict = False )
1006
1008
assert not (
1007
1009
set (makeiter (self .input )) & set (pytensor .graph .graph_inputs (makeiter (node_out )))
1008
1010
)
@@ -1012,7 +1014,7 @@ def set_size_and_deterministic(
1012
1014
1013
1015
def to_flat_input (self , node ):
1014
1016
"""*Dev* - replace vars with flattened view stored in `self.inputs`"""
1015
- return pytensor . clone_replace (node , self .replacements )
1017
+ return graph_replace (node , self .replacements , strict = False )
1016
1018
1017
1019
def symbolic_sample_over_posterior (self , node ):
1018
1020
"""*Dev* - performs sampling of node applying independent samples from posterior each time.
@@ -1023,7 +1025,7 @@ def symbolic_sample_over_posterior(self, node):
1023
1025
random = pt .specify_shape (random , self .symbolic_initial .type .shape )
1024
1026
1025
1027
def sample (post , * _ ):
1026
- return pytensor . clone_replace (node , {self .input : post })
1028
+ return graph_replace (node , {self .input : post }, strict = False )
1027
1029
1028
1030
nodes , _ = pytensor .scan (
1029
1031
sample , random , non_sequences = _known_scan_ignored_inputs (makeiter (random ))
@@ -1038,7 +1040,7 @@ def symbolic_single_sample(self, node):
1038
1040
"""
1039
1041
node = self .to_flat_input (node )
1040
1042
random = self .symbolic_random .astype (self .symbolic_initial .dtype )
1041
- return pytensor . clone_replace (node , {self .input : random [0 ]})
1043
+ return graph_replace (node , {self .input : random [0 ]}, strict = False )
1042
1044
1043
1045
def make_size_and_deterministic_replacements (self , s , d , more_replacements = None ):
1044
1046
"""*Dev* - creates correct replacements for initial depending on
@@ -1059,8 +1061,15 @@ def make_size_and_deterministic_replacements(self, s, d, more_replacements=None)
1059
1061
"""
1060
1062
initial = self ._new_initial (s , d , more_replacements )
1061
1063
initial = pt .specify_shape (initial , self .symbolic_initial .type .shape )
1064
+ # The static shape of initial may be more precise than self.symbolic_initial,
1065
+ # and reveal previously unknown broadcastable dimensions. We have to mask those again.
1066
+ if initial .type .broadcastable != self .symbolic_initial .type .broadcastable :
1067
+ unbroadcast_axes = (
1068
+ i for i , b in enumerate (self .symbolic_initial .type .broadcastable ) if not b
1069
+ )
1070
+ initial = unbroadcast (initial , * unbroadcast_axes )
1062
1071
if more_replacements :
1063
- initial = pytensor . clone_replace (initial , more_replacements )
1072
+ initial = graph_replace (initial , more_replacements , strict = False )
1064
1073
return {self .symbolic_initial : initial }
1065
1074
1066
1075
@node_property
@@ -1394,17 +1403,17 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None):
1394
1403
_node = node
1395
1404
optimizations = self .get_optimization_replacements (s , d )
1396
1405
flat2rand = self .make_size_and_deterministic_replacements (s , d , more_replacements )
1397
- node = pytensor . clone_replace (node , optimizations )
1398
- node = pytensor . clone_replace (node , flat2rand )
1406
+ node = graph_replace (node , optimizations , strict = False )
1407
+ node = graph_replace (node , flat2rand , strict = False )
1399
1408
assert not (set (self .symbolic_randoms ) & set (pytensor .graph .graph_inputs (makeiter (node ))))
1400
1409
try_to_set_test_value (_node , node , s )
1401
1410
return node
1402
1411
1403
1412
def to_flat_input (self , node , more_replacements = None ):
1404
1413
"""*Dev* - replace vars with flattened view stored in `self.inputs`"""
1405
1414
more_replacements = more_replacements or {}
1406
- node = pytensor . clone_replace (node , more_replacements )
1407
- return pytensor . clone_replace (node , self .replacements )
1415
+ node = graph_replace (node , more_replacements , strict = False )
1416
+ return graph_replace (node , self .replacements , strict = False )
1408
1417
1409
1418
def symbolic_sample_over_posterior (self , node , more_replacements = None ):
1410
1419
"""*Dev* - performs sampling of node applying independent samples from posterior each time.
@@ -1413,7 +1422,7 @@ def symbolic_sample_over_posterior(self, node, more_replacements=None):
1413
1422
node = self .to_flat_input (node )
1414
1423
1415
1424
def sample (* post ):
1416
- return pytensor . clone_replace (node , dict (zip (self .inputs , post )))
1425
+ return graph_replace (node , dict (zip (self .inputs , post )), strict = False )
1417
1426
1418
1427
nodes , _ = pytensor .scan (
1419
1428
sample , self .symbolic_randoms , non_sequences = _known_scan_ignored_inputs (makeiter (node ))
@@ -1429,7 +1438,7 @@ def symbolic_single_sample(self, node, more_replacements=None):
1429
1438
node = self .to_flat_input (node , more_replacements = more_replacements )
1430
1439
post = [v [0 ] for v in self .symbolic_randoms ]
1431
1440
inp = self .inputs
1432
- return pytensor . clone_replace (node , dict (zip (inp , post )))
1441
+ return graph_replace (node , dict (zip (inp , post )), strict = False )
1433
1442
1434
1443
def get_optimization_replacements (self , s , d ):
1435
1444
"""*Dev* - optimizations for logP. If sample size is static and equal to 1:
@@ -1463,7 +1472,7 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No
1463
1472
"""
1464
1473
node_in = node
1465
1474
if more_replacements :
1466
- node = pytensor . clone_replace (node , more_replacements )
1475
+ node = graph_replace (node , more_replacements , strict = False )
1467
1476
if not isinstance (node , (list , tuple )):
1468
1477
node = [node ]
1469
1478
node = self .model .replace_rvs_by_values (node )
0 commit comments