From 2c0ea95617f750ffe552da3ab3e475760814dda1 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Wed, 6 Dec 2023 14:45:14 +0300 Subject: [PATCH 1/2] add initial implementation of node rewriter --- pytensor/graph/rewriting/basic.py | 117 ++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index 4d4cfe02f5..a471fb8444 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Callable, Literal, Optional, Union, cast import pytensor +from pytensor.compile.builders import OpFromGraph from pytensor.configdefaults import config from pytensor.graph import destroyhandler as dh from pytensor.graph.basic import ( @@ -32,8 +33,10 @@ from pytensor.graph.features import AlreadyThere, Feature, NodeFinder from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op +from pytensor.graph.replace import clone_replace from pytensor.graph.utils import AssocList, InconsistencyError from pytensor.misc.ordered_set import OrderedSet +from pytensor.scan import Scan from pytensor.utils import flatten @@ -2141,6 +2144,120 @@ def walking_rewriter( out2in = partial(walking_rewriter, "out_to_in") +class WalkingNestedGraphRewriter(WalkingGraphRewriter): + def process_node(self, fgraph: FunctionGraph, node: Apply): + if isinstance(node.op, (Scan, OpFromGraph)): + return self.process_scan_node(fgraph, node) + else: + return super().process_node(fgraph, node) + + def process_scan_node(self, fgraph: FunctionGraph, node: Apply): + try: + replacements = self.transform_scan_node(fgraph, node) + except Exception as e: + if self.failure_callback is not None: + self.failure_callback( + e, self, [(x, None) for x in node.outputs], node_rewriter, node + ) + return False + else: + raise + if replacements is False or replacements is None: + return False + + repl_pairs = zip(node.outputs, replacements) + try: + fgraph.replace_all_validate_remove( # type: ignore + repl_pairs, + reason=node_rewriter, + ) + return True + except Exception as e: + # This means the replacements were rejected by the fgraph. + # + # This is not supposed to happen. The default failure_callback + # will print a traceback as a warning. + if self.failure_callback is not None: + self.failure_callback(e, self, repl_pairs, node_rewriter, node) + return False + else: + raise + + def transform_scan_node( + self, + fgraph: FunctionGraph, + node: Union[Scan, OpFromGraph], + ) -> Union[Scan, OpFromGraph, None]: + node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs + local_fgraph_topo = io_toposort(node_inputs, node_outputs) + op = node.op + + givens = dict() + to_remove_set = set() + for nd in local_fgraph_topo: + if nd not in to_remove_set: + if isinstance(nd.op, (Scan, OpFromGraph)): + new_node = self.transform_scan_node(node.op.inner_graph, nd) + if new_node is not None: + givens.update(zip(nd.outputs, new_node.owner.outputs)) + to_remove_set.add(nd) + else: + replacements = self.transform(node.op.inner_graph, nd) + if replacements is False or replacements is None: + pass + elif not isinstance(replacements, (tuple, list, dict)): + raise TypeError( + f"Node rewriter {node_rewriter} gave wrong type of replacement. " + f"Expected list, tuple or dict; got {replacements}" + ) + elif isinstance(replacements, (list, tuple)): + if len(nd.outputs) != len(replacements): + raise ValueError( + f"Node rewriter {node_rewriter} gave wrong number of replacements" + ) + givens.update(zip(nd.outputs, replacements)) + to_remove_set.add(nd) + elif isinstance(replacements, dict): + to_remove_set.add(nd) + for key, value in replacements.items(): + if key == "remove": + for item in value: + givens[item] = None + else: + givens[key] = value + + if len(to_remove_set) == 0: + return None + op_outs = clone_replace(node_outputs, replace=givens) + if isinstance(op, Scan): + nwScan = Scan( + node_inputs, + op_outs, + op.info, + mode=op.mode, + profile=op.profile, + truncate_gradient=op.truncate_gradient, + name=op.name, + allow_gc=op.allow_gc, + ) + nw_node = nwScan(*(node.inputs), return_list=True) + + else: + nwOpFromGraph = OpFromGraph( + node_inputs, + op_outs, + op.is_inline, + op.lop_overrides, + op.grad_overrides, + op.rop_overrides, + connection_pattern=op._connection_pattern, + name=op.name, + **op.kwargs, + ) + nw_node = nwOpFromGraph(*(node.inputs), return_list=True) + return nw_node + + class OpKeyGraphRewriter(NodeProcessingGraphRewriter): r"""A rewriter that applies a `NodeRewriter` to specific `Op`\s. From 58f66ef0d63f1a08da8af21d4a30a97197252d7a Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Thu, 14 Dec 2023 19:47:59 +0300 Subject: [PATCH 2/2] move implementation in separate module and add tests --- pytensor/graph/rewriting/basic.py | 122 +----------------------- pytensor/graph/rewriting/nested.py | 135 +++++++++++++++++++++++++++ tests/graph/rewriting/test_nested.py | 94 +++++++++++++++++++ 3 files changed, 232 insertions(+), 119 deletions(-) create mode 100644 pytensor/graph/rewriting/nested.py create mode 100644 tests/graph/rewriting/test_nested.py diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index a471fb8444..26c04ca0a7 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Callable, Literal, Optional, Union, cast import pytensor -from pytensor.compile.builders import OpFromGraph from pytensor.configdefaults import config from pytensor.graph import destroyhandler as dh from pytensor.graph.basic import ( @@ -32,11 +31,9 @@ ) from pytensor.graph.features import AlreadyThere, Feature, NodeFinder from pytensor.graph.fg import FunctionGraph -from pytensor.graph.op import Op -from pytensor.graph.replace import clone_replace +from pytensor.graph.op import HasInnerGraph, Op from pytensor.graph.utils import AssocList, InconsistencyError from pytensor.misc.ordered_set import OrderedSet -from pytensor.scan import Scan from pytensor.utils import flatten @@ -2023,7 +2020,7 @@ def apply(self, fgraph, start_from=None): io_t = time.perf_counter() - t0 def importer(node): - if node is not current_node: + if node is not current_node and not isinstance(node.op, HasInnerGraph): q.append(node) u = self.attach_updater( @@ -2033,6 +2030,7 @@ def importer(node): try: t0 = time.perf_counter() while q: + # breakpoint if self.order == "out_to_in": node = q.pop() else: @@ -2144,120 +2142,6 @@ def walking_rewriter( out2in = partial(walking_rewriter, "out_to_in") -class WalkingNestedGraphRewriter(WalkingGraphRewriter): - def process_node(self, fgraph: FunctionGraph, node: Apply): - if isinstance(node.op, (Scan, OpFromGraph)): - return self.process_scan_node(fgraph, node) - else: - return super().process_node(fgraph, node) - - def process_scan_node(self, fgraph: FunctionGraph, node: Apply): - try: - replacements = self.transform_scan_node(fgraph, node) - except Exception as e: - if self.failure_callback is not None: - self.failure_callback( - e, self, [(x, None) for x in node.outputs], node_rewriter, node - ) - return False - else: - raise - if replacements is False or replacements is None: - return False - - repl_pairs = zip(node.outputs, replacements) - try: - fgraph.replace_all_validate_remove( # type: ignore - repl_pairs, - reason=node_rewriter, - ) - return True - except Exception as e: - # This means the replacements were rejected by the fgraph. - # - # This is not supposed to happen. The default failure_callback - # will print a traceback as a warning. - if self.failure_callback is not None: - self.failure_callback(e, self, repl_pairs, node_rewriter, node) - return False - else: - raise - - def transform_scan_node( - self, - fgraph: FunctionGraph, - node: Union[Scan, OpFromGraph], - ) -> Union[Scan, OpFromGraph, None]: - node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs - local_fgraph_topo = io_toposort(node_inputs, node_outputs) - op = node.op - - givens = dict() - to_remove_set = set() - for nd in local_fgraph_topo: - if nd not in to_remove_set: - if isinstance(nd.op, (Scan, OpFromGraph)): - new_node = self.transform_scan_node(node.op.inner_graph, nd) - if new_node is not None: - givens.update(zip(nd.outputs, new_node.owner.outputs)) - to_remove_set.add(nd) - else: - replacements = self.transform(node.op.inner_graph, nd) - if replacements is False or replacements is None: - pass - elif not isinstance(replacements, (tuple, list, dict)): - raise TypeError( - f"Node rewriter {node_rewriter} gave wrong type of replacement. " - f"Expected list, tuple or dict; got {replacements}" - ) - elif isinstance(replacements, (list, tuple)): - if len(nd.outputs) != len(replacements): - raise ValueError( - f"Node rewriter {node_rewriter} gave wrong number of replacements" - ) - givens.update(zip(nd.outputs, replacements)) - to_remove_set.add(nd) - elif isinstance(replacements, dict): - to_remove_set.add(nd) - for key, value in replacements.items(): - if key == "remove": - for item in value: - givens[item] = None - else: - givens[key] = value - - if len(to_remove_set) == 0: - return None - op_outs = clone_replace(node_outputs, replace=givens) - if isinstance(op, Scan): - nwScan = Scan( - node_inputs, - op_outs, - op.info, - mode=op.mode, - profile=op.profile, - truncate_gradient=op.truncate_gradient, - name=op.name, - allow_gc=op.allow_gc, - ) - nw_node = nwScan(*(node.inputs), return_list=True) - - else: - nwOpFromGraph = OpFromGraph( - node_inputs, - op_outs, - op.is_inline, - op.lop_overrides, - op.grad_overrides, - op.rop_overrides, - connection_pattern=op._connection_pattern, - name=op.name, - **op.kwargs, - ) - nw_node = nwOpFromGraph(*(node.inputs), return_list=True) - return nw_node - - class OpKeyGraphRewriter(NodeProcessingGraphRewriter): r"""A rewriter that applies a `NodeRewriter` to specific `Op`\s. diff --git a/pytensor/graph/rewriting/nested.py b/pytensor/graph/rewriting/nested.py new file mode 100644 index 0000000000..66e5565e56 --- /dev/null +++ b/pytensor/graph/rewriting/nested.py @@ -0,0 +1,135 @@ +from typing import Optional + +from pytensor.compile.builders import OpFromGraph +from pytensor.graph.basic import Apply, io_toposort +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.replace import clone_replace +from pytensor.graph.rewriting.basic import NodeRewriter, WalkingGraphRewriter +from pytensor.scan.op import Scan + + +class WalkingNestedGraphRewriter(WalkingGraphRewriter): + def process_node( + self, + fgraph: FunctionGraph, + node: Apply, + node_rewriter: Optional[NodeRewriter] = None, + ): + self.node_rewriter = node_rewriter or self.node_rewriter + if isinstance(node.op, (Scan, OpFromGraph)): + return self.process_scan_node(fgraph, node) + else: + return super().process_node(fgraph, node) + + def process_scan_node(self, fgraph: FunctionGraph, node: Apply): + try: + replacements = self.transform_scan_node(fgraph, node) + except Exception as e: + if self.failure_callback is not None: + self.failure_callback( + e, + self, + [(x, None) for x in node.outputs], + self.node_rewriter, # type: ignore + node, + ) + return False + else: + raise + if replacements is False or replacements is None: + return False + + repl_pairs = zip(node.outputs, replacements) + try: + fgraph.replace_all_validate_remove( # type: ignore + repl_pairs, + reason=self.node_rewriter, + remove=[], + ) + return True + except Exception as e: + # This means the replacements were rejected by the fgraph. + # + # This is not supposed to happen. The default failure_callback + # will print a traceback as a warning. + if self.failure_callback is not None: + self.failure_callback( + e, + self, + repl_pairs, # type: ignore + self.node_rewriter, # type: ignore + node, + ) + return False + else: + raise + + def transform_scan_node(self, fgraph, node): + node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs + local_fgraph_topo = io_toposort(node_inputs, node_outputs) + op = node.op + + givens = dict() + to_remove_set = set() + for nd in local_fgraph_topo: + if nd not in to_remove_set: + if isinstance(nd.op, (Scan, OpFromGraph)): + [new_node] = self.transform_scan_node(node.op.fgraph, nd) + if new_node is not None: + givens.update(zip(nd.outputs, new_node.owner.outputs)) + to_remove_set.add(nd) + else: + replacements = self.node_rewriter.transform(node.op.fgraph, nd) + if replacements is False or replacements is None: + pass + elif not isinstance(replacements, (tuple, list, dict)): + raise TypeError( + f"Node rewriter {self.node_rewriter} gave wrong type of replacement. " + f"Expected list, tuple or dict; got {replacements}" + ) + elif isinstance(replacements, (list, tuple)): + if len(nd.outputs) != len(replacements): + raise ValueError( + f"Node rewriter {self.node_rewriter} gave wrong number of replacements" + ) + givens.update(zip(nd.outputs, replacements)) + to_remove_set.add(nd) + elif isinstance(replacements, dict): + to_remove_set.add(nd) + for key, value in replacements.items(): + if key == "remove": + for item in value: + givens[item] = None + else: + givens[key] = value + + if len(to_remove_set) == 0: + return None + op_outs = clone_replace(node_outputs, replace=givens) + if isinstance(op, Scan): + nwScan = Scan( + node_inputs, + op_outs, + op.info, + mode=op.mode, + profile=op.profile, + truncate_gradient=op.truncate_gradient, + name=op.name, + allow_gc=op.allow_gc, + ) + nw_node = nwScan(*(node.inputs), return_list=True) + + else: + nwOpFromGraph = OpFromGraph( + node_inputs, + op_outs, + op.is_inline, + op.lop_overrides, + op.grad_overrides, + op.rop_overrides, + connection_pattern=op._connection_pattern, + name=op.name, + **op.kwargs, + ) + nw_node = nwOpFromGraph(*(node.inputs), return_list=True) + return nw_node diff --git a/tests/graph/rewriting/test_nested.py b/tests/graph/rewriting/test_nested.py new file mode 100644 index 0000000000..4a68ad61e7 --- /dev/null +++ b/tests/graph/rewriting/test_nested.py @@ -0,0 +1,94 @@ +from pytensor.compile.builders import OpFromGraph +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.rewriting.basic import NodeRewriter +from pytensor.graph.rewriting.nested import WalkingNestedGraphRewriter +from pytensor.scan import scan +from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.math import sum +from pytensor.tensor.type import matrix, scalar, vector + + +class TestWalkingNestedGraphRewriter: + def apply_rewrites(self, fgraph): + cnt = 0 + + class MyRwriter(NodeRewriter): + def transform(self, fgraph, node): + nonlocal cnt + if isinstance(node.op, Elemwise): + cnt += 1 + return node.outputs + + node_rewriter = MyRwriter() + rewriter = WalkingNestedGraphRewriter(node_rewriter) + rewriter.apply(fgraph) + return cnt + + def test_rewrite_in_scan(self): + def scan_step(x_0): + x = x_0 + 1 + return x + + x_0 = vector("x_0") + result, _ = scan( + scan_step, + outputs_info=None, + sequences=x_0, + ) + x = sum(result) + 1 + graph = FunctionGraph([x_0], [x], clone=False) + + rewrites_cnt = self.apply_rewrites(graph) + + # one replacemnt in the scan inner grap and one in outer graph + assert rewrites_cnt == 2 + + def test_rewrite_in_nested_scan(self): + def inner_scan_step(x_0): + x = x_0 + 1 + return x + + def outer_scan_step(x_0): + x, _ = scan( + fn=inner_scan_step, + sequences=x_0, + outputs_info=None, + ) + x = x + 1 + return x + + x_0 = matrix("x_0") + result, _ = scan( + fn=outer_scan_step, + sequences=x_0, + outputs_info=None, + ) + + graph = FunctionGraph([x_0], [result], clone=False) + rewrites_cnt = self.apply_rewrites(graph) + # one replacemnt in the inner scan and one in outer scan + assert rewrites_cnt == 2 + + def test_rewrite_op_from_graph(self): + x, y, z = scalar("x"), scalar("y"), scalar("z") + e = x + y * z + op = OpFromGraph([x, y, z], [e]) + e2 = op(x, y, z) + op(z, y, x) + graph = FunctionGraph([x, y, z], [e2], clone=False) + + rewrites_cnt = self.apply_rewrites(graph) + # two rewrites in each OpFromGraph inner graphs and one in outer graph + assert rewrites_cnt == 5 + + def test_rewrite_nested_op_from_graph(self): + x, y, z = scalar("x"), scalar("y"), scalar("z") + e = x + y + op = OpFromGraph([x, y], [e]) + e2 = op(x, y) * op(x, y) + op2 = OpFromGraph([x, y], [e2]) + e3 = op2(x, y) + z + graph = FunctionGraph([x, y, z], [e3], clone=False) + + rewrites_cnt = self.apply_rewrites(graph) + # two rewrites in inner most OpFromGraph, one in second OpFromGraph, and one in outer graph + assert rewrites_cnt == 4