-
Notifications
You must be signed in to change notification settings - Fork 701
Added trace and diag with batch support for linalg crate #3703
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is
❌ Your project check has failed because the head coverage (63.37%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #3703 +/- ##
==========================================
+ Coverage 63.28% 63.37% +0.08%
==========================================
Files 1055 1059 +4
Lines 123018 123317 +299
==========================================
+ Hits 77858 78146 +288
- Misses 45160 45171 +11 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing these 🙏
Shouldn't the diag
output be D - 1
instead of keeping the rank and having the last dim = 1? And trace
would also output a tensor of rank D - 1
(so for a matrix, 1D tensor but with a single value just like our reduction ops).
We cannot use const generic expr (not stable), but we could have something similar to tensor.take
where the output rank is also declared (and should always be D - 1
).
Lmk what you think.
Also, we should assert that D >= 2
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor comments on doc, otherwise LGTM!
/// Tensor with rank D - 1, where the last two matrix dimensions are replaced by a single | ||
/// dimension containing the diagonal elements | ||
/// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// Tensor with rank D - 1, where the last two matrix dimensions are replaced by a single | |
/// dimension containing the diagonal elements | |
/// | |
/// A tensor of rank `D - 1`, where the last dimension contains the diagonal elements of the input. |
/// | ||
/// Tensor with rank D - 1, where the last two matrix dimensions are replaced by a single | ||
/// dimension containing the the trace for each matrix. | ||
/// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// | |
/// Tensor with rank D - 1, where the last two matrix dimensions are replaced by a single | |
/// dimension containing the the trace for each matrix. | |
/// | |
/// A tensor of rank `D - 1`, where the last dimension contains the sum along diagonals of the input. |
use crate::backend::Backend; | ||
use crate::tensor::Tensor; | ||
|
||
/// Computes the trace of the of square matrices. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the "square" here is over specified no? You also have tests that work on non square matrices.
Pull Request Template
Checklist
cargo run-checks
command has been executed.Related Issues/PRs
Relevant issue:
#1538
Changes
The feature description of the issue is to add common linear algebra operations. There are some existing in
the linalg crate such as vector norms and cosine similarity. I added trace and diag operation for starters. Will continue working on more operations.
Testing
Test cases covering rank 2, 3 and 4. Square tall and wide matrices and edge cases such as 1x1, single row/column with multiple data types.
Tested with: cargo test -p burn-ndarray diag && cargo test -p burn-ndarray trace
cargo run-checks passes.