Skip to content

Broadcast matrix inputs to Gemm #986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 13, 2022
Merged

Conversation

ricardoV94
Copy link
Contributor

@ricardoV94 ricardoV94 commented Jun 8, 2022

The Gemm Op is only applicable when there is no broadcasting between Z and dot(A, B) in the expression Z + dot(A, B). This PR broadcasts the matrix inputs when the Op was inserted for a mix of matrices that weren't known to be row / column matrices until runtime.

Closes #984

@ricardoV94
Copy link
Contributor Author

Converting to draft because of the performance concerns that this fix entails

@brandonwillard
Copy link
Member

Perhaps it makes sense to allow this type of constraint at the type level, more in line with the old broadcastable flag? Type shapes would then not only be limited to (None, int), but also allow for a special flag -1 or "not1" to indicate this dimension can be anything other than 1.

The old TensorType.broadcastable is still present in exactly the same form as it was. The only differences might be in how we want to use and interpret it.

I made a remark about this (i.e. the "old") interpretation of TensorType.broadcastable recently here—among other places/times throughout our work. The problem with the "old"/strict interpretation is that it puts extra pressure on Op.make_node implementations to both infer and be accurate about the broadcast patterns/static shape information in the TensorTypes it creates. We've been dealing with the issues and limitations that arise from this interpretation all throughout this work.

The type constraints you mention are viable, but also really do require a much more clearly defined and implemented type system, and, ultimately, some basic constraint logic. My push for the broad use of miniKanren is—in part—motivated by the availability (and compartmentalization) of such features.

Regardless, why can't we broadcast all the inputs to the GEMM Op in the rewrite (or even in the Op.make_node or Op.perform/Op.c_code methods)?

@ricardoV94
Copy link
Contributor Author

Regardless, why can't we broadcast all the inputs to the GEMM Op in the rewrite (or even in the Op.make_node or Op.perform/Op.c_code methods)?

I'll explore that. I didn't plan to mess with blas related Ops, but here we are ^^

In any case I feel that supporting (and enforcing) non size1 type shape might come in handy in a couple of places.

@ricardoV94
Copy link
Contributor Author

Broadcasting the matrix inputs was not so hard in the end. Doing that now

@codecov
Copy link

codecov bot commented Jun 10, 2022

Codecov Report

Merging #986 (47cc1f0) into main (064e72f) will increase coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #986   +/-   ##
=======================================
  Coverage   79.26%   79.26%           
=======================================
  Files         152      152           
  Lines       47927    47932    +5     
  Branches    10912    10913    +1     
=======================================
+ Hits        37990    37995    +5     
  Misses       7429     7429           
  Partials     2508     2508           
Impacted Files Coverage Δ
aesara/tensor/blas.py 79.71% <100.00%> (+0.09%) ⬆️

@ricardoV94
Copy link
Contributor Author

Tests are passing

brandonwillard
brandonwillard previously approved these changes Jun 10, 2022
Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great.

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Jun 13, 2022

I did some sanity checks and I am more confident that I didn't screw up anything :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Gemm fails with simple broadcasting case
2 participants