-
Notifications
You must be signed in to change notification settings - Fork 130
Walking nested graph rewriter #556
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -31,7 +31,7 @@ | |||
) | ||||
from pytensor.graph.features import AlreadyThere, Feature, NodeFinder | ||||
from pytensor.graph.fg import FunctionGraph | ||||
from pytensor.graph.op import Op | ||||
from pytensor.graph.op import HasInnerGraph, Op | ||||
from pytensor.graph.utils import AssocList, InconsistencyError | ||||
from pytensor.misc.ordered_set import OrderedSet | ||||
from pytensor.utils import flatten | ||||
|
@@ -2020,7 +2020,7 @@ def apply(self, fgraph, start_from=None): | |||
io_t = time.perf_counter() - t0 | ||||
|
||||
def importer(node): | ||||
if node is not current_node: | ||||
if node is not current_node and not isinstance(node.op, HasInnerGraph): | ||||
q.append(node) | ||||
|
||||
u = self.attach_updater( | ||||
|
@@ -2030,6 +2030,7 @@ def importer(node): | |||
try: | ||||
t0 = time.perf_counter() | ||||
while q: | ||||
# breakpoint | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
if self.order == "out_to_in": | ||||
node = q.pop() | ||||
else: | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
from typing import Optional | ||
|
||
from pytensor.compile.builders import OpFromGraph | ||
from pytensor.graph.basic import Apply, io_toposort | ||
from pytensor.graph.fg import FunctionGraph | ||
from pytensor.graph.replace import clone_replace | ||
from pytensor.graph.rewriting.basic import NodeRewriter, WalkingGraphRewriter | ||
from pytensor.scan.op import Scan | ||
|
||
|
||
class WalkingNestedGraphRewriter(WalkingGraphRewriter): | ||
def process_node( | ||
self, | ||
fgraph: FunctionGraph, | ||
node: Apply, | ||
node_rewriter: Optional[NodeRewriter] = None, | ||
): | ||
self.node_rewriter = node_rewriter or self.node_rewriter | ||
if isinstance(node.op, (Scan, OpFromGraph)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should just check |
||
return self.process_scan_node(fgraph, node) | ||
else: | ||
return super().process_node(fgraph, node) | ||
|
||
def process_scan_node(self, fgraph: FunctionGraph, node: Apply): | ||
try: | ||
replacements = self.transform_scan_node(fgraph, node) | ||
except Exception as e: | ||
if self.failure_callback is not None: | ||
self.failure_callback( | ||
e, | ||
self, | ||
[(x, None) for x in node.outputs], | ||
self.node_rewriter, # type: ignore | ||
node, | ||
) | ||
return False | ||
else: | ||
raise | ||
if replacements is False or replacements is None: | ||
return False | ||
|
||
repl_pairs = zip(node.outputs, replacements) | ||
try: | ||
fgraph.replace_all_validate_remove( # type: ignore | ||
repl_pairs, | ||
reason=self.node_rewriter, | ||
remove=[], | ||
) | ||
return True | ||
except Exception as e: | ||
# This means the replacements were rejected by the fgraph. | ||
# | ||
# This is not supposed to happen. The default failure_callback | ||
# will print a traceback as a warning. | ||
if self.failure_callback is not None: | ||
self.failure_callback( | ||
e, | ||
self, | ||
repl_pairs, # type: ignore | ||
self.node_rewriter, # type: ignore | ||
node, | ||
) | ||
return False | ||
else: | ||
raise | ||
|
||
def transform_scan_node(self, fgraph, node): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should try a simpler approach. A WalkingGraphRewriter already has all the logic to apply rewrites over a graph. Can we just use it in the inner graph of the Op? Perhaps inside if isinstance(node.op, HasInnerGraphOp):
self.apply(node.op.fgraph) |
||
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs | ||
local_fgraph_topo = io_toposort(node_inputs, node_outputs) | ||
op = node.op | ||
|
||
givens = dict() | ||
to_remove_set = set() | ||
for nd in local_fgraph_topo: | ||
if nd not in to_remove_set: | ||
if isinstance(nd.op, (Scan, OpFromGraph)): | ||
[new_node] = self.transform_scan_node(node.op.fgraph, nd) | ||
if new_node is not None: | ||
givens.update(zip(nd.outputs, new_node.owner.outputs)) | ||
to_remove_set.add(nd) | ||
else: | ||
replacements = self.node_rewriter.transform(node.op.fgraph, nd) | ||
if replacements is False or replacements is None: | ||
pass | ||
elif not isinstance(replacements, (tuple, list, dict)): | ||
raise TypeError( | ||
f"Node rewriter {self.node_rewriter} gave wrong type of replacement. " | ||
f"Expected list, tuple or dict; got {replacements}" | ||
) | ||
elif isinstance(replacements, (list, tuple)): | ||
if len(nd.outputs) != len(replacements): | ||
raise ValueError( | ||
f"Node rewriter {self.node_rewriter} gave wrong number of replacements" | ||
) | ||
givens.update(zip(nd.outputs, replacements)) | ||
to_remove_set.add(nd) | ||
elif isinstance(replacements, dict): | ||
to_remove_set.add(nd) | ||
for key, value in replacements.items(): | ||
if key == "remove": | ||
for item in value: | ||
givens[item] = None | ||
else: | ||
givens[key] = value | ||
|
||
if len(to_remove_set) == 0: | ||
return None | ||
op_outs = clone_replace(node_outputs, replace=givens) | ||
if isinstance(op, Scan): | ||
nwScan = Scan( | ||
node_inputs, | ||
op_outs, | ||
op.info, | ||
mode=op.mode, | ||
profile=op.profile, | ||
truncate_gradient=op.truncate_gradient, | ||
name=op.name, | ||
allow_gc=op.allow_gc, | ||
) | ||
nw_node = nwScan(*(node.inputs), return_list=True) | ||
|
||
else: | ||
nwOpFromGraph = OpFromGraph( | ||
node_inputs, | ||
op_outs, | ||
op.is_inline, | ||
op.lop_overrides, | ||
op.grad_overrides, | ||
op.rop_overrides, | ||
connection_pattern=op._connection_pattern, | ||
name=op.name, | ||
**op.kwargs, | ||
) | ||
nw_node = nwOpFromGraph(*(node.inputs), return_list=True) | ||
return nw_node |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from pytensor.compile.builders import OpFromGraph | ||
from pytensor.graph.fg import FunctionGraph | ||
from pytensor.graph.rewriting.basic import NodeRewriter | ||
from pytensor.graph.rewriting.nested import WalkingNestedGraphRewriter | ||
from pytensor.scan import scan | ||
from pytensor.tensor.elemwise import Elemwise | ||
from pytensor.tensor.math import sum | ||
from pytensor.tensor.type import matrix, scalar, vector | ||
|
||
|
||
class TestWalkingNestedGraphRewriter: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should do a more graph based test, and make sure the graph changed as we expected, and not just a counter variable that could be updated even without the changes working correctly |
||
def apply_rewrites(self, fgraph): | ||
cnt = 0 | ||
|
||
class MyRwriter(NodeRewriter): | ||
def transform(self, fgraph, node): | ||
nonlocal cnt | ||
if isinstance(node.op, Elemwise): | ||
cnt += 1 | ||
return node.outputs | ||
|
||
node_rewriter = MyRwriter() | ||
rewriter = WalkingNestedGraphRewriter(node_rewriter) | ||
rewriter.apply(fgraph) | ||
return cnt | ||
|
||
def test_rewrite_in_scan(self): | ||
def scan_step(x_0): | ||
x = x_0 + 1 | ||
return x | ||
|
||
x_0 = vector("x_0") | ||
result, _ = scan( | ||
scan_step, | ||
outputs_info=None, | ||
sequences=x_0, | ||
) | ||
x = sum(result) + 1 | ||
graph = FunctionGraph([x_0], [x], clone=False) | ||
|
||
rewrites_cnt = self.apply_rewrites(graph) | ||
|
||
# one replacemnt in the scan inner grap and one in outer graph | ||
assert rewrites_cnt == 2 | ||
|
||
def test_rewrite_in_nested_scan(self): | ||
def inner_scan_step(x_0): | ||
x = x_0 + 1 | ||
return x | ||
|
||
def outer_scan_step(x_0): | ||
x, _ = scan( | ||
fn=inner_scan_step, | ||
sequences=x_0, | ||
outputs_info=None, | ||
) | ||
x = x + 1 | ||
return x | ||
|
||
x_0 = matrix("x_0") | ||
result, _ = scan( | ||
fn=outer_scan_step, | ||
sequences=x_0, | ||
outputs_info=None, | ||
) | ||
|
||
graph = FunctionGraph([x_0], [result], clone=False) | ||
rewrites_cnt = self.apply_rewrites(graph) | ||
# one replacemnt in the inner scan and one in outer scan | ||
assert rewrites_cnt == 2 | ||
|
||
def test_rewrite_op_from_graph(self): | ||
x, y, z = scalar("x"), scalar("y"), scalar("z") | ||
e = x + y * z | ||
op = OpFromGraph([x, y, z], [e]) | ||
e2 = op(x, y, z) + op(z, y, x) | ||
graph = FunctionGraph([x, y, z], [e2], clone=False) | ||
|
||
rewrites_cnt = self.apply_rewrites(graph) | ||
# two rewrites in each OpFromGraph inner graphs and one in outer graph | ||
assert rewrites_cnt == 5 | ||
|
||
def test_rewrite_nested_op_from_graph(self): | ||
x, y, z = scalar("x"), scalar("y"), scalar("z") | ||
e = x + y | ||
op = OpFromGraph([x, y], [e]) | ||
e2 = op(x, y) * op(x, y) | ||
op2 = OpFromGraph([x, y], [e2]) | ||
e3 = op2(x, y) + z | ||
graph = FunctionGraph([x, y, z], [e3], clone=False) | ||
|
||
rewrites_cnt = self.apply_rewrites(graph) | ||
# two rewrites in inner most OpFromGraph, one in second OpFromGraph, and one in outer graph | ||
assert rewrites_cnt == 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change was for debugging I assume?