Skip to content

Break MaxandArgmax Op to seperate TensorMax Op and Argmax Op #731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pytensor/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,8 @@ def cond_make_inplace(fgraph, node):
Reshape,
Unbroadcast,
pt.math.Dot,
pt.math.MaxAndArgmax,
pt.math.Max,
pt.math.Argmax,
pt.subtensor.Subtensor,
pt.subtensor.IncSubtensor,
pt.basic.Alloc,
Expand Down
26 changes: 18 additions & 8 deletions pytensor/link/jax/dispatch/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot, MaxAndArgmax
from pytensor.tensor.math import Argmax, Dot, Max
from pytensor.tensor.nlinalg import (
SVD,
Det,
Expand Down Expand Up @@ -104,18 +104,28 @@
return batched_dot


@jax_funcify.register(MaxAndArgmax)
def jax_funcify_MaxAndArgmax(op, **kwargs):
@jax_funcify.register(Max)
def jax_funcify_Max(op, **kwargs):
axis = op.axis

def maxandargmax(x, axis=axis):
def max(x):
max_res = jnp.max(x, axis)

return max_res

return max


@jax_funcify.register(Argmax)
def jax_funcify_Argmax(op, **kwargs):
axis = op.axis

def argmax(x):
if axis is None:
axes = tuple(range(x.ndim))
else:
axes = tuple(int(ax) for ax in axis)

max_res = jnp.max(x, axis)

# NumPy does not support multiple axes for argmax; this is a
# work-around
keep_axes = jnp.array(
Expand All @@ -138,6 +148,6 @@

max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")

return max_res, max_idx_res
return max_idx_res

Check warning on line 151 in pytensor/link/jax/dispatch/nlinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/jax/dispatch/nlinalg.py#L151

Added line #L151 was not covered by tests

return maxandargmax
return argmax
32 changes: 8 additions & 24 deletions pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)
from pytensor.scalar.basic import add as add_as
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros, Sum
from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from pytensor.tensor.type import scalar

Expand Down Expand Up @@ -985,8 +985,8 @@ def log_softmax_py_fn(x):
return log_softmax


@numba_funcify.register(MaxAndArgmax)
def numba_funcify_MaxAndArgmax(op, node, **kwargs):
@numba_funcify.register(Argmax)
def numba_funcify_Argmax(op, node, **kwargs):
axis = op.axis
x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype
Expand All @@ -996,8 +996,8 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
if x_ndim == 0:

@numba_basic.numba_njit(inline="always")
def maxandargmax(x):
return x, 0
def argmax(x):
return 0

else:
axes = tuple(int(ax) for ax in axis)
Expand All @@ -1006,20 +1006,6 @@ def maxandargmax(x):
# work-around
keep_axes = tuple(i for i in range(x_ndim) if i not in axes)

reduce_max_py_fn = create_multiaxis_reducer(
scalar_maximum,
-np.inf,
axes,
x_ndim,
x_dtype,
return_scalar=False,
)
reduce_max = jit_compile_reducer(
Apply(node.op, node.inputs, [node.outputs[0].clone()]),
reduce_max_py_fn,
reduce_to_scalar=False,
)

reduced_x_ndim = x_ndim - len(axes) + 1
argmax_axis = create_axis_apply_fn(
np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64
Expand All @@ -1030,9 +1016,7 @@ def maxandargmax(x):
sl2 = slice(len(keep_axes), None)

@numba_basic.numba_njit
def maxandargmax(x):
max_res = reduce_max(x)

def argmax(x):
# Not-reduced axes in front
transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order))
kept_shape = transposed_x.shape[sl1]
Expand All @@ -1048,6 +1032,6 @@ def maxandargmax(x):

max_idx_res = argmax_axis(reshaped_x)

return max_res, max_idx_res
return max_idx_res

return maxandargmax
return argmax
Loading
Loading