|
31 | 31 | UINT32,
|
32 | 32 | UINT64,
|
33 | 33 | graph,
|
| 34 | + ir, |
34 | 35 | )
|
35 | 36 | from onnxscript.function_libs.torch_lib.ops import common as common_ops
|
36 | 37 | from onnxscript.function_libs.torch_lib.registration import torch_op
|
@@ -4749,28 +4750,10 @@ def aten_layer_norm(
|
4749 | 4750 | start_axis = -len(normalized_shape)
|
4750 | 4751 |
|
4751 | 4752 | if weight is None:
|
4752 |
| - one = op.Constant(value_float=1.0) |
| 4753 | + one = op.Constant(value=ir.tensor(1, dtype=input.dtype)) |
4753 | 4754 | weight = op.Expand(one, op.Shape(input, start=start_axis))
|
4754 | 4755 |
|
4755 |
| - if bias is None: |
4756 |
| - zero = op.Constant(value_float=0.0) |
4757 |
| - bias = op.Expand(zero, op.Shape(input, start=start_axis)) |
4758 |
| - |
4759 |
| - return _aten_layer_norm_onnx(input, weight, bias, axis=start_axis, eps=eps) |
4760 |
| - |
4761 |
| - |
4762 |
| -@torch_op("aten::layer_norm", private=True) |
4763 |
| -def _aten_layer_norm_onnx( |
4764 |
| - input: TReal, |
4765 |
| - weight: TReal, |
4766 |
| - bias: TReal, |
4767 |
| - axis: int, |
4768 |
| - eps: float = 1e-05, |
4769 |
| -) -> TReal: |
4770 |
| - """layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor""" |
4771 |
| - |
4772 |
| - # TODO(justinchuby): Use OptionalHasElement after onnx/onnx#4982 |
4773 |
| - result, _, _ = op.LayerNormalization(input, weight, bias, axis=axis, epsilon=eps) |
| 4756 | + result, _, _ = op.LayerNormalization(input, weight, bias, axis=start_axis, epsilon=eps) |
4774 | 4757 | return result
|
4775 | 4758 |
|
4776 | 4759 |
|
|
0 commit comments