-
Notifications
You must be signed in to change notification settings - Fork 72
[torchlib] Fix scatter reduce on error cases #2287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
❌ 3 Tests Failed:
View the top 3 failed test(s) by shortest run time
To view more test analytics, go to the Test Analytics Dashboard |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes error cases in the scatter reduce operation by correcting the tensor initialization values for BFLOAT16, BOOL, and non-reduction modes.
- Fixes the BFLOAT16 branch in the max reduction case by using torch.finfo(torch.bfloat16).max.
- Adds proper handling for BOOL type in both min (False) and max (True) branches and ensures a consistent tensor type is returned for the "none" reduction case.
Comments suppressed due to low confidence (3)
onnxscript/function_libs/torch_lib/ops/core.py:7630
- The addition of boolean handling in the 'min' branch for scatter reduce appears to resolve the error with BOOL types; please confirm that using 'False' is semantically correct for a minimum reduction with boolean values.
elif dtype == ir.DataType.BOOL:
value = ir.tensor([False], dtype=dtype)
onnxscript/function_libs/torch_lib/ops/core.py:7643
- Changing the BFLOAT16 case to use the maximum value for a max reduction correctly addresses the error, but please double-check that this change aligns with the intended behavior compared to the PyTorch semantics.
value = ir.tensor([torch.finfo(torch.bfloat16).max], dtype=dtype)
onnxscript/function_libs/torch_lib/ops/core.py:7656
- Replacing the literal 0 with a tensor wrapping ensures consistency in type handling; please verify that this change maintains the expected behavior in downstream operations.
value = ir.tensor([0], dtype=dtype)
Fix three errors
Fix a case for bfloat16 when min should be max.