-
Notifications
You must be signed in to change notification settings - Fork 129
Add rewrite for Sum(MakeVector) #346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This could close #59 if it applies to Join as well and all CAReduce (not just Sum) |
The more general CAReduce case does look a bit more tricky than I thought, mostly because the dtypes things should have can be pretty unclear. |
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #346 +/- ##
==========================================
+ Coverage 80.40% 80.44% +0.03%
==========================================
Files 156 156
Lines 45401 45481 +80
Branches 11106 11139 +33
==========================================
+ Hits 36505 36587 +82
+ Misses 6689 6687 -2
Partials 2207 2207
|
Hmm, can't you infer from the node output type? For the multiple inputs you could chain with python's If the end result doesn't have the same dtype you can also abort at the end of the rewrite |
I guess that's doable, but could we leave that to another PR if someone has time? I think the case of Sum is by far the most common one. |
Can definitely be done in a separate PR. I think |
pytensor/tensor/rewriting/basic.py
Outdated
return [array] | ||
|
||
# If this is not the case the sum is invalid | ||
assert node.op.axis is None or node.op.axis == (0,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is axis=(-1,) invalid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
pytensor/tensor/rewriting/basic.py
Outdated
element_sum = cast(add(*[cast(value, acc_dtype) for value in elements]), out_dtype) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does add
work with 0 / 1 inputs? I know MakeVector allows that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will work with one input, but fail with zero.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True. I added a special case
Some of these issues about I wonder if all these knobs are needed in practice, makes it really cumbersome to work with |
My current understanding (from docstrings, but looking at the perform / c_code is probably wiser), is that |
Maybe this is like a mini reduce+elemwise optimization concern? When you don't reduce all axes, you would need an Elemwise cast afterwards (so another loop). If that's the only reason, it means we could simplify the Op once we implement #224 |
4b47472
to
9f6f048
Compare
9f6f048
to
7429e94
Compare
I've seen graphs like these recently quite a bit, and we can avoid the allocation by removing the MakeVector call:
This now gets rewritten to
In the same models I've also come across cases like this, that we could also rewrite, but this is not included here (there already is a rewrite for
Subtensor(MakeVector)
, but not forIncSubtensor(MakeVector)
)