-
Notifications
You must be signed in to change notification settings - Fork 440
Newmetric: ClassificationReport #3116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
Great addition! |
- Remove direct imports of classification metrics at module level - Implement lazy imports using @Property for Binary/Multiclass/MultilabelClassificationReport - Move metric initialization to property getters to break circular dependencies - Maintain all existing functionality while improving import structure
for more information, see https://pre-commit.ci
Thank you very much! @Borda Could you please help me understand why the TM.unittests are failing? They seemed to be linked with azure and I am not familiar with the error logs there. Regarding the ruff failures, for some reason my local checks pass, so ill look into that. But is it okay if I tend to those once the actual code is vetted? Sorry for the delay in getting back to, I had some immediate personal work to get done. CC: @rittik9 I think it is ready now? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aymuos15 I had a quick look and opened a pr for fixing pre-commit errors.
Doctests regarding classification report seem to be failing, mind having look.
going to take a detailed look later.
fix pre-commit errors
for more information, see https://pre-commit.ci
Thanks a lot for the quickfix @rittik9. Will look into the doctests now. Sorry about the close and open, i think the close automatically got triggered post the local merge for some reason? not sure. |
src/torchmetrics/functional/classification/classification_report.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) | ||
|
||
|
||
def multiclass_classification_report( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we also provide top_k for multiclass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove the unnecessary comments
Thank you very much for the review @rittik9. I think most of them need significant change to be implemented. I will get this done over the weekend. Will make sure to remove the bad comments as well. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case of modular interface , while computing , why not call the functions from functional interface, reducing boilerplate
Hi @aymuos15 overall it looks pretty good to me, nice work. Just added a few comments. You can also update |
…r accuracy | Revamp testing to make it more parametrised | Make the non functional part of metric more modular | add ignore_index support | top_k support for multiclass
for more information, see https://pre-commit.ci
@rittik9 Thanks again for the detailed review. I have tried to address everything bar the classwise accuracy score. That is inherently different from the scikit version of the metric. Are you okay if we raise a separate issue for that, discuss how exactly we want to add that, and then do a separate PR for that? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall it looks good, just think if it would be better if we can have it rather more as a wrapper so you can select what metrics you want to have in the report (probably pass it s string metrics/class names) and then have precision, recall, F-measure as default configuration
see suggestion by @rittik9 in #2580 (comment)
Thank you very much for the review. I definitely like the idea and agree with it. However, I think I would have to pretty much revamp everything. Would it be okay to keep this as the base and proceed with that separately in another PR or should I just continue here? |
you can continue here 🐰 |
Haha, alright. Ill take some time to understand the best way to approach this and make the tests so edge cases do not come up. Thanks again! |
Hi @aymuos15 While you're at it, could you please ensure that the 'micro' averaging option is included by default for all three cases—binary, multiclass, and multilabel? In case the results don't align with sklearn, we can always verify them manually. |
Yup sure. Thank you for reminding me. Since anyways now we are going to include everything, ill make all options exhaustive. |
We can figure out how to split it into smaller pieces so it would land smoother/faster |
That sounds good with me. Could you please let me know what that would entail? This is what i had in mind: TorchMetrics Classification Report Examples
============================================================
BINARY CLASSIFICATION
============================================================
Predictions: [0, 1, 1, 0, 1, 0, 1, 1]
True Labels: [0, 1, 0, 0, 1, 1, 1, 0]
--- CASE 1: DEFAULT METRICS (precision, recall, f1-score) ---
precision recall f1-score support
0 0.67 0.50 0.57 4
1 0.60 0.75 0.67 4
accuracy 0.62 8
micro avg 0.62 0.62 0.62 8
macro avg 0.63 0.62 0.62 8
weighted avg 0.63 0.62 0.62 8
--- CASE 2: ADD ACCURACY TO DEFAULT METRICS ---
precision recall f1-score accuracy support
0 0.67 0.50 0.57 0.50 4
1 0.60 0.75 0.67 0.75 4
accuracy 0.62 8
micro avg 0.62 0.62 0.62 0.62 8
macro avg 0.63 0.62 0.62 0.62 8
weighted avg 0.63 0.62 0.62 0.62 8
--- CASE 3: SPECIFICITY ONLY ---
specificity support
0 0.75 4
1 0.50 4
accuracy 0.62 8
micro avg 0.50 8
macro avg 0.62 8
weighted avg 0.62 8
Verification - Direct accuracy calculation: 0.6250
============================================================
MULTICLASS CLASSIFICATION
============================================================
Predictions: [0, 2, 1, 2, 0, 1, 2, 0]
True Labels: [0, 1, 1, 2, 0, 2, 2, 1]
--- CASE 1: DEFAULT METRICS ---
precision recall f1-score support
0 0.67 1.00 0.80 2
1 0.50 0.33 0.40 3
2 0.67 0.67 0.67 3
accuracy 0.62 8
micro avg 0.62 0.62 0.62 8
macro avg 0.61 0.67 0.62 8
weighted avg 0.60 0.63 0.60 8
--- CASE 2: ADD ACCURACY TO DEFAULT METRICS ---
precision recall f1-score accuracy support
0 0.67 1.00 0.80 1.00 2
1 0.50 0.33 0.40 0.33 3
2 0.67 0.67 0.67 0.67 3
accuracy 0.62 8
micro avg 0.62 0.62 0.62 0.62 8
macro avg 0.61 0.67 0.62 0.67 8
weighted avg 0.60 0.63 0.60 0.63 8
--- CASE 3: SPECIFICITY ONLY ---
specificity support
0 0.83 2
1 0.80 3
2 0.80 3
accuracy 0.62 8
micro avg 0.81 8
macro avg 0.81 8
weighted avg 0.81 8
Verification - Direct accuracy calculation: 0.6667
============================================================
MULTILABEL CLASSIFICATION
============================================================
Predictions:
Sample 1: [1, 0, 1]
Sample 2: [0, 1, 0]
Sample 3: [1, 1, 0]
Sample 4: [0, 0, 1]
True Labels:
Sample 1: [1, 0, 0]
Sample 2: [0, 1, 1]
Sample 3: [1, 1, 0]
Sample 4: [0, 0, 1]
--- CASE 1: DEFAULT METRICS ---
precision recall f1-score support
0 1.00 1.00 1.00 2
1 1.00 1.00 1.00 2
2 0.50 0.50 0.50 2
micro avg 0.83 0.83 0.83 6
macro avg 0.83 0.83 0.83 6
weighted avg 0.83 0.83 0.83 6
samples avg 0.88 0.88 0.83 6
--- CASE 2: ADD ACCURACY TO DEFAULT METRICS ---
precision recall f1-score accuracy support
0 1.00 1.00 1.00 1.00 2
1 1.00 1.00 1.00 1.00 2
2 0.50 0.50 0.50 0.50 2
micro avg 0.83 0.83 0.83 0.83 6
macro avg 0.83 0.83 0.83 0.83 6
weighted avg 0.83 0.83 0.83 0.83 6
samples avg 0.88 0.88 0.83 0.83 6
--- CASE 3: SPECIFICITY ONLY ---
specificity support
0 1.00 2
1 1.00 2
2 0.50 2
micro avg 0.83 6
macro avg 0.83 6
weighted avg 0.83 6
samples avg 6
Verification - Direct accuracy calculation: 0.8333 Essentially
Is the above okay? I have not committed anything yet because the code is very messy. Once we agree on a path forward, I will trigger the next commit if thats okay. |
I had rather in mind that we can trim this PR to keep the printing functions and table formatting, and in the following PR have extension of Collection metrics either as a new subclass or a new method... |
Ah okay! I will push a commit for the formatting and printing tonight. Thank you. |
Keeping in mind the below rigidness for the micro avg # Determine if micro average should be shown in the report based on classification task
# Following scikit-learn's logic:
# - Show for multilabel classification (always)
# - Show for multiclass when using a subset of classes
# - Don't show for binary classification (micro avg is same as accuracy)
# - Don't show for full multiclass classification with all classes (micro avg is same as accuracy)
show_micro_avg = False
is_multilabel = task == ClassificationTask.MULTILABEL After going through everything again, I think the PR as of right now, it does exactly the bare minimum? The main additions after the initial few commits were the ignore_index and the micro avg. So Just a few quick questions:
Thank you very much! |
Still, the Module-like metric is not derived from the collection to save some computations |
Okay, Thank you. So just to confirm -- I would only push |
What does this PR do?
Fixes #2580
Before submitting
PR review
Classification report is a nice to have metric. The motivation can be easily inspired from its scikit learn integration. Having it ready for tensors is definitely a good step forward.
Notes:
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Was a really nice way of getting familiar with the codebase :)
📚 Documentation preview 📚: https://torchmetrics--3116.org.readthedocs.build/en/3116/