Skip to content

Conversation

@Jasjeet-Singh-S
Copy link

Add rewrite for argmax/argmin of monotonic functions

Summary

This PR implements a graph rewrite that optimizes argmax/argmin operations applied to monotonic functions by eliminating unnecessary function evaluations.

Motivation

Computing argmax(exp(x)) is wasteful because the exponential computation doesn't affect which index has the maximum value - we only care about relative ordering. Since monotonic functions preserve ordering, we can skip the expensive function application entirely.

Implementation

New rewrite: local_argmax_argmin_monotonic

The rewrite handles four transformation paths based on function monotonicity:

Monotonically Increasing Functions

  • argmax(f(x)) → argmax(x)
  • argmin(f(x)) → argmin(x)

Supported increasing functions: Exp, Exp2, Expm1, Log, Log2, Log10, Log1p, Sqrt, Deg2Rad, Rad2Deg, ArcSin, Tan, ArcTan, ArcCosh, Sinh, ArcSinh, Tanh, ArcTanh

Monotonically Decreasing Functions

  • argmax(f(x)) → argmin(x)
  • argmin(f(x)) → argmax(x)

Supported decreasing functions: Neg, Reciprocal, ArcCos

Key Features

  • Handles PyTensor's internal representation: Correctly processes argmin which is internally represented as Argmax(Neg(...)) in PyTensor
  • Preserves axis parameter: Works correctly with different axis specifications (None, 0, -1, etc.)
  • Robust pattern matching: Uses Elemwise wrapper detection to identify scalar operations
  • Stack trace preservation: Maintains debugging information via copy_stack_trace

Changes

pytensor/tensor/rewriting/math.py

  • Added MONOTONIC_INCREASING tuple containing 18 monotonically increasing scalar operations
  • Added MONOTONIC_DECREASING tuple containing 3 monotonically decreasing scalar operations
  • Implemented _is_argmin() helper function to detect argmin patterns (handles Argmax(Neg(...)) representation)
  • Implemented local_argmax_argmin_monotonic() rewriter with @register_canonicalize decorator

tests/tensor/rewriting/test_math.py

  • Added TestArgmaxArgminMonotonic test class with comprehensive coverage:
    • test_argmax_increasing_functions - Tests rewrite for increasing functions with argmax
    • test_argmin_increasing_functions - Tests rewrite for increasing functions with argmin
    • test_argmax_decreasing_functions - Tests rewrite for decreasing functions with argmax (flips to argmin)
    • test_argmin_decreasing_functions - Tests rewrite for decreasing functions with argmin (flips to argmax)
  • All tests parametrized over multiple axis values (None, 0, -1)
  • Tests verify both numerical correctness and graph structure optimization

Example

import pytensor.tensor as pt
import numpy as np

x = pt.vector('x')
y = pt.argmax(pt.exp(x))  # Before: computes exp then argmax
                           # After: computes argmax directly

# The rewrite eliminates the expensive exp() computation
# since argmax(exp(x)) = argmax(x) for monotonic functions

Performance Impact

This rewrite provides significant speedups when:

  • Computing argmax/argmin of exponentials, logarithms, or other monotonic transformations
  • Working with large arrays where the eliminated operations would be expensive
  • The monotonic function application is the dominant computational cost

Testing

All tests pass with various configurations:

  • Multiple monotonic functions (18 increasing, 3 decreasing)
  • Different axis specifications (None, 0, -1)
  • Numerical correctness verification against expected results
  • Graph structure validation to ensure rewrites are applied correctly

The rewrite correctly handles edge cases including:

  • PyTensor's internal Argmax(Neg(...)) representation for argmin
  • Broadcasting and dimension handling
  • Proper flipping between argmax/argmin for decreasing functions

Implements graph rewrite that eliminates redundant monotonic function applications in argmax/argmin operations. For monotonically increasing functions, rewrites argmax(f(x)) → argmax(x) and argmin(f(x)) → argmin(x). For decreasing functions, flips operations: argmax(f(x)) → argmin(x) and argmin(f(x)) → argmax(x). Includes comprehensive tests.
Copilot AI review requested due to automatic review settings February 1, 2026 17:22
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a graph rewrite optimization that eliminates unnecessary function evaluations when computing argmax or argmin of monotonic functions. The optimization leverages the property that monotonic functions preserve ordering, so argmax(exp(x)) can be simplified to argmax(x).

Changes:

  • Adds MONOTONIC_INCREASING and MONOTONIC_DECREASING tuples to classify scalar operations by monotonicity
  • Implements local_argmax_argmin_monotonic rewriter that optimizes argmax/argmin of monotonic functions
  • Adds comprehensive test suite with parametrized tests for different axis values

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.

File Description
pytensor/tensor/rewriting/math.py Adds monotonic function classifications and implements the core rewrite logic for argmax/argmin optimization
tests/tensor/rewriting/test_math.py Adds test class with parametrized tests for increasing and decreasing monotonic functions

Comment on lines +5100 to +5120
def test_argmax_decreasing_functions(self, axis):
"""Test argmax(f_dec(x)) -> argmin(x) for monotonic decreasing f."""
x = pt.vector("x")
test_val = np.array([1.0, 3.0, 2.0, 5.0, 4.0])

mode = get_default_mode()

for f in [pt.neg, lambda z: -z]:
unrewritten = pt.argmax(f(x), axis=axis)
expected = pt.argmin(x, axis=axis)

fn_unrewritten = function([x], unrewritten, mode=mode)
fn_expected = function([x], expected, mode=mode)

result_unrewritten = fn_unrewritten(test_val)
result_expected = fn_expected(test_val)

assert result_unrewritten == result_expected, (
f"argmax(neg(x), axis={axis}) failed: "
f"got {result_unrewritten}, expected {result_expected}"
)
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests for decreasing functions should verify that the rewrite was applied (i.e., that the monotonic function was eliminated from the graph), similar to how the tests for increasing functions check this on lines 5056-5064 and 5089-5097. This ensures the optimization actually occurred and not just that the results are numerically equivalent.

Copilot uses AI. Check for mistakes.
Comment on lines +5123 to +5143
def test_argmin_decreasing_functions(self, axis):
"""Test argmin(f_dec(x)) -> argmax(x) for monotonic decreasing f."""
x = pt.vector("x")
test_val = np.array([1.0, 3.0, 2.0, 5.0, 4.0])

mode = get_default_mode()

for f in [pt.neg, lambda z: -z]:
unrewritten = pt.argmin(f(x), axis=axis)
expected = pt.argmax(x, axis=axis)

fn_unrewritten = function([x], unrewritten, mode=mode)
fn_expected = function([x], expected, mode=mode)

result_unrewritten = fn_unrewritten(test_val)
result_expected = fn_expected(test_val)

assert result_unrewritten == result_expected, (
f"argmin(neg(x), axis={axis}) failed: "
f"got {result_unrewritten}, expected {result_expected}"
) No newline at end of file
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests for decreasing functions should verify that the rewrite was applied (i.e., that the monotonic function was eliminated from the graph), similar to how the tests for increasing functions check this on lines 5056-5064 and 5089-5097. This ensures the optimization actually occurred and not just that the results are numerically equivalent.

Copilot uses AI. Check for mistakes.

MONOTONIC_INCREASING = (
ps.Exp, ps.Exp2, ps.Expm1, ps.Log, ps.Log2, ps.Log10, ps.Log1p,
ps.Sqrt, ps.Deg2Rad, ps.Rad2Deg, ps.ArcSin, ps.Tan, ps.ArcTan,
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ps.Tan should not be included in MONOTONIC_INCREASING. The tangent function is periodic and only monotonic within each period (e.g., on (-π/2, π/2)). For arrays spanning multiple periods, argmax(tan(x)) ≠ argmax(x). For example, with x = [0, π], tan(0) = 0 and tan(π) ≈ 0, so this optimization would be incorrect.

Suggested change
ps.Sqrt, ps.Deg2Rad, ps.Rad2Deg, ps.ArcSin, ps.Tan, ps.ArcTan,
ps.Sqrt, ps.Deg2Rad, ps.Rad2Deg, ps.ArcSin, ps.ArcTan,

Copilot uses AI. Check for mistakes.
ps.ArcCosh, ps.Sinh, ps.ArcSinh, ps.Tanh, ps.ArcTanh
)

MONOTONIC_DECREASING = (ps.Neg, ps.Reciprocal, ps.ArcCos)
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ps.Reciprocal should not be included in MONOTONIC_DECREASING. The reciprocal function (1/x) is not globally monotonic - it's discontinuous at 0 and changes sign. For example, if x = [-2, -1, 1, 2], then 1/x = [-0.5, -1, 1, 0.5], giving argmax(1/x) = 2 but argmin(x) = 0, which are different. The function is only monotonically decreasing on each of (0, ∞) and (-∞, 0) separately, not globally.

Suggested change
MONOTONIC_DECREASING = (ps.Neg, ps.Reciprocal, ps.ArcCos)
MONOTONIC_DECREASING = (ps.Neg, ps.ArcCos)

Copilot uses AI. Check for mistakes.
Comment on lines +3900 to +3912
"""Check if node represents argmin by detecting Argmax(Neg(...))"""
if not isinstance(node.op, Argmax):
return False

input_node = node.inputs[0]
if not input_node.owner:
return False

# argmin(x) becomes Argmax(Neg(x)) or Argmax(imax - x) or Argmax(~x)
inner_op = input_node.owner.op
if isinstance(inner_op, Elemwise) and isinstance(inner_op.scalar_op, ps.Neg):
return True

Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment suggests that _is_argmin handles Argmax(imax - x) and Argmax(~x) patterns (for uint and bool dtypes), but the implementation only checks for Argmax(Neg(x)) (line 3910). While this may not be a practical issue (since most monotonic functions return floats), the comment should be updated to accurately reflect what the function actually detects, or the implementation should be extended to handle these additional patterns.

Suggested change
"""Check if node represents argmin by detecting Argmax(Neg(...))"""
if not isinstance(node.op, Argmax):
return False
input_node = node.inputs[0]
if not input_node.owner:
return False
# argmin(x) becomes Argmax(Neg(x)) or Argmax(imax - x) or Argmax(~x)
inner_op = input_node.owner.op
if isinstance(inner_op, Elemwise) and isinstance(inner_op.scalar_op, ps.Neg):
return True
"""Check if node represents argmin by detecting Argmax(Neg(x))."""
if not isinstance(node.op, Argmax):
return False
input_node = node.inputs[0]
if not input_node.owner:
return False
# Currently we only treat argmin(x) represented as Argmax(Neg(x)) as argmin
inner_op = input_node.owner.op
if isinstance(inner_op, Elemwise) and isinstance(inner_op.scalar_op, ps.Neg):
return True

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant