2
2
3
3
from pytensor .link .jax .dispatch import jax_funcify
4
4
from pytensor .tensor .blas import BatchedDot
5
- from pytensor .tensor .math import Dot , MaxAndArgmax
5
+ from pytensor .tensor .math import Argmax , Dot , Max
6
6
from pytensor .tensor .nlinalg import (
7
7
SVD ,
8
8
Det ,
@@ -104,18 +104,73 @@ def batched_dot(a, b):
104
104
return batched_dot
105
105
106
106
107
- @jax_funcify .register (MaxAndArgmax )
108
- def jax_funcify_MaxAndArgmax (op , ** kwargs ):
107
+ # @jax_funcify.register(Max)
108
+ # @jax_funcify.register(Argmax)
109
+ # def jax_funcify_MaxAndArgmax(op, **kwargs):
110
+ # axis = op.axis
111
+
112
+ # def maxandargmax(x, axis=axis):
113
+ # if axis is None:
114
+ # axes = tuple(range(x.ndim))
115
+ # else:
116
+ # axes = tuple(int(ax) for ax in axis)
117
+
118
+ # max_res = jnp.max(x, axis)
119
+
120
+ # # NumPy does not support multiple axes for argmax; this is a
121
+ # # work-around
122
+ # keep_axes = jnp.array(
123
+ # [i for i in range(x.ndim) if i not in axes], dtype="int64"
124
+ # )
125
+ # # Not-reduced axes in front
126
+ # transposed_x = jnp.transpose(
127
+ # x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64")))
128
+ # )
129
+ # kept_shape = transposed_x.shape[: len(keep_axes)]
130
+ # reduced_shape = transposed_x.shape[len(keep_axes) :]
131
+
132
+ # # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
133
+ # # Otherwise reshape would complain citing float arg
134
+ # new_shape = (
135
+ # *kept_shape,
136
+ # jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
137
+ # )
138
+ # reshaped_x = transposed_x.reshape(new_shape)
139
+
140
+ # max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
141
+
142
+ # return max_res, max_idx_res
143
+
144
+ # return maxandargmax
145
+
146
+
147
+ @jax_funcify .register (Max )
148
+ def jax_funcify_Max (op , ** kwargs ):
109
149
axis = op .axis
110
150
111
- def maxandargmax (x , axis = axis ):
151
+ def max (x , axis = axis ):
152
+ # if axis is None:
153
+ # axes = tuple(range(x.ndim))
154
+ # else:
155
+ # axes = tuple(int(ax) for ax in axis)
156
+
157
+ max_res = jnp .max (x , axis )
158
+
159
+ return max_res
160
+
161
+ return max
162
+
163
+
164
+ @jax_funcify .register (Argmax )
165
+ def jax_funcify_Argmax (op , ** kwargs ):
166
+ axis = op .axis
167
+
168
+ def argmax (x , axis = axis ):
112
169
if axis is None :
113
170
axes = tuple (range (x .ndim ))
114
171
else :
115
172
axes = tuple (int (ax ) for ax in axis )
116
173
117
- max_res = jnp .max (x , axis )
118
-
119
174
# NumPy does not support multiple axes for argmax; this is a
120
175
# work-around
121
176
keep_axes = jnp .array (
@@ -138,6 +193,6 @@ def maxandargmax(x, axis=axis):
138
193
139
194
max_idx_res = jnp .argmax (reshaped_x , axis = - 1 ).astype ("int64" )
140
195
141
- return max_res , max_idx_res
196
+ return max_idx_res
142
197
143
- return maxandargmax
198
+ return argmax
0 commit comments