@@ -2895,17 +2895,27 @@ def aten_kthvalue(
28952895 raise NotImplementedError ()
28962896
28972897
2898+ @torch_op ("aten::layer_norm" , trace_only = True )
28982899def aten_layer_norm (
2899- input : TensorType ,
2900+ input : TReal ,
29002901 normalized_shape : Sequence [int ],
2901- weight : Optional [TensorType ] = None ,
2902- bias : Optional [TensorType ] = None ,
2902+ weight : Optional [TReal ] = None ,
2903+ bias : Optional [TReal ] = None ,
29032904 eps : float = 1e-05 ,
2904- cudnn_enable : bool = True ,
2905- ) -> TensorType :
2905+ ) -> TReal :
29062906 """layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"""
29072907
2908- raise NotImplementedError ()
2908+ axes_list = [- i for i in range (len (normalized_shape ), 0 , - 1 )]
2909+ start_axis = axes_list [0 ]
2910+ if not op .OptionalHasElement (weight ):
2911+ one = op .Constant (value_float = 1.0 )
2912+ weight = op .Expand (one , op .Shape (input , start = start_axis ))
2913+ if not op .OptionalHasElement (bias ):
2914+ zero = op .Constant (value_float = 0.0 )
2915+ bias = op .Expand (zero , op .Shape (input , start = start_axis ))
2916+
2917+ result , _ , _ = op .LayerNormalization (input , weight , bias , axis = start_axis , epsilon = eps )
2918+ return result
29092919
29102920
29112921def aten_lcm (self : TensorType , other : TensorType ) -> TensorType :
@@ -3259,10 +3269,58 @@ def aten_matrix_power(self: TensorType, n: int) -> TensorType:
32593269 raise NotImplementedError ()
32603270
32613271
3262- def aten_max (self : TensorType ) -> TensorType :
3272+ @torch_op ("aten::max" , trace_only = True )
3273+ def aten_max (
3274+ self : TReal , dim_or_other : Union [TReal , INT64 ] = None , keepdim : BOOL = None
3275+ ) -> TReal :
32633276 """max(Tensor self) -> Tensor"""
32643277
3265- raise NotImplementedError ()
3278+ self_rank = op .Size (op .Shape (self ))
3279+ if self_rank == 0 :
3280+ self = op .Reshape (self , op .Constant (value_int = [- 1 ]))
3281+
3282+ output = 1
3283+
3284+ if op .OptionalHasElement (dim_or_other ):
3285+ if isinstance (dim_or_other , int ):
3286+ if not op .OptionalHasElement (keepdim ):
3287+ keepdim = False
3288+ result , indices = _aten_max_with_dim (self , dim_or_other , keepdim )
3289+ output = 2
3290+ else : # dim_or_other is tensor
3291+ result = _aten_max_with_other (self , dim_or_other )
3292+ else :
3293+ result = _aten_max_with_no_dim (self )
3294+
3295+ if self_rank == 0 :
3296+ result = op .Squeeze (result )
3297+
3298+ if output == 2 :
3299+ if self_rank == 0 :
3300+ indices = op .Squeeze (indices ) # type: ignore[has-type]
3301+ return result , indices
3302+ return result
3303+
3304+
3305+ @torch_op ("aten::max" , overload = True )
3306+ def _aten_max_with_no_dim (self : TReal ) -> TReal :
3307+ result = op .ReduceMax (self , keepdims = 0 )
3308+ return result
3309+
3310+
3311+ @torch_op ("aten::max" , overload = True )
3312+ def _aten_max_with_other (self : TReal , other : TReal ) -> TReal :
3313+ result = op .Max (self , other )
3314+ return result
3315+
3316+
3317+ @torch_op ("aten::max" , overload = True )
3318+ # def _aten_max_with_dim(self: TReal, dim: int, keepdim: bool) -> tuple[TReal, TInt]:
3319+ def _aten_max_with_dim (self : TReal , dim : int , keepdim : bool ):
3320+ dims = op .Reshape (dim , op .Constant (value_int = [- 1 ]))
3321+ result = op .ReduceMax (self , dims , keepdims = keepdim )
3322+ indices = op .ArgMax (self , axis = dim , keepdims = keepdim )
3323+ return result , indices
32663324
32673325
32683326def aten_max_pool1d (
@@ -3920,12 +3978,13 @@ def aten_native_layer_norm(
39203978 # where D is the dimension of normalized_shape. For example, if normalized_shape is
39213979 # (3, 5) (a 2-dimensional shape), the mean and standard-deviation are computed
39223980 # over the last 2 dimensions of the input (i.e. input.mean((-2, -1))).
3923- axes = [- i for i in range (len (normalized_shape ), 0 , - 1 )]
3924- if weight is None :
3981+ axes_list = [- i for i in range (len (normalized_shape ), 0 , - 1 )]
3982+ axes = op .Constant (value_ints = axes_list )
3983+ if not op .OptionalHasElement (weight ):
39253984 weight = op .CastLike (1 , input )
3926- if bias is None :
3985+ if not op . OptionalHasElement ( bias ) :
39273986 bias = op .CastLike (0 , input )
3928- return _aten_native_layer_norm_onnx (input , weight , bias , axes = axes , eps = eps )
3987+ return _aten_native_layer_norm_onnx (input , weight , bias , axes , eps )
39293988
39303989
39313990@torch_op ("aten::native_layer_norm" , overload = True )
@@ -3938,18 +3997,18 @@ def _aten_native_layer_norm_onnx(
39383997) -> Tuple [TReal , TReal , TReal ]:
39393998
39403999 # FIXME(justinchuby): Use opset18 when it is supported by onnxruntime
3941- mean = opset17 .ReduceMean (input , axes = axes )
3942- numerator = opset17 .Sub (input , mean )
3943- power_num = opset17 .Pow (numerator , 2.0 )
3944- variance = opset17 .ReduceMean (power_num , axes = axes )
3945- variance_eps = opset17 .Add (variance , eps )
3946- denominator = opset17 .Sqrt (variance_eps )
3947- result = opset17 .Div (numerator , denominator )
3948- weight = opset17 .CastLike (weight , result )
3949- result = opset17 .Mul (result , weight )
3950- bias = opset17 .CastLike (bias , result )
3951- result = opset17 .Add (result , bias )
3952- rdenominator = opset17 .Reciprocal (denominator )
4000+ mean = op .ReduceMean (input , axes )
4001+ numerator = op .Sub (input , mean )
4002+ power_num = op .Pow (numerator , 2.0 )
4003+ variance = op .ReduceMean (power_num , axes )
4004+ variance_eps = op .Add (variance , eps )
4005+ denominator = op .Sqrt (variance_eps )
4006+ result = op .Div (numerator , denominator )
4007+ weight = op .CastLike (weight , result )
4008+ result = op .Mul (result , weight )
4009+ bias = op .CastLike (bias , result )
4010+ result = op .Add (result , bias )
4011+ rdenominator = op .Reciprocal (denominator )
39534012 return result , mean , rdenominator
39544013
39554014
@@ -5055,20 +5114,10 @@ def aten_square(self: TensorType) -> TensorType:
50555114 raise NotImplementedError ()
50565115
50575116
5058- @torch_op ("aten::squeeze" , trace_only = True )
5059- def aten_squeeze (self : TTensor , dim : Optional [int ] = None ) -> TTensor :
5117+ def aten_squeeze (self : TensorType ) -> TensorType :
50605118 """squeeze(Tensor(a) self) -> Tensor(a)"""
50615119
5062- if op .OptionalHasElement (dim ):
5063- rank = op .Size (op .Shape (self ))
5064- if rank == 0 :
5065- self = op .Reshape (self , op .Constant (value_ints = [- 1 ]))
5066- dims = op .Reshape (dim , op .Constant (value_ints = [- 1 ]))
5067- result = op .Squeeze (self , dims )
5068- else :
5069- result = op .Squeeze (self )
5070-
5071- return result
5120+ raise NotImplementedError ()
50725121
50735122
50745123def aten_squeeze_copy (self : TensorType ) -> TensorType :
0 commit comments