Skip to content

Walking nested graph rewriter #556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pytensor/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)
from pytensor.graph.features import AlreadyThere, Feature, NodeFinder
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.utils import AssocList, InconsistencyError
from pytensor.misc.ordered_set import OrderedSet
from pytensor.utils import flatten
Expand Down Expand Up @@ -2020,7 +2020,7 @@ def apply(self, fgraph, start_from=None):
io_t = time.perf_counter() - t0

def importer(node):
if node is not current_node:
if node is not current_node and not isinstance(node.op, HasInnerGraph):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was for debugging I assume?

q.append(node)

u = self.attach_updater(
Expand All @@ -2030,6 +2030,7 @@ def importer(node):
try:
t0 = time.perf_counter()
while q:
# breakpoint
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# breakpoint

if self.order == "out_to_in":
node = q.pop()
else:
Expand Down
135 changes: 135 additions & 0 deletions pytensor/graph/rewriting/nested.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from typing import Optional

from pytensor.compile.builders import OpFromGraph
from pytensor.graph.basic import Apply, io_toposort
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import NodeRewriter, WalkingGraphRewriter
from pytensor.scan.op import Scan


class WalkingNestedGraphRewriter(WalkingGraphRewriter):
def process_node(
self,
fgraph: FunctionGraph,
node: Apply,
node_rewriter: Optional[NodeRewriter] = None,
):
self.node_rewriter = node_rewriter or self.node_rewriter
if isinstance(node.op, (Scan, OpFromGraph)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just check HasInnerGraphOp, and not restrict the kind of Op. A rewrite at the PyTensor level should be more general than what we were working with in PyMC

return self.process_scan_node(fgraph, node)
else:
return super().process_node(fgraph, node)

def process_scan_node(self, fgraph: FunctionGraph, node: Apply):
try:
replacements = self.transform_scan_node(fgraph, node)
except Exception as e:
if self.failure_callback is not None:
self.failure_callback(
e,
self,
[(x, None) for x in node.outputs],
self.node_rewriter, # type: ignore
node,
)
return False
else:
raise
if replacements is False or replacements is None:
return False

repl_pairs = zip(node.outputs, replacements)
try:
fgraph.replace_all_validate_remove( # type: ignore
repl_pairs,
reason=self.node_rewriter,
remove=[],
)
return True
except Exception as e:
# This means the replacements were rejected by the fgraph.
#
# This is not supposed to happen. The default failure_callback
# will print a traceback as a warning.
if self.failure_callback is not None:
self.failure_callback(
e,
self,
repl_pairs, # type: ignore
self.node_rewriter, # type: ignore
node,
)
return False
else:
raise

def transform_scan_node(self, fgraph, node):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should try a simpler approach. A WalkingGraphRewriter already has all the logic to apply rewrites over a graph. Can we just use it in the inner graph of the Op?

Perhaps inside process_node we have something like:

if isinstance(node.op, HasInnerGraphOp):
  self.apply(node.op.fgraph)

node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
op = node.op

givens = dict()
to_remove_set = set()
for nd in local_fgraph_topo:
if nd not in to_remove_set:
if isinstance(nd.op, (Scan, OpFromGraph)):
[new_node] = self.transform_scan_node(node.op.fgraph, nd)
if new_node is not None:
givens.update(zip(nd.outputs, new_node.owner.outputs))
to_remove_set.add(nd)
else:
replacements = self.node_rewriter.transform(node.op.fgraph, nd)
if replacements is False or replacements is None:
pass
elif not isinstance(replacements, (tuple, list, dict)):
raise TypeError(
f"Node rewriter {self.node_rewriter} gave wrong type of replacement. "
f"Expected list, tuple or dict; got {replacements}"
)
elif isinstance(replacements, (list, tuple)):
if len(nd.outputs) != len(replacements):
raise ValueError(
f"Node rewriter {self.node_rewriter} gave wrong number of replacements"
)
givens.update(zip(nd.outputs, replacements))
to_remove_set.add(nd)
elif isinstance(replacements, dict):
to_remove_set.add(nd)
for key, value in replacements.items():
if key == "remove":
for item in value:
givens[item] = None
else:
givens[key] = value

if len(to_remove_set) == 0:
return None
op_outs = clone_replace(node_outputs, replace=givens)
if isinstance(op, Scan):
nwScan = Scan(
node_inputs,
op_outs,
op.info,
mode=op.mode,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
name=op.name,
allow_gc=op.allow_gc,
)
nw_node = nwScan(*(node.inputs), return_list=True)

else:
nwOpFromGraph = OpFromGraph(
node_inputs,
op_outs,
op.is_inline,
op.lop_overrides,
op.grad_overrides,
op.rop_overrides,
connection_pattern=op._connection_pattern,
name=op.name,
**op.kwargs,
)
nw_node = nwOpFromGraph(*(node.inputs), return_list=True)
return nw_node
94 changes: 94 additions & 0 deletions tests/graph/rewriting/test_nested.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from pytensor.compile.builders import OpFromGraph
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import NodeRewriter
from pytensor.graph.rewriting.nested import WalkingNestedGraphRewriter
from pytensor.scan import scan
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.math import sum
from pytensor.tensor.type import matrix, scalar, vector


class TestWalkingNestedGraphRewriter:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should do a more graph based test, and make sure the graph changed as we expected, and not just a counter variable that could be updated even without the changes working correctly

def apply_rewrites(self, fgraph):
cnt = 0

class MyRwriter(NodeRewriter):
def transform(self, fgraph, node):
nonlocal cnt
if isinstance(node.op, Elemwise):
cnt += 1
return node.outputs

node_rewriter = MyRwriter()
rewriter = WalkingNestedGraphRewriter(node_rewriter)
rewriter.apply(fgraph)
return cnt

def test_rewrite_in_scan(self):
def scan_step(x_0):
x = x_0 + 1
return x

x_0 = vector("x_0")
result, _ = scan(
scan_step,
outputs_info=None,
sequences=x_0,
)
x = sum(result) + 1
graph = FunctionGraph([x_0], [x], clone=False)

rewrites_cnt = self.apply_rewrites(graph)

# one replacemnt in the scan inner grap and one in outer graph
assert rewrites_cnt == 2

def test_rewrite_in_nested_scan(self):
def inner_scan_step(x_0):
x = x_0 + 1
return x

def outer_scan_step(x_0):
x, _ = scan(
fn=inner_scan_step,
sequences=x_0,
outputs_info=None,
)
x = x + 1
return x

x_0 = matrix("x_0")
result, _ = scan(
fn=outer_scan_step,
sequences=x_0,
outputs_info=None,
)

graph = FunctionGraph([x_0], [result], clone=False)
rewrites_cnt = self.apply_rewrites(graph)
# one replacemnt in the inner scan and one in outer scan
assert rewrites_cnt == 2

def test_rewrite_op_from_graph(self):
x, y, z = scalar("x"), scalar("y"), scalar("z")
e = x + y * z
op = OpFromGraph([x, y, z], [e])
e2 = op(x, y, z) + op(z, y, x)
graph = FunctionGraph([x, y, z], [e2], clone=False)

rewrites_cnt = self.apply_rewrites(graph)
# two rewrites in each OpFromGraph inner graphs and one in outer graph
assert rewrites_cnt == 5

def test_rewrite_nested_op_from_graph(self):
x, y, z = scalar("x"), scalar("y"), scalar("z")
e = x + y
op = OpFromGraph([x, y], [e])
e2 = op(x, y) * op(x, y)
op2 = OpFromGraph([x, y], [e2])
e3 = op2(x, y) + z
graph = FunctionGraph([x, y, z], [e3], clone=False)

rewrites_cnt = self.apply_rewrites(graph)
# two rewrites in inner most OpFromGraph, one in second OpFromGraph, and one in outer graph
assert rewrites_cnt == 4