-
Notifications
You must be signed in to change notification settings - Fork 162
Add rewrite for argmax/argmin of monotonic functions #1869
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add rewrite for argmax/argmin of monotonic functions #1869
Conversation
Implements graph rewrite that eliminates redundant monotonic function applications in argmax/argmin operations. For monotonically increasing functions, rewrites argmax(f(x)) → argmax(x) and argmin(f(x)) → argmin(x). For decreasing functions, flips operations: argmax(f(x)) → argmin(x) and argmin(f(x)) → argmax(x). Includes comprehensive tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds a graph rewrite optimization that eliminates unnecessary function evaluations when computing argmax or argmin of monotonic functions. The optimization leverages the property that monotonic functions preserve ordering, so argmax(exp(x)) can be simplified to argmax(x).
Changes:
- Adds
MONOTONIC_INCREASINGandMONOTONIC_DECREASINGtuples to classify scalar operations by monotonicity - Implements
local_argmax_argmin_monotonicrewriter that optimizes argmax/argmin of monotonic functions - Adds comprehensive test suite with parametrized tests for different axis values
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
pytensor/tensor/rewriting/math.py |
Adds monotonic function classifications and implements the core rewrite logic for argmax/argmin optimization |
tests/tensor/rewriting/test_math.py |
Adds test class with parametrized tests for increasing and decreasing monotonic functions |
| def test_argmax_decreasing_functions(self, axis): | ||
| """Test argmax(f_dec(x)) -> argmin(x) for monotonic decreasing f.""" | ||
| x = pt.vector("x") | ||
| test_val = np.array([1.0, 3.0, 2.0, 5.0, 4.0]) | ||
|
|
||
| mode = get_default_mode() | ||
|
|
||
| for f in [pt.neg, lambda z: -z]: | ||
| unrewritten = pt.argmax(f(x), axis=axis) | ||
| expected = pt.argmin(x, axis=axis) | ||
|
|
||
| fn_unrewritten = function([x], unrewritten, mode=mode) | ||
| fn_expected = function([x], expected, mode=mode) | ||
|
|
||
| result_unrewritten = fn_unrewritten(test_val) | ||
| result_expected = fn_expected(test_val) | ||
|
|
||
| assert result_unrewritten == result_expected, ( | ||
| f"argmax(neg(x), axis={axis}) failed: " | ||
| f"got {result_unrewritten}, expected {result_expected}" | ||
| ) |
Copilot
AI
Feb 1, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests for decreasing functions should verify that the rewrite was applied (i.e., that the monotonic function was eliminated from the graph), similar to how the tests for increasing functions check this on lines 5056-5064 and 5089-5097. This ensures the optimization actually occurred and not just that the results are numerically equivalent.
| def test_argmin_decreasing_functions(self, axis): | ||
| """Test argmin(f_dec(x)) -> argmax(x) for monotonic decreasing f.""" | ||
| x = pt.vector("x") | ||
| test_val = np.array([1.0, 3.0, 2.0, 5.0, 4.0]) | ||
|
|
||
| mode = get_default_mode() | ||
|
|
||
| for f in [pt.neg, lambda z: -z]: | ||
| unrewritten = pt.argmin(f(x), axis=axis) | ||
| expected = pt.argmax(x, axis=axis) | ||
|
|
||
| fn_unrewritten = function([x], unrewritten, mode=mode) | ||
| fn_expected = function([x], expected, mode=mode) | ||
|
|
||
| result_unrewritten = fn_unrewritten(test_val) | ||
| result_expected = fn_expected(test_val) | ||
|
|
||
| assert result_unrewritten == result_expected, ( | ||
| f"argmin(neg(x), axis={axis}) failed: " | ||
| f"got {result_unrewritten}, expected {result_expected}" | ||
| ) No newline at end of file |
Copilot
AI
Feb 1, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests for decreasing functions should verify that the rewrite was applied (i.e., that the monotonic function was eliminated from the graph), similar to how the tests for increasing functions check this on lines 5056-5064 and 5089-5097. This ensures the optimization actually occurred and not just that the results are numerically equivalent.
|
|
||
| MONOTONIC_INCREASING = ( | ||
| ps.Exp, ps.Exp2, ps.Expm1, ps.Log, ps.Log2, ps.Log10, ps.Log1p, | ||
| ps.Sqrt, ps.Deg2Rad, ps.Rad2Deg, ps.ArcSin, ps.Tan, ps.ArcTan, |
Copilot
AI
Feb 1, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ps.Tan should not be included in MONOTONIC_INCREASING. The tangent function is periodic and only monotonic within each period (e.g., on (-π/2, π/2)). For arrays spanning multiple periods, argmax(tan(x)) ≠ argmax(x). For example, with x = [0, π], tan(0) = 0 and tan(π) ≈ 0, so this optimization would be incorrect.
| ps.Sqrt, ps.Deg2Rad, ps.Rad2Deg, ps.ArcSin, ps.Tan, ps.ArcTan, | |
| ps.Sqrt, ps.Deg2Rad, ps.Rad2Deg, ps.ArcSin, ps.ArcTan, |
| ps.ArcCosh, ps.Sinh, ps.ArcSinh, ps.Tanh, ps.ArcTanh | ||
| ) | ||
|
|
||
| MONOTONIC_DECREASING = (ps.Neg, ps.Reciprocal, ps.ArcCos) |
Copilot
AI
Feb 1, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ps.Reciprocal should not be included in MONOTONIC_DECREASING. The reciprocal function (1/x) is not globally monotonic - it's discontinuous at 0 and changes sign. For example, if x = [-2, -1, 1, 2], then 1/x = [-0.5, -1, 1, 0.5], giving argmax(1/x) = 2 but argmin(x) = 0, which are different. The function is only monotonically decreasing on each of (0, ∞) and (-∞, 0) separately, not globally.
| MONOTONIC_DECREASING = (ps.Neg, ps.Reciprocal, ps.ArcCos) | |
| MONOTONIC_DECREASING = (ps.Neg, ps.ArcCos) |
| """Check if node represents argmin by detecting Argmax(Neg(...))""" | ||
| if not isinstance(node.op, Argmax): | ||
| return False | ||
|
|
||
| input_node = node.inputs[0] | ||
| if not input_node.owner: | ||
| return False | ||
|
|
||
| # argmin(x) becomes Argmax(Neg(x)) or Argmax(imax - x) or Argmax(~x) | ||
| inner_op = input_node.owner.op | ||
| if isinstance(inner_op, Elemwise) and isinstance(inner_op.scalar_op, ps.Neg): | ||
| return True | ||
|
|
Copilot
AI
Feb 1, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment suggests that _is_argmin handles Argmax(imax - x) and Argmax(~x) patterns (for uint and bool dtypes), but the implementation only checks for Argmax(Neg(x)) (line 3910). While this may not be a practical issue (since most monotonic functions return floats), the comment should be updated to accurately reflect what the function actually detects, or the implementation should be extended to handle these additional patterns.
| """Check if node represents argmin by detecting Argmax(Neg(...))""" | |
| if not isinstance(node.op, Argmax): | |
| return False | |
| input_node = node.inputs[0] | |
| if not input_node.owner: | |
| return False | |
| # argmin(x) becomes Argmax(Neg(x)) or Argmax(imax - x) or Argmax(~x) | |
| inner_op = input_node.owner.op | |
| if isinstance(inner_op, Elemwise) and isinstance(inner_op.scalar_op, ps.Neg): | |
| return True | |
| """Check if node represents argmin by detecting Argmax(Neg(x)).""" | |
| if not isinstance(node.op, Argmax): | |
| return False | |
| input_node = node.inputs[0] | |
| if not input_node.owner: | |
| return False | |
| # Currently we only treat argmin(x) represented as Argmax(Neg(x)) as argmin | |
| inner_op = input_node.owner.op | |
| if isinstance(inner_op, Elemwise) and isinstance(inner_op.scalar_op, ps.Neg): | |
| return True |
Add rewrite for argmax/argmin of monotonic functions
Summary
This PR implements a graph rewrite that optimizes
argmax/argminoperations applied to monotonic functions by eliminating unnecessary function evaluations.Motivation
Computing
argmax(exp(x))is wasteful because the exponential computation doesn't affect which index has the maximum value - we only care about relative ordering. Since monotonic functions preserve ordering, we can skip the expensive function application entirely.Implementation
New rewrite:
local_argmax_argmin_monotonicThe rewrite handles four transformation paths based on function monotonicity:
Monotonically Increasing Functions
argmax(f(x)) → argmax(x)argmin(f(x)) → argmin(x)Supported increasing functions:
Exp,Exp2,Expm1,Log,Log2,Log10,Log1p,Sqrt,Deg2Rad,Rad2Deg,ArcSin,Tan,ArcTan,ArcCosh,Sinh,ArcSinh,Tanh,ArcTanhMonotonically Decreasing Functions
argmax(f(x)) → argmin(x)argmin(f(x)) → argmax(x)Supported decreasing functions:
Neg,Reciprocal,ArcCosKey Features
argminwhich is internally represented asArgmax(Neg(...))in PyTensorNone,0,-1, etc.)Elemwisewrapper detection to identify scalar operationscopy_stack_traceChanges
pytensor/tensor/rewriting/math.pyMONOTONIC_INCREASINGtuple containing 18 monotonically increasing scalar operationsMONOTONIC_DECREASINGtuple containing 3 monotonically decreasing scalar operations_is_argmin()helper function to detect argmin patterns (handlesArgmax(Neg(...))representation)local_argmax_argmin_monotonic()rewriter with@register_canonicalizedecoratortests/tensor/rewriting/test_math.pyTestArgmaxArgminMonotonictest class with comprehensive coverage:test_argmax_increasing_functions- Tests rewrite for increasing functions with argmaxtest_argmin_increasing_functions- Tests rewrite for increasing functions with argmintest_argmax_decreasing_functions- Tests rewrite for decreasing functions with argmax (flips to argmin)test_argmin_decreasing_functions- Tests rewrite for decreasing functions with argmin (flips to argmax)None,0,-1)Example
Performance Impact
This rewrite provides significant speedups when:
Testing
All tests pass with various configurations:
None,0,-1)The rewrite correctly handles edge cases including:
Argmax(Neg(...))representation forargmin