-
Notifications
You must be signed in to change notification settings - Fork 130
Walking nested graph rewriter #556
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Initially I placed |
Also this changes should allow us to simplify implementation of default moment for |
@ricardoV94 could you please take a look on this PR when you will have some time? |
@ricardoV94 just friendly reminder about PR ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aerubanov sorry for the delay. I left some comments. I think we should consider make the class simpler and more general if we can.
The gist of it is we have a GraphRewriter class that applies a NodeRewriter over nodes of an fgraph, and sometimes some nodes have an internal fgraph of their own. Can we just reuse the same class to work on this internal fgraphs?
The main concern is that by default FunctionGraph
does not clone inner graphs, because rewrites until now were not usually mutating these inplace, but recreating Ops with new inner graphs.
The largest part of the new code avoids this by pretending to apply the rewrites, keeping the results and then applying them all at once. But this is problematic, because some NodeRewriter make decisions based on the current state of the fgraph, but the "current state" may be stale in the sense that we have already decided to apply another rewrite, just not done it yet. Had the graph been changed already, the NodeRewriter might not have applied itself.
A not very efficient solution would be to clone inner fgraph when we find them, apply the GraphRewriter, and if it changed, then recreate the new Op with the new inner fgraph. Maybe we can force all Ops that inherit from HasInnerGraphOp to implement a recreate_with_new_fgraph
that accepts a new fgraph and recreates itself accordingly?
This would also fix the scope issue here, where our general rewrite has to know what a Scan or OpFromGraph are to be able to support them.
@@ -2030,6 +2030,7 @@ def importer(node): | |||
try: | |||
t0 = time.perf_counter() | |||
while q: | |||
# breakpoint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# breakpoint |
@@ -2020,7 +2020,7 @@ def apply(self, fgraph, start_from=None): | |||
io_t = time.perf_counter() - t0 | |||
|
|||
def importer(node): | |||
if node is not current_node: | |||
if node is not current_node and not isinstance(node.op, HasInnerGraph): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change was for debugging I assume?
node_rewriter: Optional[NodeRewriter] = None, | ||
): | ||
self.node_rewriter = node_rewriter or self.node_rewriter | ||
if isinstance(node.op, (Scan, OpFromGraph)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should just check HasInnerGraphOp
, and not restrict the kind of Op. A rewrite at the PyTensor level should be more general than what we were working with in PyMC
else: | ||
raise | ||
|
||
def transform_scan_node(self, fgraph, node): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should try a simpler approach. A WalkingGraphRewriter already has all the logic to apply rewrites over a graph. Can we just use it in the inner graph of the Op?
Perhaps inside process_node
we have something like:
if isinstance(node.op, HasInnerGraphOp):
self.apply(node.op.fgraph)
from pytensor.tensor.type import matrix, scalar, vector | ||
|
||
|
||
class TestWalkingNestedGraphRewriter: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should do a more graph based test, and make sure the graph changed as we expected, and not just a counter variable that could be updated even without the changes working correctly
Implementation of
WalkingNestedGraphRewriter
which can apply node_rewriter for nodes of nested graphs (Scan
andOpFromGraph
). Close #529.