@@ -149,6 +149,8 @@ class MaxAndArgmax(COp):
149
149
150
150
def __init__ (self , axis ):
151
151
assert isinstance (axis , tuple | list )
152
+ # print(axis)
153
+ # assert 0
152
154
self .axis = tuple (axis )
153
155
154
156
def get_params (self , node ):
@@ -343,6 +345,208 @@ def grad(self, inp, grads):
343
345
return (g_x ,)
344
346
345
347
348
+ class TensorMax (COp ):
349
+ """
350
+ Calculate the max over a given axis or over all axes.
351
+
352
+ """
353
+
354
+ nin = 2 # tensor, axis
355
+ nout = 1 # max val
356
+ E_axis = "invalid axis"
357
+ params_type = Generic ()
358
+ __props__ = ("axis" ,)
359
+ _f16_ok = True
360
+
361
+ def __init__ (self , axis ):
362
+ assert isinstance (axis , tuple | list )
363
+ self .axis = tuple (axis )
364
+
365
+ def get_params (self , node ):
366
+ return self .axis
367
+
368
+ def make_node (self , x ):
369
+ x = as_tensor_variable (x )
370
+
371
+ # Keep the original shapes for axes on which we do not perform the max/argmax.
372
+ all_axes = set (self .axis )
373
+ inputs = [x ]
374
+ out_shape = tuple (s for i , s in enumerate (x .type .shape ) if i not in all_axes )
375
+ outputs = [
376
+ tensor (dtype = x .type .dtype , shape = out_shape , name = "max" ),
377
+ ]
378
+ return Apply (self , inputs , outputs )
379
+
380
+ def prepare_node (self , node , storage_map , compute_map , impl ):
381
+ if len (node .inputs ) == 2 :
382
+ raise ValueError (
383
+ "You are trying to compile a graph with an old Argmax node. Either reoptimize your graph or rebuild it to get the new node format."
384
+ )
385
+
386
+ def perform (self , node , inp , outs ):
387
+ x = inp [0 ]
388
+ axes = self .axis
389
+ # max, max_idx = outs
390
+ (max ,) = outs
391
+ if axes is None :
392
+ axes = tuple (range (x .ndim ))
393
+ else :
394
+ axes = tuple (int (ax ) for ax in axes )
395
+ max [0 ] = _asarray (np .max (x , axes ), dtype = node .outputs [0 ].dtype )
396
+ # # Numpy does not support multiple axes for argmax
397
+ # # Work around
398
+ # keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
399
+ # # Not-reduced axes in front
400
+ # transposed_x = np.transpose(x, np.concatenate((keep_axes, axes)))
401
+ # kept_shape = transposed_x.shape[: len(keep_axes)]
402
+ # reduced_shape = transposed_x.shape[len(keep_axes) :]
403
+
404
+ # # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
405
+ # # Otherwise reshape would complain citing float arg
406
+ # new_shape = (*kept_shape, np.prod(reduced_shape, dtype="int64"))
407
+ # reshaped_x = transposed_x.reshape(new_shape)
408
+
409
+ # max_idx[0] = _asarray(np.argmax(reshaped_x, axis=-1), dtype="int64")
410
+
411
+ def c_code (self , node , name , inp , out , sub ):
412
+ if len (self .axis ) != 1 and len (self .axis ) != node .inputs [0 ].ndim :
413
+ raise NotImplementedError (
414
+ "NumPy C-API can compute max only for 1 axis or for all axes."
415
+ )
416
+ x = inp [0 ]
417
+ axis = sub ["params" ]
418
+ # max, argmax = out
419
+ (max ,) = out
420
+ fail = sub ["fail" ]
421
+ ret = """
422
+ #if PY_MAJOR_VERSION >= 3
423
+ #ifndef PyInt_AS_LONG
424
+ #define PyInt_AS_LONG PyLong_AS_LONG
425
+ #endif
426
+ #endif
427
+
428
+ int axis;
429
+
430
+ if (PyTuple_GET_SIZE(%(axis)s) == PyArray_NDIM(%(x)s)) {
431
+ axis = NPY_MAXDIMS;
432
+ } else if(PyTuple_GET_SIZE(%(axis)s) == 1) {
433
+ PyObject* axis_object = PyTuple_GET_ITEM(%(axis)s, 0);
434
+ axis = (int)PyInt_AS_LONG(axis_object);
435
+ if (axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)) {
436
+ PyErr_SetString(PyExc_ValueError,
437
+ "TensorMax: bad axis argument");
438
+ %(fail)s
439
+ }
440
+ } else {
441
+ PyErr_SetString(PyExc_NotImplementedError,
442
+ "TensorMax: NumPy C-API can compute max only for 1 axis or for all axes.");
443
+ %(fail)s
444
+ }
445
+
446
+ Py_CLEAR(%(max)s);
447
+
448
+ %(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL);
449
+ if (%(max)s == NULL) {
450
+ %(fail)s;
451
+ }
452
+ if (!PyArray_CheckExact(%(max)s)) {
453
+ %(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
454
+ if(%(max)s == NULL){
455
+ %(fail)s;
456
+ }
457
+ }
458
+ """
459
+ return ret % locals ()
460
+
461
+ def c_code_cache_version (self ):
462
+ return (5 ,)
463
+
464
+ def infer_shape (self , fgraph , node , shapes ):
465
+ ishape = shapes [0 ]
466
+ rval = tuple (
467
+ ishape [i ]
468
+ for (i , b ) in enumerate (node .inputs [0 ].type .broadcastable )
469
+ if i not in self .axis
470
+ )
471
+ return [rval ]
472
+
473
+ def R_op (self , inputs , eval_points ):
474
+ if eval_points [0 ] is None :
475
+ return [None , None ]
476
+
477
+ if len (self .axis ) != 1 :
478
+ raise ValueError ("R_op supported for arg_max only for one axis!" )
479
+ if self .axis [0 ] > 1 :
480
+ raise ValueError ("R_op supported for arg_max only when axis is 0 or 1" )
481
+ if inputs [0 ].ndim != 2 :
482
+ raise ValueError ("R_op supported for arg_max only when input is a matrix" )
483
+ # max_vals, max_pos = self.make_node(*inputs).outputs
484
+ # max_vals = self.make_node(*inputs).outputs
485
+ if self .axis [0 ] == 0 :
486
+ return [eval_points [0 ][arange (eval_points [0 ].shape [1 ])], None ]
487
+ else :
488
+ return [eval_points [0 ][arange (eval_points [0 ].shape [0 ])], None ]
489
+
490
+ def grad (self , inp , grads ):
491
+ # The strict sense mathematical gradient of the maximum function is
492
+ # not calculated here for it is not defined at every point where some
493
+ # coordinates are identical. However, since the latter set has null
494
+ # Lebesgue measure, the result may be interpreted as weak gradient.
495
+
496
+ # @note: This function should work correctly for L{vector}s.
497
+ # (x, y), (gz, gw)
498
+ # gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy
499
+ # gMax * dMax/dx + gArgMax * dArgMax/dx,
500
+ # gMax * dMax/daxis + gArgMax * dArgMax/daxis
501
+ # g_max has one less dimension than x, so you need to complete
502
+ # g_max to x's shape when axis=0 the broadcasting mechanism
503
+ # does it automatically
504
+ x = inp [0 ]
505
+ axis = as_tensor_variable (self .axis )
506
+ # g_max, g_max_idx = grads
507
+ (g_max ,) = grads
508
+
509
+ g_max_disconnected = isinstance (g_max .type , DisconnectedType )
510
+ # g_max_idx_disconnected = isinstance(g_max_idx.type, DisconnectedType)
511
+
512
+ # # if the op is totally disconnected, so are its inputs
513
+ # if g_max_disconnected and g_max_idx_disconnected:
514
+ # return [DisconnectedType()(), DisconnectedType()()]
515
+
516
+ # if the op is totally disconnected, so are its inputs
517
+ if g_max_disconnected :
518
+ return [DisconnectedType ()()]
519
+
520
+ # if the max is disconnected but the argmax is not,
521
+ # the gradient on its inputs is zero
522
+ # if g_max_disconnected:
523
+ # return [x.zeros_like()]
524
+ if NoneConst .equals (axis ):
525
+ axis_ = list (range (x .ndim ))
526
+ else :
527
+ axis_ = axis
528
+ xmax = max (x , axis_ )
529
+
530
+ # Raise the g_max and xmax to the same number of dim as the input.
531
+ pattern = []
532
+ out_dim = 0
533
+ if NoneConst .equals (axis ):
534
+ # We are taking the max/argmax over all dimensions.
535
+ axis = None
536
+ for i in range (x .ndim ):
537
+ if axis is None or i in axis .data :
538
+ pattern .append ("x" )
539
+ else :
540
+ pattern .append (out_dim )
541
+ out_dim += 1
542
+ g_max_pad = DimShuffle (g_max .broadcastable , pattern )(g_max )
543
+ xmax_pad = DimShuffle (xmax .broadcastable , pattern )(xmax )
544
+
545
+ # Set the grad to the correct position.
546
+ g_x = eq (xmax_pad , x ) * g_max_pad
547
+ return (g_x ,)
548
+
549
+
346
550
class Argmax (COp ):
347
551
"""
348
552
Calculate the argmax over a given axis or over all axes.
@@ -357,8 +561,10 @@ class Argmax(COp):
357
561
params_type = ParamsType (c_axis = ps .int64 )
358
562
359
563
def __init__ (self , axis ):
360
- if axis is not None :
361
- axis = tuple (axis )
564
+ # if axis is not None:
565
+ # axis = tuple(axis)
566
+ assert isinstance (axis , tuple | list )
567
+ # print(axis)
362
568
self .axis = tuple (axis )
363
569
364
570
def get_params (self , node ):
@@ -395,6 +601,8 @@ def perform(self, node, inp, outs):
395
601
(max_idx ,) = outs
396
602
if axes is None :
397
603
axes = tuple (range (x .ndim ))
604
+ else :
605
+ axes = tuple (int (ax ) for ax in axes )
398
606
399
607
# Numpy does not support multiple axes for argmax
400
608
# Work around
@@ -477,7 +685,7 @@ def grad(self, inp, grads):
477
685
478
686
479
687
@_vectorize_node .register (Argmax )
480
- @_vectorize_node .register (MaxAndArgmax )
688
+ # @_vectorize_node.register(MaxAndArgmax)
481
689
def vectorize_argmax_node (op , node , batch_x ):
482
690
core_ndim = node .inputs [0 ].type .ndim
483
691
batch_ndim = batch_x .type .ndim - core_ndim
@@ -600,7 +808,9 @@ def max_and_argmax(a, axis=None, keepdims=False):
600
808
axis = check_and_normalize_axes (a , axis )
601
809
if len (axis ) == 0 :
602
810
axis = list (range (a .type .ndim ))
603
- out , argout = MaxAndArgmax (axis )(a )
811
+ out = TensorMax (axis )(a )
812
+ argout = Argmax (axis )(a )
813
+ # out, argout = MaxAndArgmax(axis)(a)
604
814
605
815
if keepdims :
606
816
out = makeKeepDims (a , out , axis )
0 commit comments