Skip to content

Commit 151f954

Browse files
Break MaxandArgmax to TensorMax and Argmax seperately
1 parent d34760d commit 151f954

File tree

6 files changed

+990
-84
lines changed

6 files changed

+990
-84
lines changed

pytensor/ifelse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,8 @@ def cond_make_inplace(fgraph, node):
477477
Reshape,
478478
Unbroadcast,
479479
pt.math.Dot,
480-
pt.math.MaxAndArgmax,
480+
pt.math.TensorMax,
481+
pt.math.Argmax,
481482
pt.subtensor.Subtensor,
482483
pt.subtensor.IncSubtensor,
483484
pt.basic.Alloc,

pytensor/tensor/math.py

Lines changed: 214 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ class MaxAndArgmax(COp):
149149

150150
def __init__(self, axis):
151151
assert isinstance(axis, tuple | list)
152+
# print(axis)
153+
# assert 0
152154
self.axis = tuple(axis)
153155

154156
def get_params(self, node):
@@ -343,6 +345,208 @@ def grad(self, inp, grads):
343345
return (g_x,)
344346

345347

348+
class TensorMax(COp):
349+
"""
350+
Calculate the max over a given axis or over all axes.
351+
352+
"""
353+
354+
nin = 2 # tensor, axis
355+
nout = 1 # max val
356+
E_axis = "invalid axis"
357+
params_type = Generic()
358+
__props__ = ("axis",)
359+
_f16_ok = True
360+
361+
def __init__(self, axis):
362+
assert isinstance(axis, tuple | list)
363+
self.axis = tuple(axis)
364+
365+
def get_params(self, node):
366+
return self.axis
367+
368+
def make_node(self, x):
369+
x = as_tensor_variable(x)
370+
371+
# Keep the original shapes for axes on which we do not perform the max/argmax.
372+
all_axes = set(self.axis)
373+
inputs = [x]
374+
out_shape = tuple(s for i, s in enumerate(x.type.shape) if i not in all_axes)
375+
outputs = [
376+
tensor(dtype=x.type.dtype, shape=out_shape, name="max"),
377+
]
378+
return Apply(self, inputs, outputs)
379+
380+
def prepare_node(self, node, storage_map, compute_map, impl):
381+
if len(node.inputs) == 2:
382+
raise ValueError(
383+
"You are trying to compile a graph with an old Argmax node. Either reoptimize your graph or rebuild it to get the new node format."
384+
)
385+
386+
def perform(self, node, inp, outs):
387+
x = inp[0]
388+
axes = self.axis
389+
# max, max_idx = outs
390+
(max,) = outs
391+
if axes is None:
392+
axes = tuple(range(x.ndim))
393+
else:
394+
axes = tuple(int(ax) for ax in axes)
395+
max[0] = _asarray(np.max(x, axes), dtype=node.outputs[0].dtype)
396+
# # Numpy does not support multiple axes for argmax
397+
# # Work around
398+
# keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
399+
# # Not-reduced axes in front
400+
# transposed_x = np.transpose(x, np.concatenate((keep_axes, axes)))
401+
# kept_shape = transposed_x.shape[: len(keep_axes)]
402+
# reduced_shape = transposed_x.shape[len(keep_axes) :]
403+
404+
# # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
405+
# # Otherwise reshape would complain citing float arg
406+
# new_shape = (*kept_shape, np.prod(reduced_shape, dtype="int64"))
407+
# reshaped_x = transposed_x.reshape(new_shape)
408+
409+
# max_idx[0] = _asarray(np.argmax(reshaped_x, axis=-1), dtype="int64")
410+
411+
def c_code(self, node, name, inp, out, sub):
412+
if len(self.axis) != 1 and len(self.axis) != node.inputs[0].ndim:
413+
raise NotImplementedError(
414+
"NumPy C-API can compute max only for 1 axis or for all axes."
415+
)
416+
x = inp[0]
417+
axis = sub["params"]
418+
# max, argmax = out
419+
(max,) = out
420+
fail = sub["fail"]
421+
ret = """
422+
#if PY_MAJOR_VERSION >= 3
423+
#ifndef PyInt_AS_LONG
424+
#define PyInt_AS_LONG PyLong_AS_LONG
425+
#endif
426+
#endif
427+
428+
int axis;
429+
430+
if (PyTuple_GET_SIZE(%(axis)s) == PyArray_NDIM(%(x)s)) {
431+
axis = NPY_MAXDIMS;
432+
} else if(PyTuple_GET_SIZE(%(axis)s) == 1) {
433+
PyObject* axis_object = PyTuple_GET_ITEM(%(axis)s, 0);
434+
axis = (int)PyInt_AS_LONG(axis_object);
435+
if (axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)) {
436+
PyErr_SetString(PyExc_ValueError,
437+
"TensorMax: bad axis argument");
438+
%(fail)s
439+
}
440+
} else {
441+
PyErr_SetString(PyExc_NotImplementedError,
442+
"TensorMax: NumPy C-API can compute max only for 1 axis or for all axes.");
443+
%(fail)s
444+
}
445+
446+
Py_CLEAR(%(max)s);
447+
448+
%(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL);
449+
if (%(max)s == NULL) {
450+
%(fail)s;
451+
}
452+
if (!PyArray_CheckExact(%(max)s)) {
453+
%(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
454+
if(%(max)s == NULL){
455+
%(fail)s;
456+
}
457+
}
458+
"""
459+
return ret % locals()
460+
461+
def c_code_cache_version(self):
462+
return (5,)
463+
464+
def infer_shape(self, fgraph, node, shapes):
465+
ishape = shapes[0]
466+
rval = tuple(
467+
ishape[i]
468+
for (i, b) in enumerate(node.inputs[0].type.broadcastable)
469+
if i not in self.axis
470+
)
471+
return [rval]
472+
473+
def R_op(self, inputs, eval_points):
474+
if eval_points[0] is None:
475+
return [None, None]
476+
477+
if len(self.axis) != 1:
478+
raise ValueError("R_op supported for arg_max only for one axis!")
479+
if self.axis[0] > 1:
480+
raise ValueError("R_op supported for arg_max only when axis is 0 or 1")
481+
if inputs[0].ndim != 2:
482+
raise ValueError("R_op supported for arg_max only when input is a matrix")
483+
# max_vals, max_pos = self.make_node(*inputs).outputs
484+
# max_vals = self.make_node(*inputs).outputs
485+
if self.axis[0] == 0:
486+
return [eval_points[0][arange(eval_points[0].shape[1])], None]
487+
else:
488+
return [eval_points[0][arange(eval_points[0].shape[0])], None]
489+
490+
def grad(self, inp, grads):
491+
# The strict sense mathematical gradient of the maximum function is
492+
# not calculated here for it is not defined at every point where some
493+
# coordinates are identical. However, since the latter set has null
494+
# Lebesgue measure, the result may be interpreted as weak gradient.
495+
496+
# @note: This function should work correctly for L{vector}s.
497+
# (x, y), (gz, gw)
498+
# gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy
499+
# gMax * dMax/dx + gArgMax * dArgMax/dx,
500+
# gMax * dMax/daxis + gArgMax * dArgMax/daxis
501+
# g_max has one less dimension than x, so you need to complete
502+
# g_max to x's shape when axis=0 the broadcasting mechanism
503+
# does it automatically
504+
x = inp[0]
505+
axis = as_tensor_variable(self.axis)
506+
# g_max, g_max_idx = grads
507+
(g_max,) = grads
508+
509+
g_max_disconnected = isinstance(g_max.type, DisconnectedType)
510+
# g_max_idx_disconnected = isinstance(g_max_idx.type, DisconnectedType)
511+
512+
# # if the op is totally disconnected, so are its inputs
513+
# if g_max_disconnected and g_max_idx_disconnected:
514+
# return [DisconnectedType()(), DisconnectedType()()]
515+
516+
# if the op is totally disconnected, so are its inputs
517+
if g_max_disconnected:
518+
return [DisconnectedType()()]
519+
520+
# if the max is disconnected but the argmax is not,
521+
# the gradient on its inputs is zero
522+
# if g_max_disconnected:
523+
# return [x.zeros_like()]
524+
if NoneConst.equals(axis):
525+
axis_ = list(range(x.ndim))
526+
else:
527+
axis_ = axis
528+
xmax = max(x, axis_)
529+
530+
# Raise the g_max and xmax to the same number of dim as the input.
531+
pattern = []
532+
out_dim = 0
533+
if NoneConst.equals(axis):
534+
# We are taking the max/argmax over all dimensions.
535+
axis = None
536+
for i in range(x.ndim):
537+
if axis is None or i in axis.data:
538+
pattern.append("x")
539+
else:
540+
pattern.append(out_dim)
541+
out_dim += 1
542+
g_max_pad = DimShuffle(g_max.broadcastable, pattern)(g_max)
543+
xmax_pad = DimShuffle(xmax.broadcastable, pattern)(xmax)
544+
545+
# Set the grad to the correct position.
546+
g_x = eq(xmax_pad, x) * g_max_pad
547+
return (g_x,)
548+
549+
346550
class Argmax(COp):
347551
"""
348552
Calculate the argmax over a given axis or over all axes.
@@ -357,8 +561,10 @@ class Argmax(COp):
357561
params_type = ParamsType(c_axis=ps.int64)
358562

359563
def __init__(self, axis):
360-
if axis is not None:
361-
axis = tuple(axis)
564+
# if axis is not None:
565+
# axis = tuple(axis)
566+
assert isinstance(axis, tuple | list)
567+
# print(axis)
362568
self.axis = tuple(axis)
363569

364570
def get_params(self, node):
@@ -395,6 +601,8 @@ def perform(self, node, inp, outs):
395601
(max_idx,) = outs
396602
if axes is None:
397603
axes = tuple(range(x.ndim))
604+
else:
605+
axes = tuple(int(ax) for ax in axes)
398606

399607
# Numpy does not support multiple axes for argmax
400608
# Work around
@@ -477,7 +685,7 @@ def grad(self, inp, grads):
477685

478686

479687
@_vectorize_node.register(Argmax)
480-
@_vectorize_node.register(MaxAndArgmax)
688+
# @_vectorize_node.register(MaxAndArgmax)
481689
def vectorize_argmax_node(op, node, batch_x):
482690
core_ndim = node.inputs[0].type.ndim
483691
batch_ndim = batch_x.type.ndim - core_ndim
@@ -600,7 +808,9 @@ def max_and_argmax(a, axis=None, keepdims=False):
600808
axis = check_and_normalize_axes(a, axis)
601809
if len(axis) == 0:
602810
axis = list(range(a.type.ndim))
603-
out, argout = MaxAndArgmax(axis)(a)
811+
out = TensorMax(axis)(a)
812+
argout = Argmax(axis)(a)
813+
# out, argout = MaxAndArgmax(axis)(a)
604814

605815
if keepdims:
606816
out = makeKeepDims(a, out, axis)

pytensor/tensor/rewriting/uncanonicalize.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,33 +31,34 @@
3131
3232
"""
3333

34-
from pytensor import scalar as ps
3534
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
3635
from pytensor.tensor.basic import Alloc, alloc, constant
37-
from pytensor.tensor.elemwise import CAReduce, DimShuffle
38-
from pytensor.tensor.math import Argmax, Max, MaxAndArgmax, Min, neg
36+
from pytensor.tensor.elemwise import DimShuffle
37+
38+
# from pytensor.tensor.math import Argmax, Max, MaxAndArgmax, Min, neg
39+
from pytensor.tensor.math import Min, TensorMax, neg
3940
from pytensor.tensor.rewriting.basic import register_uncanonicalize
4041
from pytensor.tensor.shape import Reshape, reshape
4142
from pytensor.tensor.subtensor import Subtensor
4243

4344

44-
@register_uncanonicalize
45-
@node_rewriter([MaxAndArgmax])
46-
def local_max_and_argmax(fgraph, node):
47-
"""
48-
If we don't use the argmax, change it to a max only.
49-
"""
50-
if isinstance(node.op, MaxAndArgmax):
51-
axis = node.op.axis
52-
if len(fgraph.clients[node.outputs[1]]) == 0:
53-
new = Max(axis)(node.inputs[0])
54-
copy_stack_trace(node.outputs[0], new)
55-
return [new, None]
45+
# @register_uncanonicalize
46+
# @node_rewriter([MaxAndArgmax])
47+
# def local_max_and_argmax(fgraph, node):
48+
# """
49+
# If we don't use the argmax, change it to a max only.
50+
# """
51+
# if isinstance(node.op, MaxAndArgmax):
52+
# axis = node.op.axis
53+
# if len(fgraph.clients[node.outputs[1]]) == 0:
54+
# new = Max(axis)(node.inputs[0])
55+
# copy_stack_trace(node.outputs[0], new)
56+
# return [new, None]
5657

57-
if len(fgraph.clients[node.outputs[0]]) == 0:
58-
new = Argmax(axis)(node.inputs[0])
59-
copy_stack_trace(node.outputs[0], new)
60-
return [None, new]
58+
# if len(fgraph.clients[node.outputs[0]]) == 0:
59+
# new = Argmax(axis)(node.inputs[0])
60+
# copy_stack_trace(node.outputs[0], new)
61+
# return [None, new]
6162

6263

6364
@register_uncanonicalize
@@ -74,13 +75,13 @@ def local_max_to_min(fgraph, node):
7475
the interface put only MaxAndArgmax into the graph.
7576
7677
"""
78+
# pytensor.dprint(node)
79+
# print()
80+
# print(node.op == neg)
7781
if node.op == neg and node.inputs[0].owner:
7882
max = node.inputs[0]
79-
if (
80-
max.owner
81-
and isinstance(max.owner.op, CAReduce)
82-
and max.owner.op.scalar_op == ps.scalar_maximum
83-
):
83+
# print(max.owner.op.scalar_op)
84+
if max.owner and isinstance(max.owner.op, TensorMax):
8485
neg_node = max.owner.inputs[0]
8586
if neg_node.owner and neg_node.owner.op == neg:
8687
new = Min(max.owner.op.axis)(neg_node.owner.inputs[0])

0 commit comments

Comments
 (0)