Skip to content

Allow TensorType(shape=(1,), broadcastable=(False,)) #408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
4 tasks
ricardoV94 opened this issue Aug 7, 2023 · 1 comment
Open
4 tasks

Allow TensorType(shape=(1,), broadcastable=(False,)) #408

ricardoV94 opened this issue Aug 7, 2023 · 1 comment

Comments

@ricardoV94
Copy link
Member

Description

This requires re-introducing the broadcastable flags as independent from shape. It seems needed to:

  1. Not force static shape to be unknown
  2. Not change the meaning of the graph accidentally due to shape inference / rewrites

Affected Ops (anything that performs broadcasting of existing dims):

  • Elemwise
  • Alloc
  • GEMM Ops
  • Unbroadcast

Will require re-introducing Rebroadcast which could toggle broadcastable flags directly independently from static shape gains from SpecifyShape. Probably better named SpecifyBroadcastable.

Probably Elemwise outputs will have to be unbroadcastable as long as at least on input is also unbroadcastable along the same dimension.

x = pt.vector(shape=(1,), broadcastable=(False,))
y = x + x
assert y.type.broadcastable == (False,)
@ricardoV94
Copy link
Member Author

ricardoV94 commented May 30, 2024

An alternative is to make the broadcasting behavior, not a function of the static type of the Variables but of the Ops that do the implicit broadcasting. So Elemwise may have a broadcast_pattern with the indexes of the inputs that will be broadcasted along each dimension. That way if a graph replacement or a rewrite provides something with more defined static shape (1,) instead of the original (None,) the Op will not change it's behavior, and we no longer need this distinction at the type level.

The first time an Op is created, we would read it from the input static types, so everything works with backwards compatibility. This means however, that there would no longer be a single instance of non-unary Elemwises like pt.add, they would need to be created by the helper functions (or like we did for RVs with the dtype argument, with the __call__ method).

This makes sense to me as the same core operation with and without broadcasting are semantically quite different

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant