Skip to content

Commit d953a0d

Browse files
Deprecate and raise AttributeError for MaxAndArgmax
1 parent 25af747 commit d953a0d

File tree

4 files changed

+10
-17
lines changed

4 files changed

+10
-17
lines changed

pytensor/tensor/math.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,9 @@
110110
def __getattr__(name):
111111
if name == "MaxAndArgmax":
112112
raise AttributeError(
113-
"The class `MaxAndArgmax` has been deprecated. "
114-
"Call `Max` and `Argmax` separately as an alternative."
113+
"The class `MaxandArgmax` has been deprecated. Call `Max` and `Argmax` seperately as an alternative."
115114
)
115+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
116116

117117

118118
def _get_atol_rtol(a, b):
@@ -565,16 +565,6 @@ def max(x, axis=None, keepdims=False):
565565
We return an error as numpy when we reduce a dim with a shape of 0.
566566
567567
"""
568-
569-
# We have a choice of implementing this call with the
570-
# CAReduce op or the MaxAndArgmax op.
571-
572-
# MaxAndArgmax supports grad and Rop, so we prefer to use that.
573-
# CAReduce is faster, but optimizations will replace MaxAndArgmax[0]
574-
# with CAReduce at compile time, so at this stage the important
575-
# thing is supporting all user interface features, not speed.
576-
# Some cases can be implemented only with CAReduce.
577-
578568
out = max_and_argmax(x, axis)[0]
579569

580570
if keepdims:

pytensor/tensor/rewriting/uncanonicalize.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def local_max_to_min(fgraph, node):
5252
Notes
5353
-----
5454
We don't need an opt that will do the reverse as by default
55-
the interface put only MaxAndArgmax into the graph.
55+
the interface put only Max into the graph.
5656
5757
"""
5858
if node.op == neg and node.inputs[0].owner:

tests/link/jax/test_nlinalg.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pytensor.link.jax import JAXLinker
1212
from pytensor.tensor import blas as pt_blas
1313
from pytensor.tensor import nlinalg as pt_nlinalg
14-
from pytensor.tensor.math import MaxAndArgmax, maximum
14+
from pytensor.tensor.math import Argmax, Max, maximum
1515
from pytensor.tensor.math import max as pt_max
1616
from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector
1717
from tests.link.jax.test_basic import compare_jax_and_py
@@ -88,7 +88,8 @@ def test_jax_basic_multiout_omni():
8888
# Test that a single output of a multi-output `Op` can be used as input to
8989
# another `Op`
9090
x = dvector()
91-
mx, amx = MaxAndArgmax([0])(x)
91+
mx = Max([0])(x)
92+
amx = Argmax([0])(x)
9293
out = mx * amx
9394
out_fg = FunctionGraph([x], [out])
9495
compare_jax_and_py(out_fg, [np.r_[1, 2]])

tests/tensor/test_math.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1032,7 +1032,6 @@ def test_vectorize(self, core_axis, batch_axis):
10321032
x = tensor(shape=(5, 5, 5, 5))
10331033
batch_x = tensor(shape=(3, 5, 5, 5, 5))
10341034

1035-
# Test MaxAndArgmax
10361035
argmax_x = argmax(x, axis=core_axis)
10371036

10381037
arg_max_node = argmax_x.owner
@@ -1423,7 +1422,10 @@ def test_bool(self):
14231422

14241423

14251424
def test_MaxAndArgmax_deprecated():
1426-
with pytest.raises(AttributeError):
1425+
with pytest.raises(
1426+
AttributeError,
1427+
match="The class `MaxandArgmax` has been deprecated. Call `Max` and `Argmax` seperately as an alternative.",
1428+
):
14271429
pytensor.tensor.math.MaxAndArgmax
14281430

14291431

0 commit comments

Comments
 (0)