Skip to content

Walking nested graph rewriter #556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

aerubanov
Copy link
Contributor

Implementation of WalkingNestedGraphRewriter which can apply node_rewriter for nodes of nested graphs (Scan and OpFromGraph). Close #529.

@aerubanov
Copy link
Contributor Author

Initially I placed WalkingNestedGraphRewriter in pytensor/graph/rewriting/basic.py but faced with cyclic import issue caused by OpFromGraph, so I moved it into separate module. Also not sure about tests, may be we need add more scenarios or use another approach.

@aerubanov
Copy link
Contributor Author

Also this changes should allow us to simplify implementation of default moment for CustomSymbolicDist from pymc-devs/pymc#6873

@aerubanov
Copy link
Contributor Author

@ricardoV94 could you please take a look on this PR when you will have some time?

@aerubanov
Copy link
Contributor Author

@ricardoV94 just friendly reminder about PR )

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aerubanov sorry for the delay. I left some comments. I think we should consider make the class simpler and more general if we can.

The gist of it is we have a GraphRewriter class that applies a NodeRewriter over nodes of an fgraph, and sometimes some nodes have an internal fgraph of their own. Can we just reuse the same class to work on this internal fgraphs?

The main concern is that by default FunctionGraph does not clone inner graphs, because rewrites until now were not usually mutating these inplace, but recreating Ops with new inner graphs.

The largest part of the new code avoids this by pretending to apply the rewrites, keeping the results and then applying them all at once. But this is problematic, because some NodeRewriter make decisions based on the current state of the fgraph, but the "current state" may be stale in the sense that we have already decided to apply another rewrite, just not done it yet. Had the graph been changed already, the NodeRewriter might not have applied itself.

A not very efficient solution would be to clone inner fgraph when we find them, apply the GraphRewriter, and if it changed, then recreate the new Op with the new inner fgraph. Maybe we can force all Ops that inherit from HasInnerGraphOp to implement a recreate_with_new_fgraph that accepts a new fgraph and recreates itself accordingly?

This would also fix the scope issue here, where our general rewrite has to know what a Scan or OpFromGraph are to be able to support them.

@@ -2030,6 +2030,7 @@ def importer(node):
try:
t0 = time.perf_counter()
while q:
# breakpoint
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# breakpoint

@@ -2020,7 +2020,7 @@ def apply(self, fgraph, start_from=None):
io_t = time.perf_counter() - t0

def importer(node):
if node is not current_node:
if node is not current_node and not isinstance(node.op, HasInnerGraph):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was for debugging I assume?

node_rewriter: Optional[NodeRewriter] = None,
):
self.node_rewriter = node_rewriter or self.node_rewriter
if isinstance(node.op, (Scan, OpFromGraph)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just check HasInnerGraphOp, and not restrict the kind of Op. A rewrite at the PyTensor level should be more general than what we were working with in PyMC

else:
raise

def transform_scan_node(self, fgraph, node):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should try a simpler approach. A WalkingGraphRewriter already has all the logic to apply rewrites over a graph. Can we just use it in the inner graph of the Op?

Perhaps inside process_node we have something like:

if isinstance(node.op, HasInnerGraphOp):
  self.apply(node.op.fgraph)

from pytensor.tensor.type import matrix, scalar, vector


class TestWalkingNestedGraphRewriter:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should do a more graph based test, and make sure the graph changed as we expected, and not just a counter variable that could be updated even without the changes working correctly

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add WalkingNestedGraphRewriter to apply node rewrites to Scans and OpFromGraph inner graphs
2 participants