@@ -2893,17 +2893,27 @@ def aten_kthvalue(
28932893 raise NotImplementedError ()
28942894
28952895
2896+ @torch_op ("aten::layer_norm" , trace_only = True )
28962897def aten_layer_norm (
2897- input : TensorType ,
2898+ input : TReal ,
28982899 normalized_shape : Sequence [int ],
2899- weight : Optional [TensorType ] = None ,
2900- bias : Optional [TensorType ] = None ,
2900+ weight : Optional [TReal ] = None ,
2901+ bias : Optional [TReal ] = None ,
29012902 eps : float = 1e-05 ,
2902- cudnn_enable : bool = True ,
2903- ) -> TensorType :
2903+ ) -> TReal :
29042904 """layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"""
29052905
2906- raise NotImplementedError ()
2906+ axes_list = [- i for i in range (len (normalized_shape ), 0 , - 1 )]
2907+ start_axis = axes_list [0 ]
2908+ if not op .OptionalHasElement (weight ):
2909+ one = op .Constant (value_float = 1.0 )
2910+ weight = op .Expand (one , op .Shape (input , start = start_axis ))
2911+ if not op .OptionalHasElement (bias ):
2912+ zero = op .Constant (value_float = 0.0 )
2913+ bias = op .Expand (zero , op .Shape (input , start = start_axis ))
2914+
2915+ result , _ , _ = op .LayerNormalization (input , weight , bias , axis = start_axis , epsilon = eps )
2916+ return result
29072917
29082918
29092919def aten_lcm (self : TensorType , other : TensorType ) -> TensorType :
@@ -3966,12 +3976,13 @@ def aten_native_layer_norm(
39663976 # where D is the dimension of normalized_shape. For example, if normalized_shape is
39673977 # (3, 5) (a 2-dimensional shape), the mean and standard-deviation are computed
39683978 # over the last 2 dimensions of the input (i.e. input.mean((-2, -1))).
3969- axes = [- i for i in range (len (normalized_shape ), 0 , - 1 )]
3970- if weight is None :
3979+ axes_list = [- i for i in range (len (normalized_shape ), 0 , - 1 )]
3980+ axes = op .Constant (value_ints = axes_list )
3981+ if not op .OptionalHasElement (weight ):
39713982 weight = op .CastLike (1 , input )
3972- if bias is None :
3983+ if not op . OptionalHasElement ( bias ) :
39733984 bias = op .CastLike (0 , input )
3974- return _aten_native_layer_norm_onnx (input , weight , bias , axes = axes , eps = eps )
3985+ return _aten_native_layer_norm_onnx (input , weight , bias , axes , eps )
39753986
39763987
39773988@torch_op ("aten::native_layer_norm" , overload = True )
@@ -3984,18 +3995,18 @@ def _aten_native_layer_norm_onnx(
39843995) -> Tuple [TReal , TReal , TReal ]:
39853996
39863997 # FIXME(justinchuby): Use opset18 when it is supported by onnxruntime
3987- mean = opset17 .ReduceMean (input , axes = axes )
3988- numerator = opset17 .Sub (input , mean )
3989- power_num = opset17 .Pow (numerator , 2.0 )
3990- variance = opset17 .ReduceMean (power_num , axes = axes )
3991- variance_eps = opset17 .Add (variance , eps )
3992- denominator = opset17 .Sqrt (variance_eps )
3993- result = opset17 .Div (numerator , denominator )
3994- weight = opset17 .CastLike (weight , result )
3995- result = opset17 .Mul (result , weight )
3996- bias = opset17 .CastLike (bias , result )
3997- result = opset17 .Add (result , bias )
3998- rdenominator = opset17 .Reciprocal (denominator )
3998+ mean = op .ReduceMean (input , axes )
3999+ numerator = op .Sub (input , mean )
4000+ power_num = op .Pow (numerator , 2.0 )
4001+ variance = op .ReduceMean (power_num , axes )
4002+ variance_eps = op .Add (variance , eps )
4003+ denominator = op .Sqrt (variance_eps )
4004+ result = op .Div (numerator , denominator )
4005+ weight = op .CastLike (weight , result )
4006+ result = op .Mul (result , weight )
4007+ bias = op .CastLike (bias , result )
4008+ result = op .Add (result , bias )
4009+ rdenominator = op .Reciprocal (denominator )
39994010 return result , mean , rdenominator
40004011
40014012
0 commit comments