Skip to content

Commit a278272

Browse files
Added changes for seperating MaxandArgmax Op
Scalar problem solved Finalise changes to seperate MaxAndArgmax Op
1 parent 950a5ea commit a278272

File tree

15 files changed

+307
-1272
lines changed

15 files changed

+307
-1272
lines changed

pytensor/compile/function/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ def opt_log1p(node):
312312
else:
313313
# note: pfunc will also call orig_function -- orig_function is
314314
# a choke point that all compilation must pass through
315+
315316
fn = pfunc(
316317
params=inputs,
317318
outputs=outputs,

pytensor/compile/function/types.py

+1
Original file line numberDiff line numberDiff line change
@@ -1758,6 +1758,7 @@ def orig_function(
17581758
name=name,
17591759
fgraph=fgraph,
17601760
)
1761+
print(m)
17611762
with config.change_flags(compute_test_value="off"):
17621763
fn = m.create(defaults)
17631764
finally:

pytensor/graph/op.py

+1
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def __call__(
291291
292292
"""
293293
node = self.make_node(*inputs, **kwargs)
294+
294295
if name is not None:
295296
if len(node.outputs) == 1:
296297
node.outputs[0].name = name

pytensor/ifelse.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ def cond_make_inplace(fgraph, node):
477477
Reshape,
478478
Unbroadcast,
479479
pt.math.Dot,
480-
pt.math.TensorMax,
480+
pt.math.Max,
481481
pt.math.Argmax,
482482
pt.subtensor.Subtensor,
483483
pt.subtensor.IncSubtensor,

pytensor/link/jax/dispatch/nlinalg.py

+63-8
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from pytensor.link.jax.dispatch import jax_funcify
44
from pytensor.tensor.blas import BatchedDot
5-
from pytensor.tensor.math import Dot, MaxAndArgmax
5+
from pytensor.tensor.math import Argmax, Dot, Max
66
from pytensor.tensor.nlinalg import (
77
SVD,
88
Det,
@@ -104,18 +104,73 @@ def batched_dot(a, b):
104104
return batched_dot
105105

106106

107-
@jax_funcify.register(MaxAndArgmax)
108-
def jax_funcify_MaxAndArgmax(op, **kwargs):
107+
# @jax_funcify.register(Max)
108+
# @jax_funcify.register(Argmax)
109+
# def jax_funcify_MaxAndArgmax(op, **kwargs):
110+
# axis = op.axis
111+
112+
# def maxandargmax(x, axis=axis):
113+
# if axis is None:
114+
# axes = tuple(range(x.ndim))
115+
# else:
116+
# axes = tuple(int(ax) for ax in axis)
117+
118+
# max_res = jnp.max(x, axis)
119+
120+
# # NumPy does not support multiple axes for argmax; this is a
121+
# # work-around
122+
# keep_axes = jnp.array(
123+
# [i for i in range(x.ndim) if i not in axes], dtype="int64"
124+
# )
125+
# # Not-reduced axes in front
126+
# transposed_x = jnp.transpose(
127+
# x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64")))
128+
# )
129+
# kept_shape = transposed_x.shape[: len(keep_axes)]
130+
# reduced_shape = transposed_x.shape[len(keep_axes) :]
131+
132+
# # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
133+
# # Otherwise reshape would complain citing float arg
134+
# new_shape = (
135+
# *kept_shape,
136+
# jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
137+
# )
138+
# reshaped_x = transposed_x.reshape(new_shape)
139+
140+
# max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
141+
142+
# return max_res, max_idx_res
143+
144+
# return maxandargmax
145+
146+
147+
@jax_funcify.register(Max)
148+
def jax_funcify_Max(op, **kwargs):
109149
axis = op.axis
110150

111-
def maxandargmax(x, axis=axis):
151+
def max(x, axis=axis):
152+
# if axis is None:
153+
# axes = tuple(range(x.ndim))
154+
# else:
155+
# axes = tuple(int(ax) for ax in axis)
156+
157+
max_res = jnp.max(x, axis)
158+
159+
return max_res
160+
161+
return max
162+
163+
164+
@jax_funcify.register(Argmax)
165+
def jax_funcify_Argmax(op, **kwargs):
166+
axis = op.axis
167+
168+
def argmax(x, axis=axis):
112169
if axis is None:
113170
axes = tuple(range(x.ndim))
114171
else:
115172
axes = tuple(int(ax) for ax in axis)
116173

117-
max_res = jnp.max(x, axis)
118-
119174
# NumPy does not support multiple axes for argmax; this is a
120175
# work-around
121176
keep_axes = jnp.array(
@@ -138,6 +193,6 @@ def maxandargmax(x, axis=axis):
138193

139194
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
140195

141-
return max_res, max_idx_res
196+
return max_idx_res
142197

143-
return maxandargmax
198+
return argmax

pytensor/link/numba/dispatch/elemwise.py

+8-24
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
)
4545
from pytensor.scalar.basic import add as add_as
4646
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
47-
from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros, Sum
47+
from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum
4848
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
4949
from pytensor.tensor.type import scalar
5050

@@ -985,8 +985,8 @@ def log_softmax_py_fn(x):
985985
return log_softmax
986986

987987

988-
@numba_funcify.register(MaxAndArgmax)
989-
def numba_funcify_MaxAndArgmax(op, node, **kwargs):
988+
@numba_funcify.register(Argmax)
989+
def numba_funcify_Argmax(op, node, **kwargs):
990990
axis = op.axis
991991
x_at = node.inputs[0]
992992
x_dtype = x_at.type.numpy_dtype
@@ -996,8 +996,8 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
996996
if x_ndim == 0:
997997

998998
@numba_basic.numba_njit(inline="always")
999-
def maxandargmax(x):
1000-
return x, 0
999+
def argmax(x):
1000+
return 0
10011001

10021002
else:
10031003
axes = tuple(int(ax) for ax in axis)
@@ -1006,20 +1006,6 @@ def maxandargmax(x):
10061006
# work-around
10071007
keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
10081008

1009-
reduce_max_py_fn = create_multiaxis_reducer(
1010-
scalar_maximum,
1011-
-np.inf,
1012-
axes,
1013-
x_ndim,
1014-
x_dtype,
1015-
return_scalar=False,
1016-
)
1017-
reduce_max = jit_compile_reducer(
1018-
Apply(node.op, node.inputs, [node.outputs[0].clone()]),
1019-
reduce_max_py_fn,
1020-
reduce_to_scalar=False,
1021-
)
1022-
10231009
reduced_x_ndim = x_ndim - len(axes) + 1
10241010
argmax_axis = create_axis_apply_fn(
10251011
np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64
@@ -1030,9 +1016,7 @@ def maxandargmax(x):
10301016
sl2 = slice(len(keep_axes), None)
10311017

10321018
@numba_basic.numba_njit
1033-
def maxandargmax(x):
1034-
max_res = reduce_max(x)
1035-
1019+
def argmax(x):
10361020
# Not-reduced axes in front
10371021
transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order))
10381022
kept_shape = transposed_x.shape[sl1]
@@ -1048,6 +1032,6 @@ def maxandargmax(x):
10481032

10491033
max_idx_res = argmax_axis(reshaped_x)
10501034

1051-
return max_res, max_idx_res
1035+
return max_idx_res
10521036

1053-
return maxandargmax
1037+
return argmax

0 commit comments

Comments
 (0)