Skip to content

Commit 5c600c7

Browse files
committed
Use graph_replace instead of clone_replace in VI
1 parent 1a06d50 commit 5c600c7

File tree

3 files changed

+27
-15
lines changed

3 files changed

+27
-15
lines changed

pymc/variational/approximations.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from arviz import InferenceData
2020
from pytensor import tensor as pt
2121
from pytensor.graph.basic import Variable
22+
from pytensor.graph.replace import graph_replace
2223
from pytensor.tensor.var import TensorVariable
2324

2425
import pymc as pm
@@ -390,7 +391,7 @@ def evaluate_over_trace(self, node):
390391
node = self.to_flat_input(node)
391392

392393
def sample(post, *_):
393-
return pytensor.clone_replace(node, {self.input: post})
394+
return graph_replace(node, {self.input: post}, strict=False)
394395

395396
nodes, _ = pytensor.scan(
396397
sample, self.histogram, non_sequences=_known_scan_ignored_inputs(makeiter(node))

pymc/variational/opvi.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
import xarray
6060

6161
from pytensor.graph.basic import Variable
62+
from pytensor.graph.replace import graph_replace
63+
from pytensor.tensor.shape import unbroadcast
6264

6365
import pymc as pm
6466

@@ -1002,7 +1004,7 @@ def set_size_and_deterministic(
10021004
"""
10031005

10041006
flat2rand = self.make_size_and_deterministic_replacements(s, d, more_replacements)
1005-
node_out = pytensor.clone_replace(node, flat2rand)
1007+
node_out = graph_replace(node, flat2rand, strict=False)
10061008
assert not (
10071009
set(makeiter(self.input)) & set(pytensor.graph.graph_inputs(makeiter(node_out)))
10081010
)
@@ -1012,7 +1014,7 @@ def set_size_and_deterministic(
10121014

10131015
def to_flat_input(self, node):
10141016
"""*Dev* - replace vars with flattened view stored in `self.inputs`"""
1015-
return pytensor.clone_replace(node, self.replacements)
1017+
return graph_replace(node, self.replacements, strict=False)
10161018

10171019
def symbolic_sample_over_posterior(self, node):
10181020
"""*Dev* - performs sampling of node applying independent samples from posterior each time.
@@ -1023,7 +1025,7 @@ def symbolic_sample_over_posterior(self, node):
10231025
random = pt.specify_shape(random, self.symbolic_initial.type.shape)
10241026

10251027
def sample(post, *_):
1026-
return pytensor.clone_replace(node, {self.input: post})
1028+
return graph_replace(node, {self.input: post}, strict=False)
10271029

10281030
nodes, _ = pytensor.scan(
10291031
sample, random, non_sequences=_known_scan_ignored_inputs(makeiter(random))
@@ -1038,7 +1040,7 @@ def symbolic_single_sample(self, node):
10381040
"""
10391041
node = self.to_flat_input(node)
10401042
random = self.symbolic_random.astype(self.symbolic_initial.dtype)
1041-
return pytensor.clone_replace(node, {self.input: random[0]})
1043+
return graph_replace(node, {self.input: random[0]}, strict=False)
10421044

10431045
def make_size_and_deterministic_replacements(self, s, d, more_replacements=None):
10441046
"""*Dev* - creates correct replacements for initial depending on
@@ -1059,8 +1061,15 @@ def make_size_and_deterministic_replacements(self, s, d, more_replacements=None)
10591061
"""
10601062
initial = self._new_initial(s, d, more_replacements)
10611063
initial = pt.specify_shape(initial, self.symbolic_initial.type.shape)
1064+
# The static shape of initial may be more precise than self.symbolic_initial,
1065+
# and reveal previously unknown broadcastable dimensions. We have to mask those again.
1066+
if initial.type.broadcastable != self.symbolic_initial.type.broadcastable:
1067+
unbroadcast_axes = (
1068+
i for i, b in enumerate(self.symbolic_initial.type.broadcastable) if not b
1069+
)
1070+
initial = unbroadcast(initial, *unbroadcast_axes)
10621071
if more_replacements:
1063-
initial = pytensor.clone_replace(initial, more_replacements)
1072+
initial = graph_replace(initial, more_replacements, strict=False)
10641073
return {self.symbolic_initial: initial}
10651074

10661075
@node_property
@@ -1394,17 +1403,17 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None):
13941403
_node = node
13951404
optimizations = self.get_optimization_replacements(s, d)
13961405
flat2rand = self.make_size_and_deterministic_replacements(s, d, more_replacements)
1397-
node = pytensor.clone_replace(node, optimizations)
1398-
node = pytensor.clone_replace(node, flat2rand)
1406+
node = graph_replace(node, optimizations, strict=False)
1407+
node = graph_replace(node, flat2rand, strict=False)
13991408
assert not (set(self.symbolic_randoms) & set(pytensor.graph.graph_inputs(makeiter(node))))
14001409
try_to_set_test_value(_node, node, s)
14011410
return node
14021411

14031412
def to_flat_input(self, node, more_replacements=None):
14041413
"""*Dev* - replace vars with flattened view stored in `self.inputs`"""
14051414
more_replacements = more_replacements or {}
1406-
node = pytensor.clone_replace(node, more_replacements)
1407-
return pytensor.clone_replace(node, self.replacements)
1415+
node = graph_replace(node, more_replacements, strict=False)
1416+
return graph_replace(node, self.replacements, strict=False)
14081417

14091418
def symbolic_sample_over_posterior(self, node, more_replacements=None):
14101419
"""*Dev* - performs sampling of node applying independent samples from posterior each time.
@@ -1413,7 +1422,7 @@ def symbolic_sample_over_posterior(self, node, more_replacements=None):
14131422
node = self.to_flat_input(node)
14141423

14151424
def sample(*post):
1416-
return pytensor.clone_replace(node, dict(zip(self.inputs, post)))
1425+
return graph_replace(node, dict(zip(self.inputs, post)), strict=False)
14171426

14181427
nodes, _ = pytensor.scan(
14191428
sample, self.symbolic_randoms, non_sequences=_known_scan_ignored_inputs(makeiter(node))
@@ -1429,7 +1438,7 @@ def symbolic_single_sample(self, node, more_replacements=None):
14291438
node = self.to_flat_input(node, more_replacements=more_replacements)
14301439
post = [v[0] for v in self.symbolic_randoms]
14311440
inp = self.inputs
1432-
return pytensor.clone_replace(node, dict(zip(inp, post)))
1441+
return graph_replace(node, dict(zip(inp, post)), strict=False)
14331442

14341443
def get_optimization_replacements(self, s, d):
14351444
"""*Dev* - optimizations for logP. If sample size is static and equal to 1:
@@ -1463,7 +1472,7 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No
14631472
"""
14641473
node_in = node
14651474
if more_replacements:
1466-
node = pytensor.clone_replace(node, more_replacements)
1475+
node = graph_replace(node, more_replacements, strict=False)
14671476
if not isinstance(node, (list, tuple)):
14681477
node = [node]
14691478
node = self.model.replace_rvs_by_values(node)

pymc/variational/stein.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import pytensor
1615
import pytensor.tensor as pt
1716

17+
from pytensor.graph.replace import graph_replace
18+
1819
from pymc.pytensorf import floatX
1920
from pymc.util import WithMemoization, locally_cachedmethod
2021
from pymc.variational.opvi import node_property
@@ -85,9 +86,10 @@ def dxkxy(self):
8586
def logp_norm(self):
8687
sized_symbolic_logp = self.approx.sized_symbolic_logp
8788
if self.use_histogram:
88-
sized_symbolic_logp = pytensor.clone_replace(
89+
sized_symbolic_logp = graph_replace(
8990
sized_symbolic_logp,
9091
dict(zip(self.approx.symbolic_randoms, self.approx.collect("histogram"))),
92+
strict=False,
9193
)
9294
return sized_symbolic_logp / self.approx.symbolic_normalizing_constant
9395

0 commit comments

Comments
 (0)