Skip to content

Add rewrite for Sum(MakeVector) #346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 12, 2023

Conversation

aseyboldt
Copy link
Member

I've seen graphs like these recently quite a bit, and we can avoid the allocation by removing the MakeVector call:

Sum{axes=None} [id A]
 └─ MakeVector{dtype='int64'} [id B]
    ├─ Subtensor{i} [id C]
    │  ├─ Shape [id D]
    │  │  └─ x [id E]
    │  └─ 0 [id F]
    └─ Add [id G]
       ├─ Subtensor{i} [id H]
       │  ├─ Shape [id I]
       │  │  └─ x [id E]
       │  └─ 1 [id J]
       └─ 1 [id K]

This now gets rewritten to

Add [id A] 2
 ├─ 1 [id B]
 ├─ Shape_i{0} [id C] 1
 │  └─ x [id D]
 └─ Shape_i{1} [id E] 0
    └─ x [id D]

In the same models I've also come across cases like this, that we could also rewrite, but this is not included here (there already is a rewrite for Subtensor(MakeVector), but not for IncSubtensor(MakeVector))

IncSubtensor{i} [id A] 4
 ├─ MakeVector{dtype='int64'} [id B] 3
 │  ├─ Shape_i{0} [id C] 2
 │  │  └─ x [id D]
 │  └─ Add [id E] 1
 │     ├─ 1 [id F]
 │     └─ Shape_i{1} [id G] 0
 │        └─ x [id D]
 ├─ 2 [id H]
 └─ 0 [id I]

@ricardoV94
Copy link
Member

This could close #59 if it applies to Join as well and all CAReduce (not just Sum)

@aseyboldt
Copy link
Member Author

The more general CAReduce case does look a bit more tricky than I thought, mostly because the dtypes things should have can be pretty unclear. acc_dtype and output_dtype are sometimes None, and on top of that the corresponding scalar Ops only take two inputs. Maybe generalizing those to allow multiple inputs would actually be nice.

@codecov-commenter
Copy link

codecov-commenter commented Jun 16, 2023

Codecov Report

Merging #346 (7429e94) into main (5c87d74) will increase coverage by 0.03%.
The diff coverage is 93.40%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #346      +/-   ##
==========================================
+ Coverage   80.40%   80.44%   +0.03%     
==========================================
  Files         156      156              
  Lines       45401    45481      +80     
  Branches    11106    11139      +33     
==========================================
+ Hits        36505    36587      +82     
+ Misses       6689     6687       -2     
  Partials     2207     2207              
Impacted Files Coverage Δ
pytensor/scalar/basic.py 80.16% <66.66%> (-0.04%) ⬇️
pytensor/tensor/rewriting/shape.py 81.20% <83.33%> (+0.10%) ⬆️
pytensor/link/jax/dispatch/scalar.py 98.41% <100.00%> (+0.01%) ⬆️
pytensor/scan/rewriting.py 79.94% <100.00%> (ø)
pytensor/tensor/basic.py 90.77% <100.00%> (ø)
pytensor/tensor/rewriting/basic.py 93.52% <100.00%> (+0.61%) ⬆️
pytensor/tensor/rewriting/elemwise.py 88.99% <100.00%> (+0.49%) ⬆️
pytensor/tensor/rewriting/math.py 86.36% <100.00%> (+0.31%) ⬆️
pytensor/tensor/var.py 87.76% <100.00%> (+0.19%) ⬆️

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 17, 2023

The more general CAReduce case does look a bit more tricky than I thought, mostly because the dtypes things should have can be pretty unclear. acc_dtype and output_dtype are sometimes None, and on top of that the corresponding scalar Ops only take two inputs. Maybe generalizing those to allow multiple inputs would actually be nice.

Hmm, can't you infer from the node output type?

For the multiple inputs you could chain with python's reduce, no? The initial case (identity) also takes care of empty or single element cases just like CAReduce does.

If the end result doesn't have the same dtype you can also abort at the end of the rewrite

@aseyboldt
Copy link
Member Author

I guess that's doable, but could we leave that to another PR if someone has time? I think the case of Sum is by far the most common one.
It would also require figuring out how the acc_dtype is actually interpreted, from reading over it, I'm not actually sure we always honor it properly.
An alternative to a reduce call might also be to change the scalar functions for And, Or etc to accept multiple inputs. I think that would be more consistent, as Add for instance does.

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 29, 2023

Can definitely be done in a separate PR. I think all may be quite common as well, specially in parameter checks / asserts.

return [array]

# If this is not the case the sum is invalid
assert node.op.axis is None or node.op.axis == (0,)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is axis=(-1,) invalid?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment on lines 1002 to 991
element_sum = cast(add(*[cast(value, acc_dtype) for value in elements]), out_dtype)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does add work with 0 / 1 inputs? I know MakeVector allows that

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will work with one input, but fail with zero.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. I added a special case

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 30, 2023

Some of these issues about acc_dtype and dtype showed up in #361. There is also one upcast_discrete_output property that we may also need to check in this PR.

I wonder if all these knobs are needed in practice, makes it really cumbersome to work with CAReduce.
Your suggestion of allowing the ScalarOps to accept multiple inputs is also worth considering going into the future.

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 30, 2023

It would also require figuring out how the acc_dtype is actually interpreted, from reading over it, I'm not actually sure we always honor it properly.

My current understanding (from docstrings, but looking at the perform / c_code is probably wiser), is that acc_dtype is the type of the initial value and each intermediate computation (which may require casting internally after performing the scalar_op). Then there is dtype to which the final result after accumulation is cast (no idea why this can't be done by an explicit cast). Also this doesn't apply for discrete outputs, without the upcast_discrete_outputs flag

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 30, 2023

no idea why this can't be done by an explicit cast

Maybe this is like a mini reduce+elemwise optimization concern? When you don't reduce all axes, you would need an Elemwise cast afterwards (so another loop). If that's the only reason, it means we could simplify the Op once we implement #224

@aseyboldt aseyboldt force-pushed the rewrite-sum-makevector branch from 4b47472 to 9f6f048 Compare July 12, 2023 01:03
@aseyboldt aseyboldt force-pushed the rewrite-sum-makevector branch from 9f6f048 to 7429e94 Compare July 12, 2023 01:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants