@@ -3267,10 +3267,58 @@ def aten_matrix_power(self: TensorType, n: int) -> TensorType:
32673267 raise NotImplementedError ()
32683268
32693269
3270- def aten_max (self : TensorType ) -> TensorType :
3270+ @torch_op ("aten::max" , trace_only = True )
3271+ def aten_max (
3272+ self : TReal , dim_or_other : Union [TReal , INT64 ] = None , keepdim : BOOL = None
3273+ ) -> TReal :
32713274 """max(Tensor self) -> Tensor"""
32723275
3273- raise NotImplementedError ()
3276+ self_rank = op .Size (op .Shape (self ))
3277+ if self_rank == 0 :
3278+ self = op .Reshape (self , op .Constant (value_int = [- 1 ]))
3279+
3280+ output = 1
3281+
3282+ if op .OptionalHasElement (dim_or_other ):
3283+ if isinstance (dim_or_other , int ):
3284+ if not op .OptionalHasElement (keepdim ):
3285+ keepdim = False
3286+ result , indices = _aten_max_with_dim (self , dim_or_other , keepdim )
3287+ output = 2
3288+ else : # dim_or_other is tensor
3289+ result = _aten_max_with_other (self , dim_or_other )
3290+ else :
3291+ result = _aten_max_with_no_dim (self )
3292+
3293+ if self_rank == 0 :
3294+ result = op .Squeeze (result )
3295+
3296+ if output == 2 :
3297+ if self_rank == 0 :
3298+ indices = op .Squeeze (indices ) # type: ignore[has-type]
3299+ return result , indices
3300+ return result
3301+
3302+
3303+ @torch_op ("aten::max" , overload = True )
3304+ def _aten_max_with_no_dim (self : TReal ) -> TReal :
3305+ result = op .ReduceMax (self , keepdims = 0 )
3306+ return result
3307+
3308+
3309+ @torch_op ("aten::max" , overload = True )
3310+ def _aten_max_with_other (self : TReal , other : TReal ) -> TReal :
3311+ result = op .Max (self , other )
3312+ return result
3313+
3314+
3315+ @torch_op ("aten::max" , overload = True )
3316+ # def _aten_max_with_dim(self: TReal, dim: int, keepdim: bool) -> tuple[TReal, TInt]:
3317+ def _aten_max_with_dim (self : TReal , dim : int , keepdim : bool ):
3318+ dims = op .Reshape (dim , op .Constant (value_int = [- 1 ]))
3319+ result = op .ReduceMax (self , dims , keepdims = keepdim )
3320+ indices = op .ArgMax (self , axis = dim , keepdims = keepdim )
3321+ return result , indices
32743322
32753323
32763324def aten_max_pool1d (
0 commit comments